diff --git a/docs/docs/get_started/introduction.mdx b/docs/docs/get_started/introduction.mdx
index d3e42904ff7..2cef4cc5f65 100644
--- a/docs/docs/get_started/introduction.mdx
+++ b/docs/docs/get_started/introduction.mdx
@@ -29,6 +29,11 @@ The main value props of the LangChain packages are:
Off-the-shelf chains make it easy to get started. Components make it easy to customize existing chains and build new ones.
+The LangChain libraries themselves are made up of several different packages.
+- **`langchain-core`**: Base abstractions and LangChain Expression Language.
+- **`langchain-community`**: Third party integrations.
+- **`langchain`**: Chains, agents, and retrieval strategies that make up an application's cognitive architecture.
+
## Get started
[Hereβs](/docs/get_started/installation) how to install LangChain, set up your environment, and start building.
diff --git a/libs/community/Makefile b/libs/community/Makefile
new file mode 100644
index 00000000000..fe1509fe8dd
--- /dev/null
+++ b/libs/community/Makefile
@@ -0,0 +1,69 @@
+.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
+
+# Default target executed when no arguments are given to make.
+all: help
+
+# Define a variable for the test file path.
+TEST_FILE ?= tests/unit_tests/
+
+test:
+ poetry run pytest $(TEST_FILE)
+
+tests:
+ poetry run pytest $(TEST_FILE)
+
+test_watch:
+ poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests
+
+check_imports: langchain_community/**/*.py
+ for f in $^ ; do \
+ python -c "from importlib.machinery import SourceFileLoader; SourceFileLoader('x', '$$f').load_module()" || exit 1; \
+ done
+
+extended_tests:
+ poetry run pytest --only-extended tests/unit_tests
+
+
+######################
+# LINTING AND FORMATTING
+######################
+
+# Define a variable for Python and notebook files.
+PYTHON_FILES=.
+MYPY_CACHE=.mypy_cache
+lint format: PYTHON_FILES=.
+lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/community -name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
+lint_package: PYTHON_FILES=langchain_community
+lint_tests: PYTHON_FILES=tests
+lint_tests: MYPY_CACHE=.mypy_cache_test
+
+lint lint_diff lint_package lint_tests:
+ ./scripts/check_pydantic.sh .
+ ./scripts/check_imports.sh
+ poetry run ruff .
+ [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
+ [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
+ [ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
+
+format format_diff:
+ poetry run ruff format $(PYTHON_FILES)
+ poetry run ruff --select I --fix $(PYTHON_FILES)
+
+spell_check:
+ poetry run codespell --toml pyproject.toml
+
+spell_fix:
+ poetry run codespell --toml pyproject.toml -w
+
+######################
+# HELP
+######################
+
+help:
+ @echo '----'
+ @echo 'format - run code formatters'
+ @echo 'lint - run linters'
+ @echo 'test - run unit tests'
+ @echo 'tests - run unit tests'
+ @echo 'test TEST_FILE= - run all tests in file'
+ @echo 'test_watch - run unit tests in watch mode'
diff --git a/libs/community/README.md b/libs/community/README.md
new file mode 100644
index 00000000000..d0f8c668bdc
--- /dev/null
+++ b/libs/community/README.md
@@ -0,0 +1 @@
+# langchain-community
\ No newline at end of file
diff --git a/libs/community/langchain_community/__init__.py b/libs/community/langchain_community/__init__.py
new file mode 100644
index 00000000000..ac7aeef6faf
--- /dev/null
+++ b/libs/community/langchain_community/__init__.py
@@ -0,0 +1,9 @@
+"""Main entrypoint into package."""
+from importlib import metadata
+
+try:
+ __version__ = metadata.version(__package__)
+except metadata.PackageNotFoundError:
+ # Case where package metadata is not available.
+ __version__ = ""
+del metadata # optional, avoids polluting the results of dir(__package__)
diff --git a/libs/langchain/tests/integration_tests/adapters/__init__.py b/libs/community/langchain_community/adapters/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/adapters/__init__.py
rename to libs/community/langchain_community/adapters/__init__.py
diff --git a/libs/community/langchain_community/adapters/openai.py b/libs/community/langchain_community/adapters/openai.py
new file mode 100644
index 00000000000..0af759ebf5b
--- /dev/null
+++ b/libs/community/langchain_community/adapters/openai.py
@@ -0,0 +1,399 @@
+from __future__ import annotations
+
+import importlib
+from typing import (
+ Any,
+ AsyncIterator,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Sequence,
+ Union,
+ overload,
+)
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessage,
+ FunctionMessage,
+ HumanMessage,
+ SystemMessage,
+ ToolMessage,
+)
+from langchain_core.pydantic_v1 import BaseModel
+from typing_extensions import Literal
+
+
+async def aenumerate(
+ iterable: AsyncIterator[Any], start: int = 0
+) -> AsyncIterator[tuple[int, Any]]:
+ """Async version of enumerate function."""
+ i = start
+ async for x in iterable:
+ yield i, x
+ i += 1
+
+
+class IndexableBaseModel(BaseModel):
+ """Allows a BaseModel to return its fields by string variable indexing"""
+
+ def __getitem__(self, item: str) -> Any:
+ return getattr(self, item)
+
+
+class Choice(IndexableBaseModel):
+ message: dict
+
+
+class ChatCompletions(IndexableBaseModel):
+ choices: List[Choice]
+
+
+class ChoiceChunk(IndexableBaseModel):
+ delta: dict
+
+
+class ChatCompletionChunk(IndexableBaseModel):
+ choices: List[ChoiceChunk]
+
+
+def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ """Convert a dictionary to a LangChain message.
+
+ Args:
+ _dict: The dictionary.
+
+ Returns:
+ The LangChain message.
+ """
+ role = _dict["role"]
+ if role == "user":
+ return HumanMessage(content=_dict["content"])
+ elif role == "assistant":
+ # Fix for azure
+ # Also OpenAI returns None for tool invocations
+ content = _dict.get("content", "") or ""
+ additional_kwargs: Dict = {}
+ if _dict.get("function_call"):
+ additional_kwargs["function_call"] = dict(_dict["function_call"])
+ if _dict.get("tool_calls"):
+ additional_kwargs["tool_calls"] = _dict["tool_calls"]
+ return AIMessage(content=content, additional_kwargs=additional_kwargs)
+ elif role == "system":
+ return SystemMessage(content=_dict["content"])
+ elif role == "function":
+ return FunctionMessage(content=_dict["content"], name=_dict["name"])
+ elif role == "tool":
+ return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"])
+ else:
+ return ChatMessage(content=_dict["content"], role=role)
+
+
+def convert_message_to_dict(message: BaseMessage) -> dict:
+ """Convert a LangChain message to a dictionary.
+
+ Args:
+ message: The LangChain message.
+
+ Returns:
+ The dictionary.
+ """
+ message_dict: Dict[str, Any]
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ if "function_call" in message.additional_kwargs:
+ message_dict["function_call"] = message.additional_kwargs["function_call"]
+ # If function call only, content is None not empty string
+ if message_dict["content"] == "":
+ message_dict["content"] = None
+ if "tool_calls" in message.additional_kwargs:
+ message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
+ # If tool calls only, content is None not empty string
+ if message_dict["content"] == "":
+ message_dict["content"] = None
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ message_dict = {
+ "role": "function",
+ "content": message.content,
+ "name": message.name,
+ }
+ elif isinstance(message, ToolMessage):
+ message_dict = {
+ "role": "tool",
+ "content": message.content,
+ "tool_call_id": message.tool_call_id,
+ }
+ else:
+ raise TypeError(f"Got unknown type {message}")
+ if "name" in message.additional_kwargs:
+ message_dict["name"] = message.additional_kwargs["name"]
+ return message_dict
+
+
+def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
+ """Convert dictionaries representing OpenAI messages to LangChain format.
+
+ Args:
+ messages: List of dictionaries representing OpenAI messages
+
+ Returns:
+ List of LangChain BaseMessage objects.
+ """
+ return [convert_dict_to_message(m) for m in messages]
+
+
+def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict:
+ _dict: Dict[str, Any] = {}
+ if isinstance(chunk, AIMessageChunk):
+ if i == 0:
+ # Only shows up in the first chunk
+ _dict["role"] = "assistant"
+ if "function_call" in chunk.additional_kwargs:
+ _dict["function_call"] = chunk.additional_kwargs["function_call"]
+ # If the first chunk is a function call, the content is not empty string,
+ # not missing, but None.
+ if i == 0:
+ _dict["content"] = None
+ else:
+ _dict["content"] = chunk.content
+ else:
+ raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}")
+ # This only happens at the end of streams, and OpenAI returns as empty dict
+ if _dict == {"content": ""}:
+ _dict = {}
+ return _dict
+
+
+def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
+ _dict = _convert_message_chunk(chunk, i)
+ return {"choices": [{"delta": _dict}]}
+
+
+class ChatCompletion:
+ """Chat completion."""
+
+ @overload
+ @staticmethod
+ def create(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: Literal[False] = False,
+ **kwargs: Any,
+ ) -> dict:
+ ...
+
+ @overload
+ @staticmethod
+ def create(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: Literal[True],
+ **kwargs: Any,
+ ) -> Iterable:
+ ...
+
+ @staticmethod
+ def create(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: bool = False,
+ **kwargs: Any,
+ ) -> Union[dict, Iterable]:
+ models = importlib.import_module("langchain.chat_models")
+ model_cls = getattr(models, provider)
+ model_config = model_cls(**kwargs)
+ converted_messages = convert_openai_messages(messages)
+ if not stream:
+ result = model_config.invoke(converted_messages)
+ return {"choices": [{"message": convert_message_to_dict(result)}]}
+ else:
+ return (
+ _convert_message_chunk_to_delta(c, i)
+ for i, c in enumerate(model_config.stream(converted_messages))
+ )
+
+ @overload
+ @staticmethod
+ async def acreate(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: Literal[False] = False,
+ **kwargs: Any,
+ ) -> dict:
+ ...
+
+ @overload
+ @staticmethod
+ async def acreate(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: Literal[True],
+ **kwargs: Any,
+ ) -> AsyncIterator:
+ ...
+
+ @staticmethod
+ async def acreate(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: bool = False,
+ **kwargs: Any,
+ ) -> Union[dict, AsyncIterator]:
+ models = importlib.import_module("langchain.chat_models")
+ model_cls = getattr(models, provider)
+ model_config = model_cls(**kwargs)
+ converted_messages = convert_openai_messages(messages)
+ if not stream:
+ result = await model_config.ainvoke(converted_messages)
+ return {"choices": [{"message": convert_message_to_dict(result)}]}
+ else:
+ return (
+ _convert_message_chunk_to_delta(c, i)
+ async for i, c in aenumerate(model_config.astream(converted_messages))
+ )
+
+
+def _has_assistant_message(session: ChatSession) -> bool:
+ """Check if chat session has an assistant message."""
+ return any([isinstance(m, AIMessage) for m in session["messages"]])
+
+
+def convert_messages_for_finetuning(
+ sessions: Iterable[ChatSession],
+) -> List[List[dict]]:
+ """Convert messages to a list of lists of dictionaries for fine-tuning.
+
+ Args:
+ sessions: The chat sessions.
+
+ Returns:
+ The list of lists of dictionaries.
+ """
+ return [
+ [convert_message_to_dict(s) for s in session["messages"]]
+ for session in sessions
+ if _has_assistant_message(session)
+ ]
+
+
+class Completions:
+ """Completion."""
+
+ @overload
+ @staticmethod
+ def create(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: Literal[False] = False,
+ **kwargs: Any,
+ ) -> ChatCompletions:
+ ...
+
+ @overload
+ @staticmethod
+ def create(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: Literal[True],
+ **kwargs: Any,
+ ) -> Iterable:
+ ...
+
+ @staticmethod
+ def create(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: bool = False,
+ **kwargs: Any,
+ ) -> Union[ChatCompletions, Iterable]:
+ models = importlib.import_module("langchain.chat_models")
+ model_cls = getattr(models, provider)
+ model_config = model_cls(**kwargs)
+ converted_messages = convert_openai_messages(messages)
+ if not stream:
+ result = model_config.invoke(converted_messages)
+ return ChatCompletions(
+ choices=[Choice(message=convert_message_to_dict(result))]
+ )
+ else:
+ return (
+ ChatCompletionChunk(
+ choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
+ )
+ for i, c in enumerate(model_config.stream(converted_messages))
+ )
+
+ @overload
+ @staticmethod
+ async def acreate(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: Literal[False] = False,
+ **kwargs: Any,
+ ) -> ChatCompletions:
+ ...
+
+ @overload
+ @staticmethod
+ async def acreate(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: Literal[True],
+ **kwargs: Any,
+ ) -> AsyncIterator:
+ ...
+
+ @staticmethod
+ async def acreate(
+ messages: Sequence[Dict[str, Any]],
+ *,
+ provider: str = "ChatOpenAI",
+ stream: bool = False,
+ **kwargs: Any,
+ ) -> Union[ChatCompletions, AsyncIterator]:
+ models = importlib.import_module("langchain.chat_models")
+ model_cls = getattr(models, provider)
+ model_config = model_cls(**kwargs)
+ converted_messages = convert_openai_messages(messages)
+ if not stream:
+ result = await model_config.ainvoke(converted_messages)
+ return ChatCompletions(
+ choices=[Choice(message=convert_message_to_dict(result))]
+ )
+ else:
+ return (
+ ChatCompletionChunk(
+ choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
+ )
+ async for i, c in aenumerate(model_config.astream(converted_messages))
+ )
+
+
+class Chat:
+ def __init__(self) -> None:
+ self.completions = Completions()
+
+
+chat = Chat()
diff --git a/libs/community/langchain_community/agent_toolkits/__init__.py b/libs/community/langchain_community/agent_toolkits/__init__.py
new file mode 100644
index 00000000000..9501eb5db14
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/__init__.py
@@ -0,0 +1,78 @@
+"""Agent toolkits contain integrations with various resources and services.
+
+LangChain has a large ecosystem of integrations with various external resources
+like local and remote file systems, APIs and databases.
+
+These integrations allow developers to create versatile applications that combine the
+power of LLMs with the ability to access, interact with and manipulate external
+resources.
+
+When developing an application, developers should inspect the capabilities and
+permissions of the tools that underlie the given agent toolkit, and determine
+whether permissions of the given toolkit are appropriate for the application.
+
+See [Security](https://python.langchain.com/docs/security) for more information.
+"""
+from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
+from langchain_community.agent_toolkits.amadeus.toolkit import AmadeusToolkit
+from langchain_community.agent_toolkits.azure_cognitive_services import (
+ AzureCognitiveServicesToolkit,
+)
+from langchain_community.agent_toolkits.conversational_retrieval.openai_functions import ( # noqa: E501
+ create_conversational_retrieval_agent,
+)
+from langchain_community.agent_toolkits.file_management.toolkit import (
+ FileManagementToolkit,
+)
+from langchain_community.agent_toolkits.gmail.toolkit import GmailToolkit
+from langchain_community.agent_toolkits.jira.toolkit import JiraToolkit
+from langchain_community.agent_toolkits.json.base import create_json_agent
+from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
+from langchain_community.agent_toolkits.multion.toolkit import MultionToolkit
+from langchain_community.agent_toolkits.nasa.toolkit import NasaToolkit
+from langchain_community.agent_toolkits.nla.toolkit import NLAToolkit
+from langchain_community.agent_toolkits.office365.toolkit import O365Toolkit
+from langchain_community.agent_toolkits.openapi.base import create_openapi_agent
+from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit
+from langchain_community.agent_toolkits.playwright.toolkit import (
+ PlayWrightBrowserToolkit,
+)
+from langchain_community.agent_toolkits.powerbi.base import create_pbi_agent
+from langchain_community.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent
+from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
+from langchain_community.agent_toolkits.slack.toolkit import SlackToolkit
+from langchain_community.agent_toolkits.spark_sql.base import create_spark_sql_agent
+from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
+from langchain_community.agent_toolkits.sql.base import create_sql_agent
+from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
+from langchain_community.agent_toolkits.steam.toolkit import SteamToolkit
+from langchain_community.agent_toolkits.zapier.toolkit import ZapierToolkit
+
+__all__ = [
+ "AINetworkToolkit",
+ "AmadeusToolkit",
+ "AzureCognitiveServicesToolkit",
+ "FileManagementToolkit",
+ "GmailToolkit",
+ "JiraToolkit",
+ "JsonToolkit",
+ "MultionToolkit",
+ "NasaToolkit",
+ "NLAToolkit",
+ "O365Toolkit",
+ "OpenAPIToolkit",
+ "PlayWrightBrowserToolkit",
+ "PowerBIToolkit",
+ "SlackToolkit",
+ "SteamToolkit",
+ "SQLDatabaseToolkit",
+ "SparkSQLToolkit",
+ "ZapierToolkit",
+ "create_json_agent",
+ "create_openapi_agent",
+ "create_pbi_agent",
+ "create_pbi_chat_agent",
+ "create_spark_sql_agent",
+ "create_sql_agent",
+ "create_conversational_retrieval_agent",
+]
diff --git a/libs/community/langchain_community/agent_toolkits/ainetwork/__init__.py b/libs/community/langchain_community/agent_toolkits/ainetwork/__init__.py
new file mode 100644
index 00000000000..c4295f2ef48
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/ainetwork/__init__.py
@@ -0,0 +1 @@
+"""AINetwork toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/ainetwork/toolkit.py b/libs/community/langchain_community/agent_toolkits/ainetwork/toolkit.py
new file mode 100644
index 00000000000..3e0c140e99c
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/ainetwork/toolkit.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Literal, Optional
+
+from langchain_core.pydantic_v1 import root_validator
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.ainetwork.app import AINAppOps
+from langchain_community.tools.ainetwork.owner import AINOwnerOps
+from langchain_community.tools.ainetwork.rule import AINRuleOps
+from langchain_community.tools.ainetwork.transfer import AINTransfer
+from langchain_community.tools.ainetwork.utils import authenticate
+from langchain_community.tools.ainetwork.value import AINValueOps
+
+if TYPE_CHECKING:
+ from ain.ain import Ain
+
+
+class AINetworkToolkit(BaseToolkit):
+ """Toolkit for interacting with AINetwork Blockchain.
+
+ *Security Note*: This toolkit contains tools that can read and modify
+ the state of a service; e.g., by reading, creating, updating, deleting
+ data associated with this service.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ network: Optional[Literal["mainnet", "testnet"]] = "testnet"
+ interface: Optional[Ain] = None
+
+ @root_validator(pre=True)
+ def set_interface(cls, values: dict) -> dict:
+ if not values.get("interface"):
+ values["interface"] = authenticate(network=values.get("network", "testnet"))
+ return values
+
+ class Config:
+ """Pydantic config."""
+
+ validate_all = True
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [
+ AINAppOps(),
+ AINOwnerOps(),
+ AINRuleOps(),
+ AINTransfer(),
+ AINValueOps(),
+ ]
diff --git a/libs/langchain/tests/integration_tests/callbacks/__init__.py b/libs/community/langchain_community/agent_toolkits/amadeus/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/callbacks/__init__.py
rename to libs/community/langchain_community/agent_toolkits/amadeus/__init__.py
diff --git a/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py b/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py
new file mode 100644
index 00000000000..90bc5da6476
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List
+
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.amadeus.closest_airport import AmadeusClosestAirport
+from langchain_community.tools.amadeus.flight_search import AmadeusFlightSearch
+from langchain_community.tools.amadeus.utils import authenticate
+
+if TYPE_CHECKING:
+ from amadeus import Client
+
+
+class AmadeusToolkit(BaseToolkit):
+ """Toolkit for interacting with Amadeus which offers APIs for travel."""
+
+ client: Client = Field(default_factory=authenticate)
+
+ class Config:
+ """Pydantic config."""
+
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [
+ AmadeusClosestAirport(),
+ AmadeusFlightSearch(),
+ ]
diff --git a/libs/community/langchain_community/agent_toolkits/azure_cognitive_services.py b/libs/community/langchain_community/agent_toolkits/azure_cognitive_services.py
new file mode 100644
index 00000000000..3c2de898ae4
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/azure_cognitive_services.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+import sys
+from typing import List
+
+from langchain_core.tools import BaseTool
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools.azure_cognitive_services import (
+ AzureCogsFormRecognizerTool,
+ AzureCogsImageAnalysisTool,
+ AzureCogsSpeech2TextTool,
+ AzureCogsText2SpeechTool,
+ AzureCogsTextAnalyticsHealthTool,
+)
+
+
+class AzureCognitiveServicesToolkit(BaseToolkit):
+ """Toolkit for Azure Cognitive Services."""
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+
+ tools: List[BaseTool] = [
+ AzureCogsFormRecognizerTool(),
+ AzureCogsSpeech2TextTool(),
+ AzureCogsText2SpeechTool(),
+ AzureCogsTextAnalyticsHealthTool(),
+ ]
+
+ # TODO: Remove check once azure-ai-vision supports MacOS.
+ if sys.platform.startswith("linux") or sys.platform.startswith("win"):
+ tools.append(AzureCogsImageAnalysisTool())
+ return tools
diff --git a/libs/community/langchain_community/agent_toolkits/base.py b/libs/community/langchain_community/agent_toolkits/base.py
new file mode 100644
index 00000000000..0315184115e
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/base.py
@@ -0,0 +1,15 @@
+"""Toolkits for agents."""
+from abc import ABC, abstractmethod
+from typing import List
+
+from langchain_core.pydantic_v1 import BaseModel
+
+from langchain_community.tools import BaseTool
+
+
+class BaseToolkit(BaseModel, ABC):
+ """Base Toolkit representing a collection of related tools."""
+
+ @abstractmethod
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
diff --git a/libs/langchain/tests/integration_tests/chat_models/__init__.py b/libs/community/langchain_community/agent_toolkits/clickup/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/chat_models/__init__.py
rename to libs/community/langchain_community/agent_toolkits/clickup/__init__.py
diff --git a/libs/community/langchain_community/agent_toolkits/clickup/toolkit.py b/libs/community/langchain_community/agent_toolkits/clickup/toolkit.py
new file mode 100644
index 00000000000..23a1c9a76c8
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/clickup/toolkit.py
@@ -0,0 +1,108 @@
+from typing import Dict, List
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.clickup.prompt import (
+ CLICKUP_FOLDER_CREATE_PROMPT,
+ CLICKUP_GET_ALL_TEAMS_PROMPT,
+ CLICKUP_GET_FOLDERS_PROMPT,
+ CLICKUP_GET_LIST_PROMPT,
+ CLICKUP_GET_SPACES_PROMPT,
+ CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
+ CLICKUP_GET_TASK_PROMPT,
+ CLICKUP_LIST_CREATE_PROMPT,
+ CLICKUP_TASK_CREATE_PROMPT,
+ CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
+ CLICKUP_UPDATE_TASK_PROMPT,
+)
+from langchain_community.tools.clickup.tool import ClickupAction
+from langchain_community.utilities.clickup import ClickupAPIWrapper
+
+
+class ClickupToolkit(BaseToolkit):
+ """Clickup Toolkit.
+
+ *Security Note*: This toolkit contains tools that can read and modify
+ the state of a service; e.g., by reading, creating, updating, deleting
+ data associated with this service.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ tools: List[BaseTool] = []
+
+ @classmethod
+ def from_clickup_api_wrapper(
+ cls, clickup_api_wrapper: ClickupAPIWrapper
+ ) -> "ClickupToolkit":
+ operations: List[Dict] = [
+ {
+ "mode": "get_task",
+ "name": "Get task",
+ "description": CLICKUP_GET_TASK_PROMPT,
+ },
+ {
+ "mode": "get_task_attribute",
+ "name": "Get task attribute",
+ "description": CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
+ },
+ {
+ "mode": "get_teams",
+ "name": "Get Teams",
+ "description": CLICKUP_GET_ALL_TEAMS_PROMPT,
+ },
+ {
+ "mode": "create_task",
+ "name": "Create Task",
+ "description": CLICKUP_TASK_CREATE_PROMPT,
+ },
+ {
+ "mode": "create_list",
+ "name": "Create List",
+ "description": CLICKUP_LIST_CREATE_PROMPT,
+ },
+ {
+ "mode": "create_folder",
+ "name": "Create Folder",
+ "description": CLICKUP_FOLDER_CREATE_PROMPT,
+ },
+ {
+ "mode": "get_list",
+ "name": "Get all lists in the space",
+ "description": CLICKUP_GET_LIST_PROMPT,
+ },
+ {
+ "mode": "get_folders",
+ "name": "Get all folders in the workspace",
+ "description": CLICKUP_GET_FOLDERS_PROMPT,
+ },
+ {
+ "mode": "get_spaces",
+ "name": "Get all spaces in the workspace",
+ "description": CLICKUP_GET_SPACES_PROMPT,
+ },
+ {
+ "mode": "update_task",
+ "name": "Update task",
+ "description": CLICKUP_UPDATE_TASK_PROMPT,
+ },
+ {
+ "mode": "update_task_assignees",
+ "name": "Update task assignees",
+ "description": CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
+ },
+ ]
+ tools = [
+ ClickupAction(
+ name=action["name"],
+ description=action["description"],
+ mode=action["mode"],
+ api_wrapper=clickup_api_wrapper,
+ )
+ for action in operations
+ ]
+ return cls(tools=tools)
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return self.tools
diff --git a/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py b/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py
new file mode 100644
index 00000000000..405ccd39357
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, List, Optional
+
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.memory import BaseMemory
+from langchain_core.messages import SystemMessage
+from langchain_core.prompts.chat import MessagesPlaceholder
+from langchain_core.tools import BaseTool
+
+if TYPE_CHECKING:
+ from langchain.agents.agent import AgentExecutor
+
+
+def _get_default_system_message() -> SystemMessage:
+ return SystemMessage(
+ content=(
+ "Do your best to answer the questions. "
+ "Feel free to use any tools available to look up "
+ "relevant information, only if necessary"
+ )
+ )
+
+
+def create_conversational_retrieval_agent(
+ llm: BaseLanguageModel,
+ tools: List[BaseTool],
+ remember_intermediate_steps: bool = True,
+ memory_key: str = "chat_history",
+ system_message: Optional[SystemMessage] = None,
+ verbose: bool = False,
+ max_token_limit: int = 2000,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """A convenience method for creating a conversational retrieval agent.
+
+ Args:
+ llm: The language model to use, should be ChatOpenAI
+ tools: A list of tools the agent has access to
+ remember_intermediate_steps: Whether the agent should remember intermediate
+ steps or not. Intermediate steps refer to prior action/observation
+ pairs from previous questions. The benefit of remembering these is if
+ there is relevant information in there, the agent can use it to answer
+ follow up questions. The downside is it will take up more tokens.
+ memory_key: The name of the memory key in the prompt.
+ system_message: The system message to use. By default, a basic one will
+ be used.
+ verbose: Whether or not the final AgentExecutor should be verbose or not,
+ defaults to False.
+ max_token_limit: The max number of tokens to keep around in memory.
+ Defaults to 2000.
+
+ Returns:
+ An agent executor initialized appropriately
+ """
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
+ AgentTokenBufferMemory,
+ )
+ from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
+ from langchain.memory.token_buffer import ConversationTokenBufferMemory
+
+ if remember_intermediate_steps:
+ memory: BaseMemory = AgentTokenBufferMemory(
+ memory_key=memory_key, llm=llm, max_token_limit=max_token_limit
+ )
+ else:
+ memory = ConversationTokenBufferMemory(
+ memory_key=memory_key,
+ return_messages=True,
+ output_key="output",
+ llm=llm,
+ max_token_limit=max_token_limit,
+ )
+
+ _system_message = system_message or _get_default_system_message()
+ prompt = OpenAIFunctionsAgent.create_prompt(
+ system_message=_system_message,
+ extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
+ )
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
+ return AgentExecutor(
+ agent=agent,
+ tools=tools,
+ memory=memory,
+ verbose=verbose,
+ return_intermediate_steps=remember_intermediate_steps,
+ **kwargs,
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/csv/__init__.py b/libs/community/langchain_community/agent_toolkits/csv/__init__.py
new file mode 100644
index 00000000000..4b049802888
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/csv/__init__.py
@@ -0,0 +1,26 @@
+from pathlib import Path
+from typing import Any
+
+from langchain_core._api.path import as_import_path
+
+
+def __getattr__(name: str) -> Any:
+ """Get attr name."""
+
+ if name == "create_csv_agent":
+ # Get directory of langchain package
+ HERE = Path(__file__).parents[3]
+ here = as_import_path(Path(__file__).parent, relative_to=HERE)
+
+ old_path = "langchain." + here + "." + name
+ new_path = "langchain_experimental." + here + "." + name
+ raise ImportError(
+ "This agent has been moved to langchain experiment. "
+ "This agent relies on python REPL tool under the hood, so to use it "
+ "safely please sandbox the python REPL. "
+ "Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
+ "and https://github.com/langchain-ai/langchain/discussions/11680"
+ "To keep using this code as is, install langchain experimental and "
+ f"update your import statement from:\n `{old_path}` to `{new_path}`."
+ )
+ raise AttributeError(f"{name} does not exist")
diff --git a/libs/community/langchain_community/agent_toolkits/file_management/__init__.py b/libs/community/langchain_community/agent_toolkits/file_management/__init__.py
new file mode 100644
index 00000000000..53ce9329f91
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/file_management/__init__.py
@@ -0,0 +1,7 @@
+"""Local file management toolkit."""
+
+from langchain_community.agent_toolkits.file_management.toolkit import (
+ FileManagementToolkit,
+)
+
+__all__ = ["FileManagementToolkit"]
diff --git a/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py
new file mode 100644
index 00000000000..538569755b5
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+from typing import List, Optional
+
+from langchain_core.pydantic_v1 import root_validator
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.file_management.copy import CopyFileTool
+from langchain_community.tools.file_management.delete import DeleteFileTool
+from langchain_community.tools.file_management.file_search import FileSearchTool
+from langchain_community.tools.file_management.list_dir import ListDirectoryTool
+from langchain_community.tools.file_management.move import MoveFileTool
+from langchain_community.tools.file_management.read import ReadFileTool
+from langchain_community.tools.file_management.write import WriteFileTool
+
+_FILE_TOOLS = {
+ # "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined]
+ tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined]
+ for tool_cls in [
+ CopyFileTool,
+ DeleteFileTool,
+ FileSearchTool,
+ MoveFileTool,
+ ReadFileTool,
+ WriteFileTool,
+ ListDirectoryTool,
+ ]
+}
+
+
+class FileManagementToolkit(BaseToolkit):
+ """Toolkit for interacting with local files.
+
+ *Security Notice*: This toolkit provides methods to interact with local files.
+ If providing this toolkit to an agent on an LLM, ensure you scope
+ the agent's permissions to only include the necessary permissions
+ to perform the desired operations.
+
+ By **default** the agent will have access to all files within
+ the root dir and will be able to Copy, Delete, Move, Read, Write
+ and List files in that directory.
+
+ Consider the following:
+ - Limit access to particular directories using `root_dir`.
+ - Use filesystem permissions to restrict access and permissions to only
+ the files and directories required by the agent.
+ - Limit the tools available to the agent to only the file operations
+ necessary for the agent's intended use.
+ - Sandbox the agent by running it in a container.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ root_dir: Optional[str] = None
+ """If specified, all file operations are made relative to root_dir."""
+ selected_tools: Optional[List[str]] = None
+ """If provided, only provide the selected tools. Defaults to all."""
+
+ @root_validator
+ def validate_tools(cls, values: dict) -> dict:
+ selected_tools = values.get("selected_tools") or []
+ for tool_name in selected_tools:
+ if tool_name not in _FILE_TOOLS:
+ raise ValueError(
+ f"File Tool of name {tool_name} not supported."
+ f" Permitted tools: {list(_FILE_TOOLS)}"
+ )
+ return values
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ allowed_tools = self.selected_tools or _FILE_TOOLS.keys()
+ tools: List[BaseTool] = []
+ for tool in allowed_tools:
+ tool_cls = _FILE_TOOLS[tool]
+ tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
+ return tools
+
+
+__all__ = ["FileManagementToolkit"]
diff --git a/libs/community/langchain_community/agent_toolkits/github/__init__.py b/libs/community/langchain_community/agent_toolkits/github/__init__.py
new file mode 100644
index 00000000000..bcd9368a52a
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/github/__init__.py
@@ -0,0 +1 @@
+"""GitHub Toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/github/toolkit.py b/libs/community/langchain_community/agent_toolkits/github/toolkit.py
new file mode 100644
index 00000000000..7c85504e1c6
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/github/toolkit.py
@@ -0,0 +1,287 @@
+"""GitHub Toolkit."""
+from typing import Dict, List
+
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.github.prompt import (
+ COMMENT_ON_ISSUE_PROMPT,
+ CREATE_BRANCH_PROMPT,
+ CREATE_FILE_PROMPT,
+ CREATE_PULL_REQUEST_PROMPT,
+ CREATE_REVIEW_REQUEST_PROMPT,
+ DELETE_FILE_PROMPT,
+ GET_FILES_FROM_DIRECTORY_PROMPT,
+ GET_ISSUE_PROMPT,
+ GET_ISSUES_PROMPT,
+ GET_PR_PROMPT,
+ LIST_BRANCHES_IN_REPO_PROMPT,
+ LIST_PRS_PROMPT,
+ LIST_PULL_REQUEST_FILES,
+ OVERVIEW_EXISTING_FILES_BOT_BRANCH,
+ OVERVIEW_EXISTING_FILES_IN_MAIN,
+ READ_FILE_PROMPT,
+ SEARCH_CODE_PROMPT,
+ SEARCH_ISSUES_AND_PRS_PROMPT,
+ SET_ACTIVE_BRANCH_PROMPT,
+ UPDATE_FILE_PROMPT,
+)
+from langchain_community.tools.github.tool import GitHubAction
+from langchain_community.utilities.github import GitHubAPIWrapper
+
+
+class NoInput(BaseModel):
+ no_input: str = Field("", description="No input required, e.g. `` (empty string).")
+
+
+class GetIssue(BaseModel):
+ issue_number: int = Field(0, description="Issue number as an integer, e.g. `42`")
+
+
+class CommentOnIssue(BaseModel):
+ input: str = Field(..., description="Follow the required formatting.")
+
+
+class GetPR(BaseModel):
+ pr_number: int = Field(0, description="The PR number as an integer, e.g. `12`")
+
+
+class CreatePR(BaseModel):
+ formatted_pr: str = Field(..., description="Follow the required formatting.")
+
+
+class CreateFile(BaseModel):
+ formatted_file: str = Field(..., description="Follow the required formatting.")
+
+
+class ReadFile(BaseModel):
+ formatted_filepath: str = Field(
+ ...,
+ description=(
+ "The full file path of the file you would like to read where the "
+ "path must NOT start with a slash, e.g. `some_dir/my_file.py`."
+ ),
+ )
+
+
+class UpdateFile(BaseModel):
+ formatted_file_update: str = Field(
+ ..., description="Strictly follow the provided rules."
+ )
+
+
+class DeleteFile(BaseModel):
+ formatted_filepath: str = Field(
+ ...,
+ description=(
+ "The full file path of the file you would like to delete"
+ " where the path must NOT start with a slash, e.g."
+ " `some_dir/my_file.py`. Only input a string,"
+ " not the param name."
+ ),
+ )
+
+
+class DirectoryPath(BaseModel):
+ input: str = Field(
+ "",
+ description=(
+ "The path of the directory, e.g. `some_dir/inner_dir`."
+ " Only input a string, do not include the parameter name."
+ ),
+ )
+
+
+class BranchName(BaseModel):
+ branch_name: str = Field(
+ ..., description="The name of the branch, e.g. `my_branch`."
+ )
+
+
+class SearchCode(BaseModel):
+ search_query: str = Field(
+ ...,
+ description=(
+ "A keyword-focused natural language search"
+ "query for code, e.g. `MyFunctionName()`."
+ ),
+ )
+
+
+class CreateReviewRequest(BaseModel):
+ username: str = Field(
+ ...,
+ description="GitHub username of the user being requested, e.g. `my_username`.",
+ )
+
+
+class SearchIssuesAndPRs(BaseModel):
+ search_query: str = Field(
+ ...,
+ description="Natural language search query, e.g. `My issue title or topic`.",
+ )
+
+
+class GitHubToolkit(BaseToolkit):
+ """GitHub Toolkit.
+
+ *Security Note*: This toolkit contains tools that can read and modify
+ the state of a service; e.g., by creating, deleting, or updating,
+ reading underlying data.
+
+ For example, this toolkit can be used to create issues, pull requests,
+ and comments on GitHub.
+
+ See [Security](https://python.langchain.com/docs/security) for more information.
+ """
+
+ tools: List[BaseTool] = []
+
+ @classmethod
+ def from_github_api_wrapper(
+ cls, github_api_wrapper: GitHubAPIWrapper
+ ) -> "GitHubToolkit":
+ operations: List[Dict] = [
+ {
+ "mode": "get_issues",
+ "name": "Get Issues",
+ "description": GET_ISSUES_PROMPT,
+ "args_schema": NoInput,
+ },
+ {
+ "mode": "get_issue",
+ "name": "Get Issue",
+ "description": GET_ISSUE_PROMPT,
+ "args_schema": GetIssue,
+ },
+ {
+ "mode": "comment_on_issue",
+ "name": "Comment on Issue",
+ "description": COMMENT_ON_ISSUE_PROMPT,
+ "args_schema": CommentOnIssue,
+ },
+ {
+ "mode": "list_open_pull_requests",
+ "name": "List open pull requests (PRs)",
+ "description": LIST_PRS_PROMPT,
+ "args_schema": NoInput,
+ },
+ {
+ "mode": "get_pull_request",
+ "name": "Get Pull Request",
+ "description": GET_PR_PROMPT,
+ "args_schema": GetPR,
+ },
+ {
+ "mode": "list_pull_request_files",
+ "name": "Overview of files included in PR",
+ "description": LIST_PULL_REQUEST_FILES,
+ "args_schema": GetPR,
+ },
+ {
+ "mode": "create_pull_request",
+ "name": "Create Pull Request",
+ "description": CREATE_PULL_REQUEST_PROMPT,
+ "args_schema": CreatePR,
+ },
+ {
+ "mode": "list_pull_request_files",
+ "name": "List Pull Requests' Files",
+ "description": LIST_PULL_REQUEST_FILES,
+ "args_schema": GetPR,
+ },
+ {
+ "mode": "create_file",
+ "name": "Create File",
+ "description": CREATE_FILE_PROMPT,
+ "args_schema": CreateFile,
+ },
+ {
+ "mode": "read_file",
+ "name": "Read File",
+ "description": READ_FILE_PROMPT,
+ "args_schema": ReadFile,
+ },
+ {
+ "mode": "update_file",
+ "name": "Update File",
+ "description": UPDATE_FILE_PROMPT,
+ "args_schema": UpdateFile,
+ },
+ {
+ "mode": "delete_file",
+ "name": "Delete File",
+ "description": DELETE_FILE_PROMPT,
+ "args_schema": DeleteFile,
+ },
+ {
+ "mode": "list_files_in_main_branch",
+ "name": "Overview of existing files in Main branch",
+ "description": OVERVIEW_EXISTING_FILES_IN_MAIN,
+ "args_schema": NoInput,
+ },
+ {
+ "mode": "list_files_in_bot_branch",
+ "name": "Overview of files in current working branch",
+ "description": OVERVIEW_EXISTING_FILES_BOT_BRANCH,
+ "args_schema": NoInput,
+ },
+ {
+ "mode": "list_branches_in_repo",
+ "name": "List branches in this repository",
+ "description": LIST_BRANCHES_IN_REPO_PROMPT,
+ "args_schema": NoInput,
+ },
+ {
+ "mode": "set_active_branch",
+ "name": "Set active branch",
+ "description": SET_ACTIVE_BRANCH_PROMPT,
+ "args_schema": BranchName,
+ },
+ {
+ "mode": "create_branch",
+ "name": "Create a new branch",
+ "description": CREATE_BRANCH_PROMPT,
+ "args_schema": BranchName,
+ },
+ {
+ "mode": "get_files_from_directory",
+ "name": "Get files from a directory",
+ "description": GET_FILES_FROM_DIRECTORY_PROMPT,
+ "args_schema": DirectoryPath,
+ },
+ {
+ "mode": "search_issues_and_prs",
+ "name": "Search issues and pull requests",
+ "description": SEARCH_ISSUES_AND_PRS_PROMPT,
+ "args_schema": SearchIssuesAndPRs,
+ },
+ {
+ "mode": "search_code",
+ "name": "Search code",
+ "description": SEARCH_CODE_PROMPT,
+ "args_schema": SearchCode,
+ },
+ {
+ "mode": "create_review_request",
+ "name": "Create review request",
+ "description": CREATE_REVIEW_REQUEST_PROMPT,
+ "args_schema": CreateReviewRequest,
+ },
+ ]
+ tools = [
+ GitHubAction(
+ name=action["name"],
+ description=action["description"],
+ mode=action["mode"],
+ api_wrapper=github_api_wrapper,
+ args_schema=action.get("args_schema", None),
+ )
+ for action in operations
+ ]
+ return cls(tools=tools)
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return self.tools
diff --git a/libs/community/langchain_community/agent_toolkits/gitlab/__init__.py b/libs/community/langchain_community/agent_toolkits/gitlab/__init__.py
new file mode 100644
index 00000000000..7d3ca720636
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/gitlab/__init__.py
@@ -0,0 +1 @@
+"""GitLab Toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/gitlab/toolkit.py b/libs/community/langchain_community/agent_toolkits/gitlab/toolkit.py
new file mode 100644
index 00000000000..1a56b821f5b
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/gitlab/toolkit.py
@@ -0,0 +1,94 @@
+"""GitHub Toolkit."""
+from typing import Dict, List
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.gitlab.prompt import (
+ COMMENT_ON_ISSUE_PROMPT,
+ CREATE_FILE_PROMPT,
+ CREATE_PULL_REQUEST_PROMPT,
+ DELETE_FILE_PROMPT,
+ GET_ISSUE_PROMPT,
+ GET_ISSUES_PROMPT,
+ READ_FILE_PROMPT,
+ UPDATE_FILE_PROMPT,
+)
+from langchain_community.tools.gitlab.tool import GitLabAction
+from langchain_community.utilities.gitlab import GitLabAPIWrapper
+
+
+class GitLabToolkit(BaseToolkit):
+ """GitLab Toolkit.
+
+ *Security Note*: This toolkit contains tools that can read and modify
+ the state of a service; e.g., by creating, deleting, or updating,
+ reading underlying data.
+
+ For example, this toolkit can be used to create issues, pull requests,
+ and comments on GitLab.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ tools: List[BaseTool] = []
+
+ @classmethod
+ def from_gitlab_api_wrapper(
+ cls, gitlab_api_wrapper: GitLabAPIWrapper
+ ) -> "GitLabToolkit":
+ operations: List[Dict] = [
+ {
+ "mode": "get_issues",
+ "name": "Get Issues",
+ "description": GET_ISSUES_PROMPT,
+ },
+ {
+ "mode": "get_issue",
+ "name": "Get Issue",
+ "description": GET_ISSUE_PROMPT,
+ },
+ {
+ "mode": "comment_on_issue",
+ "name": "Comment on Issue",
+ "description": COMMENT_ON_ISSUE_PROMPT,
+ },
+ {
+ "mode": "create_pull_request",
+ "name": "Create Pull Request",
+ "description": CREATE_PULL_REQUEST_PROMPT,
+ },
+ {
+ "mode": "create_file",
+ "name": "Create File",
+ "description": CREATE_FILE_PROMPT,
+ },
+ {
+ "mode": "read_file",
+ "name": "Read File",
+ "description": READ_FILE_PROMPT,
+ },
+ {
+ "mode": "update_file",
+ "name": "Update File",
+ "description": UPDATE_FILE_PROMPT,
+ },
+ {
+ "mode": "delete_file",
+ "name": "Delete File",
+ "description": DELETE_FILE_PROMPT,
+ },
+ ]
+ tools = [
+ GitLabAction(
+ name=action["name"],
+ description=action["description"],
+ mode=action["mode"],
+ api_wrapper=gitlab_api_wrapper,
+ )
+ for action in operations
+ ]
+ return cls(tools=tools)
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return self.tools
diff --git a/libs/community/langchain_community/agent_toolkits/gmail/__init__.py b/libs/community/langchain_community/agent_toolkits/gmail/__init__.py
new file mode 100644
index 00000000000..02e7f81659f
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/gmail/__init__.py
@@ -0,0 +1 @@
+"""Gmail toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/gmail/toolkit.py b/libs/community/langchain_community/agent_toolkits/gmail/toolkit.py
new file mode 100644
index 00000000000..1e4af3fd614
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/gmail/toolkit.py
@@ -0,0 +1,58 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List
+
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.gmail.create_draft import GmailCreateDraft
+from langchain_community.tools.gmail.get_message import GmailGetMessage
+from langchain_community.tools.gmail.get_thread import GmailGetThread
+from langchain_community.tools.gmail.search import GmailSearch
+from langchain_community.tools.gmail.send_message import GmailSendMessage
+from langchain_community.tools.gmail.utils import build_resource_service
+
+if TYPE_CHECKING:
+ # This is for linting and IDE typehints
+ from googleapiclient.discovery import Resource
+else:
+ try:
+ # We do this so pydantic can resolve the types when instantiating
+ from googleapiclient.discovery import Resource
+ except ImportError:
+ pass
+
+
+SCOPES = ["https://mail.google.com/"]
+
+
+class GmailToolkit(BaseToolkit):
+ """Toolkit for interacting with Gmail.
+
+ *Security Note*: This toolkit contains tools that can read and modify
+ the state of a service; e.g., by reading, creating, updating, deleting
+ data associated with this service.
+
+ For example, this toolkit can be used to send emails on behalf of the
+ associated account.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ api_resource: Resource = Field(default_factory=build_resource_service)
+
+ class Config:
+ """Pydantic config."""
+
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [
+ GmailCreateDraft(api_resource=self.api_resource),
+ GmailSendMessage(api_resource=self.api_resource),
+ GmailSearch(api_resource=self.api_resource),
+ GmailGetMessage(api_resource=self.api_resource),
+ GmailGetThread(api_resource=self.api_resource),
+ ]
diff --git a/libs/community/langchain_community/agent_toolkits/jira/__init__.py b/libs/community/langchain_community/agent_toolkits/jira/__init__.py
new file mode 100644
index 00000000000..9f7c67558fa
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/jira/__init__.py
@@ -0,0 +1 @@
+"""Jira Toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/jira/toolkit.py b/libs/community/langchain_community/agent_toolkits/jira/toolkit.py
new file mode 100644
index 00000000000..425c8d4c065
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/jira/toolkit.py
@@ -0,0 +1,70 @@
+from typing import Dict, List
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.jira.prompt import (
+ JIRA_CATCH_ALL_PROMPT,
+ JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
+ JIRA_GET_ALL_PROJECTS_PROMPT,
+ JIRA_ISSUE_CREATE_PROMPT,
+ JIRA_JQL_PROMPT,
+)
+from langchain_community.tools.jira.tool import JiraAction
+from langchain_community.utilities.jira import JiraAPIWrapper
+
+
+class JiraToolkit(BaseToolkit):
+ """Jira Toolkit.
+
+ *Security Note*: This toolkit contains tools that can read and modify
+ the state of a service; e.g., by creating, deleting, or updating,
+ reading underlying data.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ tools: List[BaseTool] = []
+
+ @classmethod
+ def from_jira_api_wrapper(cls, jira_api_wrapper: JiraAPIWrapper) -> "JiraToolkit":
+ operations: List[Dict] = [
+ {
+ "mode": "jql",
+ "name": "JQL Query",
+ "description": JIRA_JQL_PROMPT,
+ },
+ {
+ "mode": "get_projects",
+ "name": "Get Projects",
+ "description": JIRA_GET_ALL_PROJECTS_PROMPT,
+ },
+ {
+ "mode": "create_issue",
+ "name": "Create Issue",
+ "description": JIRA_ISSUE_CREATE_PROMPT,
+ },
+ {
+ "mode": "other",
+ "name": "Catch all Jira API call",
+ "description": JIRA_CATCH_ALL_PROMPT,
+ },
+ {
+ "mode": "create_page",
+ "name": "Create confluence page",
+ "description": JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
+ },
+ ]
+ tools = [
+ JiraAction(
+ name=action["name"],
+ description=action["description"],
+ mode=action["mode"],
+ api_wrapper=jira_api_wrapper,
+ )
+ for action in operations
+ ]
+ return cls(tools=tools)
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return self.tools
diff --git a/libs/community/langchain_community/agent_toolkits/json/__init__.py b/libs/community/langchain_community/agent_toolkits/json/__init__.py
new file mode 100644
index 00000000000..bfab0ec6f83
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/json/__init__.py
@@ -0,0 +1 @@
+"""Json agent."""
diff --git a/libs/community/langchain_community/agent_toolkits/json/base.py b/libs/community/langchain_community/agent_toolkits/json/base.py
new file mode 100644
index 00000000000..06830d90228
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/json/base.py
@@ -0,0 +1,59 @@
+"""Json agent."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.language_models import BaseLanguageModel
+
+from langchain_community.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
+from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
+
+if TYPE_CHECKING:
+ from langchain.agents.agent import AgentExecutor
+
+
+def create_json_agent(
+ llm: BaseLanguageModel,
+ toolkit: JsonToolkit,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ prefix: str = JSON_PREFIX,
+ suffix: str = JSON_SUFFIX,
+ format_instructions: Optional[str] = None,
+ input_variables: Optional[List[str]] = None,
+ verbose: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct a json agent from an LLM and tools."""
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.chains.llm import LLMChain
+
+ tools = toolkit.get_tools()
+ prompt_params = (
+ {"format_instructions": format_instructions}
+ if format_instructions is not None
+ else {}
+ )
+ prompt = ZeroShotAgent.create_prompt(
+ tools,
+ prefix=prefix,
+ suffix=suffix,
+ input_variables=input_variables,
+ **prompt_params,
+ )
+ llm_chain = LLMChain(
+ llm=llm,
+ prompt=prompt,
+ callback_manager=callback_manager,
+ )
+ tool_names = [tool.name for tool in tools]
+ agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ verbose=verbose,
+ **(agent_executor_kwargs or {}),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/json/prompt.py b/libs/community/langchain_community/agent_toolkits/json/prompt.py
new file mode 100644
index 00000000000..a3b7584aca2
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/json/prompt.py
@@ -0,0 +1,25 @@
+# flake8: noqa
+
+JSON_PREFIX = """You are an agent designed to interact with JSON.
+Your goal is to return a final answer by interacting with the JSON.
+You have access to the following tools which help you learn more about the JSON you are interacting with.
+Only use the below tools. Only use the information returned by the below tools to construct your final answer.
+Do not make up any information that is not contained in the JSON.
+Your input to the tools should be in the form of `data["key"][0]` where `data` is the JSON blob you are interacting with, and the syntax used is Python.
+You should only use keys that you know for a fact exist. You must validate that a key exists by seeing it previously when calling `json_spec_list_keys`.
+If you have not seen a key in one of those responses, you cannot use it.
+You should only add one key at a time to the path. You cannot add multiple keys at once.
+If you encounter a "KeyError", go back to the previous key, look at the available keys, and try again.
+
+If the question does not seem to be related to the JSON, just return "I don't know" as the answer.
+Always begin your interaction with the `json_spec_list_keys` tool with input "data" to see what keys exist in the JSON.
+
+Note that sometimes the value at a given path is large. In this case, you will get an error "Value is a large dictionary, should explore its keys directly".
+In this case, you should ALWAYS follow up by using the `json_spec_list_keys` tool to see what keys exist at that path.
+Do not simply refer the user to the JSON or a section of the JSON, as this is not a valid answer. Keep digging until you find the answer and explicitly return it.
+"""
+JSON_SUFFIX = """Begin!"
+
+Question: {input}
+Thought: I should look at the keys that exist in data to see what I have access to
+{agent_scratchpad}"""
diff --git a/libs/community/langchain_community/agent_toolkits/json/toolkit.py b/libs/community/langchain_community/agent_toolkits/json/toolkit.py
new file mode 100644
index 00000000000..8a4aa00a6ea
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/json/toolkit.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+from typing import List
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.json.tool import (
+ JsonGetValueTool,
+ JsonListKeysTool,
+ JsonSpec,
+)
+
+
+class JsonToolkit(BaseToolkit):
+ """Toolkit for interacting with a JSON spec."""
+
+ spec: JsonSpec
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [
+ JsonListKeysTool(spec=self.spec),
+ JsonGetValueTool(spec=self.spec),
+ ]
diff --git a/libs/community/langchain_community/agent_toolkits/multion/__init__.py b/libs/community/langchain_community/agent_toolkits/multion/__init__.py
new file mode 100644
index 00000000000..56c7215199b
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/multion/__init__.py
@@ -0,0 +1 @@
+"""MultiOn Toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/multion/toolkit.py b/libs/community/langchain_community/agent_toolkits/multion/toolkit.py
new file mode 100644
index 00000000000..1d8f9029112
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/multion/toolkit.py
@@ -0,0 +1,33 @@
+"""MultiOn agent."""
+from __future__ import annotations
+
+from typing import List
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.multion.close_session import MultionCloseSession
+from langchain_community.tools.multion.create_session import MultionCreateSession
+from langchain_community.tools.multion.update_session import MultionUpdateSession
+
+
+class MultionToolkit(BaseToolkit):
+ """Toolkit for interacting with the Browser Agent.
+
+ **Security Note**: This toolkit contains tools that interact with the
+ user's browser via the multion API which grants an agent
+ access to the user's browser.
+
+ Please review the documentation for the multion API to understand
+ the security implications of using this toolkit.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ class Config:
+ """Pydantic config."""
+
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [MultionCreateSession(), MultionUpdateSession(), MultionCloseSession()]
diff --git a/libs/community/langchain_community/agent_toolkits/nasa/__init__.py b/libs/community/langchain_community/agent_toolkits/nasa/__init__.py
new file mode 100644
index 00000000000..a13c3ec706c
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/nasa/__init__.py
@@ -0,0 +1 @@
+"""NASA Toolkit"""
diff --git a/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py b/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py
new file mode 100644
index 00000000000..46edd98af3f
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py
@@ -0,0 +1,57 @@
+from typing import Dict, List
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.nasa.prompt import (
+ NASA_CAPTIONS_PROMPT,
+ NASA_MANIFEST_PROMPT,
+ NASA_METADATA_PROMPT,
+ NASA_SEARCH_PROMPT,
+)
+from langchain_community.tools.nasa.tool import NasaAction
+from langchain_community.utilities.nasa import NasaAPIWrapper
+
+
+class NasaToolkit(BaseToolkit):
+ """Nasa Toolkit."""
+
+ tools: List[BaseTool] = []
+
+ @classmethod
+ def from_nasa_api_wrapper(cls, nasa_api_wrapper: NasaAPIWrapper) -> "NasaToolkit":
+ operations: List[Dict] = [
+ {
+ "mode": "search_media",
+ "name": "Search NASA Image and Video Library media",
+ "description": NASA_SEARCH_PROMPT,
+ },
+ {
+ "mode": "get_media_metadata_manifest",
+ "name": "Get NASA Image and Video Library media metadata manifest",
+ "description": NASA_MANIFEST_PROMPT,
+ },
+ {
+ "mode": "get_media_metadata_location",
+ "name": "Get NASA Image and Video Library media metadata location",
+ "description": NASA_METADATA_PROMPT,
+ },
+ {
+ "mode": "get_video_captions_location",
+ "name": "Get NASA Image and Video Library video captions location",
+ "description": NASA_CAPTIONS_PROMPT,
+ },
+ ]
+ tools = [
+ NasaAction(
+ name=action["name"],
+ description=action["description"],
+ mode=action["mode"],
+ api_wrapper=nasa_api_wrapper,
+ )
+ for action in operations
+ ]
+ return cls(tools=tools)
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return self.tools
diff --git a/libs/langchain/tests/integration_tests/document_loaders/parsers/__init__.py b/libs/community/langchain_community/agent_toolkits/nla/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/document_loaders/parsers/__init__.py
rename to libs/community/langchain_community/agent_toolkits/nla/__init__.py
diff --git a/libs/community/langchain_community/agent_toolkits/nla/tool.py b/libs/community/langchain_community/agent_toolkits/nla/tool.py
new file mode 100644
index 00000000000..b947c7df977
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/nla/tool.py
@@ -0,0 +1,58 @@
+"""Tool for interacting with a single API with natural language definition."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Optional
+
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.tools import Tool
+
+from langchain_community.tools.openapi.utils.api_models import APIOperation
+from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec
+from langchain_community.utilities.requests import Requests
+
+if TYPE_CHECKING:
+ from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
+
+
+class NLATool(Tool):
+ """Natural Language API Tool."""
+
+ @classmethod
+ def from_open_api_endpoint_chain(
+ cls, chain: OpenAPIEndpointChain, api_title: str
+ ) -> "NLATool":
+ """Convert an endpoint chain to an API endpoint tool."""
+ expanded_name = (
+ f'{api_title.replace(" ", "_")}.{chain.api_operation.operation_id}'
+ )
+ description = (
+ f"I'm an AI from {api_title}. Instruct what you want,"
+ " and I'll assist via an API with description:"
+ f" {chain.api_operation.description}"
+ )
+ return cls(name=expanded_name, func=chain.run, description=description)
+
+ @classmethod
+ def from_llm_and_method(
+ cls,
+ llm: BaseLanguageModel,
+ path: str,
+ method: str,
+ spec: OpenAPISpec,
+ requests: Optional[Requests] = None,
+ verbose: bool = False,
+ return_intermediate_steps: bool = False,
+ **kwargs: Any,
+ ) -> "NLATool":
+ """Instantiate the tool from the specified path and method."""
+ api_operation = APIOperation.from_openapi_spec(spec, path, method)
+ chain = OpenAPIEndpointChain.from_api_operation(
+ api_operation,
+ llm,
+ requests=requests,
+ verbose=verbose,
+ return_intermediate_steps=return_intermediate_steps,
+ **kwargs,
+ )
+ return cls.from_open_api_endpoint_chain(chain, spec.info.title)
diff --git a/libs/community/langchain_community/agent_toolkits/nla/toolkit.py b/libs/community/langchain_community/agent_toolkits/nla/toolkit.py
new file mode 100644
index 00000000000..ae02aeb53bc
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/nla/toolkit.py
@@ -0,0 +1,127 @@
+from __future__ import annotations
+
+from typing import Any, List, Optional, Sequence
+
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.agent_toolkits.nla.tool import NLATool
+from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec
+from langchain_community.tools.plugin import AIPlugin
+from langchain_community.utilities.requests import Requests
+
+
+class NLAToolkit(BaseToolkit):
+ """Natural Language API Toolkit.
+
+ *Security Note*: This toolkit creates tools that enable making calls
+ to an Open API compliant API.
+
+ The tools created by this toolkit may be able to make GET, POST,
+ PATCH, PUT, DELETE requests to any of the exposed endpoints on
+ the API.
+
+ Control access to who can use this toolkit.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ nla_tools: Sequence[NLATool] = Field(...)
+ """List of API Endpoint Tools."""
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools for all the API operations."""
+ return list(self.nla_tools)
+
+ @staticmethod
+ def _get_http_operation_tools(
+ llm: BaseLanguageModel,
+ spec: OpenAPISpec,
+ requests: Optional[Requests] = None,
+ verbose: bool = False,
+ **kwargs: Any,
+ ) -> List[NLATool]:
+ """Get the tools for all the API operations."""
+ if not spec.paths:
+ return []
+ http_operation_tools = []
+ for path in spec.paths:
+ for method in spec.get_methods_for_path(path):
+ endpoint_tool = NLATool.from_llm_and_method(
+ llm=llm,
+ path=path,
+ method=method,
+ spec=spec,
+ requests=requests,
+ verbose=verbose,
+ **kwargs,
+ )
+ http_operation_tools.append(endpoint_tool)
+ return http_operation_tools
+
+ @classmethod
+ def from_llm_and_spec(
+ cls,
+ llm: BaseLanguageModel,
+ spec: OpenAPISpec,
+ requests: Optional[Requests] = None,
+ verbose: bool = False,
+ **kwargs: Any,
+ ) -> NLAToolkit:
+ """Instantiate the toolkit by creating tools for each operation."""
+ http_operation_tools = cls._get_http_operation_tools(
+ llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
+ )
+ return cls(nla_tools=http_operation_tools)
+
+ @classmethod
+ def from_llm_and_url(
+ cls,
+ llm: BaseLanguageModel,
+ open_api_url: str,
+ requests: Optional[Requests] = None,
+ verbose: bool = False,
+ **kwargs: Any,
+ ) -> NLAToolkit:
+ """Instantiate the toolkit from an OpenAPI Spec URL"""
+ spec = OpenAPISpec.from_url(open_api_url)
+ return cls.from_llm_and_spec(
+ llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
+ )
+
+ @classmethod
+ def from_llm_and_ai_plugin(
+ cls,
+ llm: BaseLanguageModel,
+ ai_plugin: AIPlugin,
+ requests: Optional[Requests] = None,
+ verbose: bool = False,
+ **kwargs: Any,
+ ) -> NLAToolkit:
+ """Instantiate the toolkit from an OpenAPI Spec URL"""
+ spec = OpenAPISpec.from_url(ai_plugin.api.url)
+ # TODO: Merge optional Auth information with the `requests` argument
+ return cls.from_llm_and_spec(
+ llm=llm,
+ spec=spec,
+ requests=requests,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_llm_and_ai_plugin_url(
+ cls,
+ llm: BaseLanguageModel,
+ ai_plugin_url: str,
+ requests: Optional[Requests] = None,
+ verbose: bool = False,
+ **kwargs: Any,
+ ) -> NLAToolkit:
+ """Instantiate the toolkit from an OpenAPI Spec URL"""
+ plugin = AIPlugin.from_url(ai_plugin_url)
+ return cls.from_llm_and_ai_plugin(
+ llm=llm, ai_plugin=plugin, requests=requests, verbose=verbose, **kwargs
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/office365/__init__.py b/libs/community/langchain_community/agent_toolkits/office365/__init__.py
new file mode 100644
index 00000000000..acd0a87f955
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/office365/__init__.py
@@ -0,0 +1 @@
+"""Office365 toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/office365/toolkit.py b/libs/community/langchain_community/agent_toolkits/office365/toolkit.py
new file mode 100644
index 00000000000..990264ce1db
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/office365/toolkit.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List
+
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.office365.create_draft_message import (
+ O365CreateDraftMessage,
+)
+from langchain_community.tools.office365.events_search import O365SearchEvents
+from langchain_community.tools.office365.messages_search import O365SearchEmails
+from langchain_community.tools.office365.send_event import O365SendEvent
+from langchain_community.tools.office365.send_message import O365SendMessage
+from langchain_community.tools.office365.utils import authenticate
+
+if TYPE_CHECKING:
+ from O365 import Account
+
+
+class O365Toolkit(BaseToolkit):
+ """Toolkit for interacting with Office 365.
+
+ *Security Note*: This toolkit contains tools that can read and modify
+ the state of a service; e.g., by reading, creating, updating, deleting
+ data associated with this service.
+
+ For example, this toolkit can be used search through emails and events,
+ send messages and event invites, and create draft messages.
+
+ Please make sure that the permissions given by this toolkit
+ are appropriate for your use case.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ account: Account = Field(default_factory=authenticate)
+
+ class Config:
+ """Pydantic config."""
+
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [
+ O365SearchEvents(),
+ O365CreateDraftMessage(),
+ O365SearchEmails(),
+ O365SendEvent(),
+ O365SendMessage(),
+ ]
diff --git a/libs/community/langchain_community/agent_toolkits/openapi/__init__.py b/libs/community/langchain_community/agent_toolkits/openapi/__init__.py
new file mode 100644
index 00000000000..5d06e271cd0
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/openapi/__init__.py
@@ -0,0 +1 @@
+"""OpenAPI spec agent."""
diff --git a/libs/community/langchain_community/agent_toolkits/openapi/base.py b/libs/community/langchain_community/agent_toolkits/openapi/base.py
new file mode 100644
index 00000000000..63a513811ea
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/openapi/base.py
@@ -0,0 +1,83 @@
+"""OpenAPI spec agent."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.language_models import BaseLanguageModel
+
+from langchain_community.agent_toolkits.openapi.prompt import (
+ OPENAPI_PREFIX,
+ OPENAPI_SUFFIX,
+)
+from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit
+
+if TYPE_CHECKING:
+ from langchain.agents.agent import AgentExecutor
+
+
+def create_openapi_agent(
+ llm: BaseLanguageModel,
+ toolkit: OpenAPIToolkit,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ prefix: str = OPENAPI_PREFIX,
+ suffix: str = OPENAPI_SUFFIX,
+ format_instructions: Optional[str] = None,
+ input_variables: Optional[List[str]] = None,
+ max_iterations: Optional[int] = 15,
+ max_execution_time: Optional[float] = None,
+ early_stopping_method: str = "force",
+ verbose: bool = False,
+ return_intermediate_steps: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct an OpenAPI agent from an LLM and tools.
+
+ *Security Note*: When creating an OpenAPI agent, check the permissions
+ and capabilities of the underlying toolkit.
+
+ For example, if the default implementation of OpenAPIToolkit
+ uses the RequestsToolkit which contains tools to make arbitrary
+ network requests against any URL (e.g., GET, POST, PATCH, PUT, DELETE),
+
+ Control access to who can submit issue requests using this toolkit and
+ what network access it has.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.chains.llm import LLMChain
+
+ tools = toolkit.get_tools()
+ prompt_params = (
+ {"format_instructions": format_instructions}
+ if format_instructions is not None
+ else {}
+ )
+ prompt = ZeroShotAgent.create_prompt(
+ tools,
+ prefix=prefix,
+ suffix=suffix,
+ input_variables=input_variables,
+ **prompt_params,
+ )
+ llm_chain = LLMChain(
+ llm=llm,
+ prompt=prompt,
+ callback_manager=callback_manager,
+ )
+ tool_names = [tool.name for tool in tools]
+ agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ verbose=verbose,
+ return_intermediate_steps=return_intermediate_steps,
+ max_iterations=max_iterations,
+ max_execution_time=max_execution_time,
+ early_stopping_method=early_stopping_method,
+ **(agent_executor_kwargs or {}),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/openapi/planner.py b/libs/community/langchain_community/agent_toolkits/openapi/planner.py
new file mode 100644
index 00000000000..e95e011e395
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/openapi/planner.py
@@ -0,0 +1,374 @@
+"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
+import json
+import re
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional
+
+import yaml
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.prompts import BasePromptTemplate, PromptTemplate
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool, Tool
+
+from langchain_community.agent_toolkits.openapi.planner_prompt import (
+ API_CONTROLLER_PROMPT,
+ API_CONTROLLER_TOOL_DESCRIPTION,
+ API_CONTROLLER_TOOL_NAME,
+ API_ORCHESTRATOR_PROMPT,
+ API_PLANNER_PROMPT,
+ API_PLANNER_TOOL_DESCRIPTION,
+ API_PLANNER_TOOL_NAME,
+ PARSING_DELETE_PROMPT,
+ PARSING_GET_PROMPT,
+ PARSING_PATCH_PROMPT,
+ PARSING_POST_PROMPT,
+ PARSING_PUT_PROMPT,
+ REQUESTS_DELETE_TOOL_DESCRIPTION,
+ REQUESTS_GET_TOOL_DESCRIPTION,
+ REQUESTS_PATCH_TOOL_DESCRIPTION,
+ REQUESTS_POST_TOOL_DESCRIPTION,
+ REQUESTS_PUT_TOOL_DESCRIPTION,
+)
+from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec
+from langchain_community.llms import OpenAI
+from langchain_community.tools.requests.tool import BaseRequestsTool
+from langchain_community.utilities.requests import RequestsWrapper
+
+#
+# Requests tools with LLM-instructed extraction of truncated responses.
+#
+# Of course, truncating so bluntly may lose a lot of valuable
+# information in the response.
+# However, the goal for now is to have only a single inference step.
+MAX_RESPONSE_LENGTH = 5000
+"""Maximum length of the response to be returned."""
+
+
+def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
+ from langchain.chains.llm import LLMChain
+
+ return LLMChain(
+ llm=OpenAI(),
+ prompt=prompt,
+ )
+
+
+def _get_default_llm_chain_factory(
+ prompt: BasePromptTemplate,
+) -> Callable[[], Any]:
+ """Returns a default LLMChain factory."""
+ return partial(_get_default_llm_chain, prompt)
+
+
+class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
+ """Requests GET tool with LLM-instructed extraction of truncated responses."""
+
+ name: str = "requests_get"
+ """Tool name."""
+ description = REQUESTS_GET_TOOL_DESCRIPTION
+ """Tool description."""
+ response_length: Optional[int] = MAX_RESPONSE_LENGTH
+ """Maximum length of the response to be returned."""
+ llm_chain: Any = Field(
+ default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
+ )
+ """LLMChain used to extract the response."""
+
+ def _run(self, text: str) -> str:
+ from langchain.output_parsers.json import parse_json_markdown
+
+ try:
+ data = parse_json_markdown(text)
+ except json.JSONDecodeError as e:
+ raise e
+ data_params = data.get("params")
+ response = self.requests_wrapper.get(data["url"], params=data_params)
+ response = response[: self.response_length]
+ return self.llm_chain.predict(
+ response=response, instructions=data["output_instructions"]
+ ).strip()
+
+ async def _arun(self, text: str) -> str:
+ raise NotImplementedError()
+
+
+class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
+ """Requests POST tool with LLM-instructed extraction of truncated responses."""
+
+ name: str = "requests_post"
+ """Tool name."""
+ description = REQUESTS_POST_TOOL_DESCRIPTION
+ """Tool description."""
+ response_length: Optional[int] = MAX_RESPONSE_LENGTH
+ """Maximum length of the response to be returned."""
+ llm_chain: Any = Field(
+ default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
+ )
+ """LLMChain used to extract the response."""
+
+ def _run(self, text: str) -> str:
+ from langchain.output_parsers.json import parse_json_markdown
+
+ try:
+ data = parse_json_markdown(text)
+ except json.JSONDecodeError as e:
+ raise e
+ response = self.requests_wrapper.post(data["url"], data["data"])
+ response = response[: self.response_length]
+ return self.llm_chain.predict(
+ response=response, instructions=data["output_instructions"]
+ ).strip()
+
+ async def _arun(self, text: str) -> str:
+ raise NotImplementedError()
+
+
+class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
+ """Requests PATCH tool with LLM-instructed extraction of truncated responses."""
+
+ name: str = "requests_patch"
+ """Tool name."""
+ description = REQUESTS_PATCH_TOOL_DESCRIPTION
+ """Tool description."""
+ response_length: Optional[int] = MAX_RESPONSE_LENGTH
+ """Maximum length of the response to be returned."""
+ llm_chain: Any = Field(
+ default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
+ )
+ """LLMChain used to extract the response."""
+
+ def _run(self, text: str) -> str:
+ from langchain.output_parsers.json import parse_json_markdown
+
+ try:
+ data = parse_json_markdown(text)
+ except json.JSONDecodeError as e:
+ raise e
+ response = self.requests_wrapper.patch(data["url"], data["data"])
+ response = response[: self.response_length]
+ return self.llm_chain.predict(
+ response=response, instructions=data["output_instructions"]
+ ).strip()
+
+ async def _arun(self, text: str) -> str:
+ raise NotImplementedError()
+
+
+class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
+ """Requests PUT tool with LLM-instructed extraction of truncated responses."""
+
+ name: str = "requests_put"
+ """Tool name."""
+ description = REQUESTS_PUT_TOOL_DESCRIPTION
+ """Tool description."""
+ response_length: Optional[int] = MAX_RESPONSE_LENGTH
+ """Maximum length of the response to be returned."""
+ llm_chain: Any = Field(
+ default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
+ )
+ """LLMChain used to extract the response."""
+
+ def _run(self, text: str) -> str:
+ from langchain.output_parsers.json import parse_json_markdown
+
+ try:
+ data = parse_json_markdown(text)
+ except json.JSONDecodeError as e:
+ raise e
+ response = self.requests_wrapper.put(data["url"], data["data"])
+ response = response[: self.response_length]
+ return self.llm_chain.predict(
+ response=response, instructions=data["output_instructions"]
+ ).strip()
+
+ async def _arun(self, text: str) -> str:
+ raise NotImplementedError()
+
+
+class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
+ """A tool that sends a DELETE request and parses the response."""
+
+ name: str = "requests_delete"
+ """The name of the tool."""
+ description = REQUESTS_DELETE_TOOL_DESCRIPTION
+ """The description of the tool."""
+
+ response_length: Optional[int] = MAX_RESPONSE_LENGTH
+ """The maximum length of the response."""
+ llm_chain: Any = Field(
+ default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
+ )
+ """The LLM chain used to parse the response."""
+
+ def _run(self, text: str) -> str:
+ from langchain.output_parsers.json import parse_json_markdown
+
+ try:
+ data = parse_json_markdown(text)
+ except json.JSONDecodeError as e:
+ raise e
+ response = self.requests_wrapper.delete(data["url"])
+ response = response[: self.response_length]
+ return self.llm_chain.predict(
+ response=response, instructions=data["output_instructions"]
+ ).strip()
+
+ async def _arun(self, text: str) -> str:
+ raise NotImplementedError()
+
+
+#
+# Orchestrator, planner, controller.
+#
+def _create_api_planner_tool(
+ api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
+) -> Tool:
+ from langchain.chains.llm import LLMChain
+
+ endpoint_descriptions = [
+ f"{name} {description}" for name, description, _ in api_spec.endpoints
+ ]
+ prompt = PromptTemplate(
+ template=API_PLANNER_PROMPT,
+ input_variables=["query"],
+ partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
+ )
+ chain = LLMChain(llm=llm, prompt=prompt)
+ tool = Tool(
+ name=API_PLANNER_TOOL_NAME,
+ description=API_PLANNER_TOOL_DESCRIPTION,
+ func=chain.run,
+ )
+ return tool
+
+
+def _create_api_controller_agent(
+ api_url: str,
+ api_docs: str,
+ requests_wrapper: RequestsWrapper,
+ llm: BaseLanguageModel,
+) -> Any:
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.chains.llm import LLMChain
+
+ get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
+ post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
+ tools: List[BaseTool] = [
+ RequestsGetToolWithParsing(
+ requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
+ ),
+ RequestsPostToolWithParsing(
+ requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
+ ),
+ ]
+ prompt = PromptTemplate(
+ template=API_CONTROLLER_PROMPT,
+ input_variables=["input", "agent_scratchpad"],
+ partial_variables={
+ "api_url": api_url,
+ "api_docs": api_docs,
+ "tool_names": ", ".join([tool.name for tool in tools]),
+ "tool_descriptions": "\n".join(
+ [f"{tool.name}: {tool.description}" for tool in tools]
+ ),
+ },
+ )
+ agent = ZeroShotAgent(
+ llm_chain=LLMChain(llm=llm, prompt=prompt),
+ allowed_tools=[tool.name for tool in tools],
+ )
+ return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
+
+
+def _create_api_controller_tool(
+ api_spec: ReducedOpenAPISpec,
+ requests_wrapper: RequestsWrapper,
+ llm: BaseLanguageModel,
+) -> Tool:
+ """Expose controller as a tool.
+
+ The tool is invoked with a plan from the planner, and dynamically
+ creates a controller agent with relevant documentation only to
+ constrain the context.
+ """
+
+ base_url = api_spec.servers[0]["url"] # TODO: do better.
+
+ def _create_and_run_api_controller_agent(plan_str: str) -> str:
+ pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
+ matches = re.findall(pattern, plan_str)
+ endpoint_names = [
+ "{method} {route}".format(method=method, route=route.split("?")[0])
+ for method, route in matches
+ ]
+ docs_str = ""
+ for endpoint_name in endpoint_names:
+ found_match = False
+ for name, _, docs in api_spec.endpoints:
+ regex_name = re.compile(re.sub("\{.*?\}", ".*", name))
+ if regex_name.match(endpoint_name):
+ found_match = True
+ docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
+ if not found_match:
+ raise ValueError(f"{endpoint_name} endpoint does not exist.")
+
+ agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
+ return agent.run(plan_str)
+
+ return Tool(
+ name=API_CONTROLLER_TOOL_NAME,
+ func=_create_and_run_api_controller_agent,
+ description=API_CONTROLLER_TOOL_DESCRIPTION,
+ )
+
+
+def create_openapi_agent(
+ api_spec: ReducedOpenAPISpec,
+ requests_wrapper: RequestsWrapper,
+ llm: BaseLanguageModel,
+ shared_memory: Optional[Any] = None,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ verbose: bool = True,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+) -> Any:
+ """Instantiate OpenAI API planner and controller for a given spec.
+
+ Inject credentials via requests_wrapper.
+
+ We use a top-level "orchestrator" agent to invoke the planner and controller,
+ rather than a top-level planner
+ that invokes a controller with its plan. This is to keep the planner simple.
+ """
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.chains.llm import LLMChain
+
+ tools = [
+ _create_api_planner_tool(api_spec, llm),
+ _create_api_controller_tool(api_spec, requests_wrapper, llm),
+ ]
+ prompt = PromptTemplate(
+ template=API_ORCHESTRATOR_PROMPT,
+ input_variables=["input", "agent_scratchpad"],
+ partial_variables={
+ "tool_names": ", ".join([tool.name for tool in tools]),
+ "tool_descriptions": "\n".join(
+ [f"{tool.name}: {tool.description}" for tool in tools]
+ ),
+ },
+ )
+ agent = ZeroShotAgent(
+ llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
+ allowed_tools=[tool.name for tool in tools],
+ **kwargs,
+ )
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ verbose=verbose,
+ **(agent_executor_kwargs or {}),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/openapi/planner_prompt.py b/libs/community/langchain_community/agent_toolkits/openapi/planner_prompt.py
new file mode 100644
index 00000000000..ec99e823e06
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/openapi/planner_prompt.py
@@ -0,0 +1,235 @@
+# flake8: noqa
+
+from langchain_core.prompts.prompt import PromptTemplate
+
+
+API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API.
+
+You should:
+1) evaluate whether the user query can be solved by the API documentated below. If no, say why.
+2) if yes, generate a plan of API calls and say what they are doing step by step.
+3) If the plan includes a DELETE call, you should always return an ask from the User for authorization first unless the User has specifically asked to delete something.
+
+You should only use API endpoints documented below ("Endpoints you can use:").
+You can only use the DELETE tool if the User has specifically asked to delete something. Otherwise, you should return a request authorization from the User first.
+Some user queries can be resolved in a single API call, but some will require several API calls.
+The plan will be passed to an API controller that can format it into web requests and return the responses.
+
+----
+
+Here are some examples:
+
+Fake endpoints for examples:
+GET /user to get information about the current user
+GET /products/search search across products
+POST /users/{{id}}/cart to add products to a user's cart
+PATCH /users/{{id}}/cart to update a user's cart
+PUT /users/{{id}}/coupon to apply idempotent coupon to a user's cart
+DELETE /users/{{id}}/cart to delete a user's cart
+
+User query: tell me a joke
+Plan: Sorry, this API's domain is shopping, not comedy.
+
+User query: I want to buy a couch
+Plan: 1. GET /products with a query param to search for couches
+2. GET /user to find the user's id
+3. POST /users/{{id}}/cart to add a couch to the user's cart
+
+User query: I want to add a lamp to my cart
+Plan: 1. GET /products with a query param to search for lamps
+2. GET /user to find the user's id
+3. PATCH /users/{{id}}/cart to add a lamp to the user's cart
+
+User query: I want to add a coupon to my cart
+Plan: 1. GET /user to find the user's id
+2. PUT /users/{{id}}/coupon to apply the coupon
+
+User query: I want to delete my cart
+Plan: 1. GET /user to find the user's id
+2. DELETE required. Did user specify DELETE or previously authorize? Yes, proceed.
+3. DELETE /users/{{id}}/cart to delete the user's cart
+
+User query: I want to start a new cart
+Plan: 1. GET /user to find the user's id
+2. DELETE required. Did user specify DELETE or previously authorize? No, ask for authorization.
+3. Are you sure you want to delete your cart?
+----
+
+Here are endpoints you can use. Do not reference any of the endpoints above.
+
+{endpoints}
+
+----
+
+User query: {query}
+Plan:"""
+API_PLANNER_TOOL_NAME = "api_planner"
+API_PLANNER_TOOL_DESCRIPTION = f"Can be used to generate the right API calls to assist with a user query, like {API_PLANNER_TOOL_NAME}(query). Should always be called before trying to call the API controller."
+
+# Execution.
+API_CONTROLLER_PROMPT = """You are an agent that gets a sequence of API calls and given their documentation, should execute them and return the final response.
+If you cannot complete them and run into issues, you should explain the issue. If you're unable to resolve an API call, you can retry the API call. When interacting with API objects, you should extract ids for inputs to other API calls but ids and names for outputs returned to the User.
+
+
+Here is documentation on the API:
+Base url: {api_url}
+Endpoints:
+{api_docs}
+
+
+Here are tools to execute requests against the API: {tool_descriptions}
+
+
+Starting below, you should follow this format:
+
+Plan: the plan of API calls to execute
+Thought: you should always think about what to do
+Action: the action to take, should be one of the tools [{tool_names}]
+Action Input: the input to the action
+Observation: the output of the action
+... (this Thought/Action/Action Input/Observation can repeat N times)
+Thought: I am finished executing the plan (or, I cannot finish executing the plan without knowing some other information.)
+Final Answer: the final output from executing the plan or missing information I'd need to re-plan correctly.
+
+
+Begin!
+
+Plan: {input}
+Thought:
+{agent_scratchpad}
+"""
+API_CONTROLLER_TOOL_NAME = "api_controller"
+API_CONTROLLER_TOOL_DESCRIPTION = f"Can be used to execute a plan of API calls, like {API_CONTROLLER_TOOL_NAME}(plan)."
+
+# Orchestrate planning + execution.
+# The goal is to have an agent at the top-level (e.g. so it can recover from errors and re-plan) while
+# keeping planning (and specifically the planning prompt) simple.
+API_ORCHESTRATOR_PROMPT = """You are an agent that assists with user queries against API, things like querying information or creating resources.
+Some user queries can be resolved in a single API call, particularly if you can find appropriate params from the OpenAPI spec; though some require several API calls.
+You should always plan your API calls first, and then execute the plan second.
+If the plan includes a DELETE call, be sure to ask the User for authorization first unless the User has specifically asked to delete something.
+You should never return information without executing the api_controller tool.
+
+
+Here are the tools to plan and execute API requests: {tool_descriptions}
+
+
+Starting below, you should follow this format:
+
+User query: the query a User wants help with related to the API
+Thought: you should always think about what to do
+Action: the action to take, should be one of the tools [{tool_names}]
+Action Input: the input to the action
+Observation: the result of the action
+... (this Thought/Action/Action Input/Observation can repeat N times)
+Thought: I am finished executing a plan and have the information the user asked for or the data the user asked to create
+Final Answer: the final output from executing the plan
+
+
+Example:
+User query: can you add some trendy stuff to my shopping cart.
+Thought: I should plan API calls first.
+Action: api_planner
+Action Input: I need to find the right API calls to add trendy items to the users shopping cart
+Observation: 1) GET /items with params 'trending' is 'True' to get trending item ids
+2) GET /user to get user
+3) POST /cart to post the trending items to the user's cart
+Thought: I'm ready to execute the API calls.
+Action: api_controller
+Action Input: 1) GET /items params 'trending' is 'True' to get trending item ids
+2) GET /user to get user
+3) POST /cart to post the trending items to the user's cart
+...
+
+Begin!
+
+User query: {input}
+Thought: I should generate a plan to help with this query and then copy that plan exactly to the controller.
+{agent_scratchpad}"""
+
+REQUESTS_GET_TOOL_DESCRIPTION = """Use this to GET content from a website.
+Input to the tool should be a json string with 3 keys: "url", "params" and "output_instructions".
+The value of "url" should be a string.
+The value of "params" should be a dict of the needed and available parameters from the OpenAPI spec related to the endpoint.
+If parameters are not needed, or not available, leave it empty.
+The value of "output_instructions" should be instructions on what information to extract from the response,
+for example the id(s) for a resource(s) that the GET request fetches.
+"""
+
+PARSING_GET_PROMPT = PromptTemplate(
+ template="""Here is an API response:\n\n{response}\n\n====
+Your task is to extract some information according to these instructions: {instructions}
+When working with API objects, you should usually use ids over names.
+If the response indicates an error, you should instead output a summary of the error.
+
+Output:""",
+ input_variables=["response", "instructions"],
+)
+
+REQUESTS_POST_TOOL_DESCRIPTION = """Use this when you want to POST to a website.
+Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
+The value of "url" should be a string.
+The value of "data" should be a dictionary of key-value pairs you want to POST to the url.
+The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the POST request creates.
+Always use double quotes for strings in the json string."""
+
+PARSING_POST_PROMPT = PromptTemplate(
+ template="""Here is an API response:\n\n{response}\n\n====
+Your task is to extract some information according to these instructions: {instructions}
+When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
+If the response indicates an error, you should instead output a summary of the error.
+
+Output:""",
+ input_variables=["response", "instructions"],
+)
+
+REQUESTS_PATCH_TOOL_DESCRIPTION = """Use this when you want to PATCH content on a website.
+Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
+The value of "url" should be a string.
+The value of "data" should be a dictionary of key-value pairs of the body params available in the OpenAPI spec you want to PATCH the content with at the url.
+The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PATCH request creates.
+Always use double quotes for strings in the json string."""
+
+PARSING_PATCH_PROMPT = PromptTemplate(
+ template="""Here is an API response:\n\n{response}\n\n====
+Your task is to extract some information according to these instructions: {instructions}
+When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
+If the response indicates an error, you should instead output a summary of the error.
+
+Output:""",
+ input_variables=["response", "instructions"],
+)
+
+REQUESTS_PUT_TOOL_DESCRIPTION = """Use this when you want to PUT to a website.
+Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
+The value of "url" should be a string.
+The value of "data" should be a dictionary of key-value pairs you want to PUT to the url.
+The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PUT request creates.
+Always use double quotes for strings in the json string."""
+
+PARSING_PUT_PROMPT = PromptTemplate(
+ template="""Here is an API response:\n\n{response}\n\n====
+Your task is to extract some information according to these instructions: {instructions}
+When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
+If the response indicates an error, you should instead output a summary of the error.
+
+Output:""",
+ input_variables=["response", "instructions"],
+)
+
+REQUESTS_DELETE_TOOL_DESCRIPTION = """ONLY USE THIS TOOL WHEN THE USER HAS SPECIFICALLY REQUESTED TO DELETE CONTENT FROM A WEBSITE.
+Input to the tool should be a json string with 2 keys: "url", and "output_instructions".
+The value of "url" should be a string.
+The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the DELETE request creates.
+Always use double quotes for strings in the json string.
+ONLY USE THIS TOOL IF THE USER HAS SPECIFICALLY REQUESTED TO DELETE SOMETHING."""
+
+PARSING_DELETE_PROMPT = PromptTemplate(
+ template="""Here is an API response:\n\n{response}\n\n====
+Your task is to extract some information according to these instructions: {instructions}
+When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
+If the response indicates an error, you should instead output a summary of the error.
+
+Output:""",
+ input_variables=["response", "instructions"],
+)
diff --git a/libs/community/langchain_community/agent_toolkits/openapi/prompt.py b/libs/community/langchain_community/agent_toolkits/openapi/prompt.py
new file mode 100644
index 00000000000..0484f5bf5d9
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/openapi/prompt.py
@@ -0,0 +1,29 @@
+# flake8: noqa
+
+OPENAPI_PREFIX = """You are an agent designed to answer questions by making web requests to an API given the openapi spec.
+
+If the question does not seem related to the API, return I don't know. Do not make up an answer.
+Only use information provided by the tools to construct your response.
+
+First, find the base URL needed to make the request.
+
+Second, find the relevant paths needed to answer the question. Take note that, sometimes, you might need to make more than one request to more than one path to answer the question.
+
+Third, find the required parameters needed to make the request. For GET requests, these are usually URL parameters and for POST requests, these are request body parameters.
+
+Fourth, make the requests needed to answer the question. Ensure that you are sending the correct parameters to the request by checking which parameters are required. For parameters with a fixed set of values, please use the spec to look at which values are allowed.
+
+Use the exact parameter names as listed in the spec, do not make up any names or abbreviate the names of parameters.
+If you get a not found error, ensure that you are using a path that actually exists in the spec.
+"""
+OPENAPI_SUFFIX = """Begin!
+
+Question: {input}
+Thought: I should explore the spec to find the base url for the API.
+{agent_scratchpad}"""
+
+DESCRIPTION = """Can be used to answer questions about the openapi spec for the API. Always use this tool before trying to make a request.
+Example inputs to this tool:
+ 'What are the required query parameters for a GET request to the /bar endpoint?`
+ 'What are the required parameters in the request body for a POST request to the /foo endpoint?'
+Always give this tool a specific question."""
diff --git a/libs/community/langchain_community/agent_toolkits/openapi/spec.py b/libs/community/langchain_community/agent_toolkits/openapi/spec.py
new file mode 100644
index 00000000000..29b529fc71e
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/openapi/spec.py
@@ -0,0 +1,75 @@
+"""Quick and dirty representation for OpenAPI specs."""
+
+from dataclasses import dataclass
+from typing import List, Tuple
+
+from langchain_core.utils.json_schema import dereference_refs
+
+
+@dataclass(frozen=True)
+class ReducedOpenAPISpec:
+ """A reduced OpenAPI spec.
+
+ This is a quick and dirty representation for OpenAPI specs.
+
+ Attributes:
+ servers: The servers in the spec.
+ description: The description of the spec.
+ endpoints: The endpoints in the spec.
+ """
+
+ servers: List[dict]
+ description: str
+ endpoints: List[Tuple[str, str, dict]]
+
+
+def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPISpec:
+ """Simplify/distill/minify a spec somehow.
+
+ I want a smaller target for retrieval and (more importantly)
+ I want smaller results from retrieval.
+ I was hoping https://openapi.tools/ would have some useful bits
+ to this end, but doesn't seem so.
+ """
+ # 1. Consider only get, post, patch, put, delete endpoints.
+ endpoints = [
+ (f"{operation_name.upper()} {route}", docs.get("description"), docs)
+ for route, operation in spec["paths"].items()
+ for operation_name, docs in operation.items()
+ if operation_name in ["get", "post", "patch", "put", "delete"]
+ ]
+
+ # 2. Replace any refs so that complete docs are retrieved.
+ # Note: probably want to do this post-retrieval, it blows up the size of the spec.
+ if dereference:
+ endpoints = [
+ (name, description, dereference_refs(docs, full_schema=spec))
+ for name, description, docs in endpoints
+ ]
+
+ # 3. Strip docs down to required request args + happy path response.
+ def reduce_endpoint_docs(docs: dict) -> dict:
+ out = {}
+ if docs.get("description"):
+ out["description"] = docs.get("description")
+ if docs.get("parameters"):
+ out["parameters"] = [
+ parameter
+ for parameter in docs.get("parameters", [])
+ if parameter.get("required")
+ ]
+ if "200" in docs["responses"]:
+ out["responses"] = docs["responses"]["200"]
+ if docs.get("requestBody"):
+ out["requestBody"] = docs.get("requestBody")
+ return out
+
+ endpoints = [
+ (name, description, reduce_endpoint_docs(docs))
+ for name, description, docs in endpoints
+ ]
+ return ReducedOpenAPISpec(
+ servers=spec["servers"],
+ description=spec["info"].get("description", ""),
+ endpoints=endpoints,
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py b/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py
new file mode 100644
index 00000000000..5b7f3fcbd52
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py
@@ -0,0 +1,90 @@
+"""Requests toolkit."""
+from __future__ import annotations
+
+from typing import Any, List
+
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.tools import Tool
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.agent_toolkits.json.base import create_json_agent
+from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
+from langchain_community.agent_toolkits.openapi.prompt import DESCRIPTION
+from langchain_community.tools import BaseTool
+from langchain_community.tools.json.tool import JsonSpec
+from langchain_community.tools.requests.tool import (
+ RequestsDeleteTool,
+ RequestsGetTool,
+ RequestsPatchTool,
+ RequestsPostTool,
+ RequestsPutTool,
+)
+from langchain_community.utilities.requests import TextRequestsWrapper
+
+
+class RequestsToolkit(BaseToolkit):
+ """Toolkit for making REST requests.
+
+ *Security Note*: This toolkit contains tools to make GET, POST, PATCH, PUT,
+ and DELETE requests to an API.
+
+ Exercise care in who is allowed to use this toolkit. If exposing
+ to end users, consider that users will be able to make arbitrary
+ requests on behalf of the server hosting the code. For example,
+ users could ask the server to make a request to a private API
+ that is only accessible from the server.
+
+ Control access to who can submit issue requests using this toolkit and
+ what network access it has.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ requests_wrapper: TextRequestsWrapper
+
+ def get_tools(self) -> List[BaseTool]:
+ """Return a list of tools."""
+ return [
+ RequestsGetTool(requests_wrapper=self.requests_wrapper),
+ RequestsPostTool(requests_wrapper=self.requests_wrapper),
+ RequestsPatchTool(requests_wrapper=self.requests_wrapper),
+ RequestsPutTool(requests_wrapper=self.requests_wrapper),
+ RequestsDeleteTool(requests_wrapper=self.requests_wrapper),
+ ]
+
+
+class OpenAPIToolkit(BaseToolkit):
+ """Toolkit for interacting with an OpenAPI API.
+
+ *Security Note*: This toolkit contains tools that can read and modify
+ the state of a service; e.g., by creating, deleting, or updating,
+ reading underlying data.
+
+ For example, this toolkit can be used to delete data exposed via
+ an OpenAPI compliant API.
+ """
+
+ json_agent: Any
+ requests_wrapper: TextRequestsWrapper
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ json_agent_tool = Tool(
+ name="json_explorer",
+ func=self.json_agent.run,
+ description=DESCRIPTION,
+ )
+ request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper)
+ return [*request_toolkit.get_tools(), json_agent_tool]
+
+ @classmethod
+ def from_llm(
+ cls,
+ llm: BaseLanguageModel,
+ json_spec: JsonSpec,
+ requests_wrapper: TextRequestsWrapper,
+ **kwargs: Any,
+ ) -> OpenAPIToolkit:
+ """Create json agent from llm, then initialize."""
+ json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs)
+ return cls(json_agent=json_agent, requests_wrapper=requests_wrapper)
diff --git a/libs/community/langchain_community/agent_toolkits/playwright/__init__.py b/libs/community/langchain_community/agent_toolkits/playwright/__init__.py
new file mode 100644
index 00000000000..7fc7f6d9950
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/playwright/__init__.py
@@ -0,0 +1,6 @@
+"""Playwright browser toolkit."""
+from langchain_community.agent_toolkits.playwright.toolkit import (
+ PlayWrightBrowserToolkit,
+)
+
+__all__ = ["PlayWrightBrowserToolkit"]
diff --git a/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py b/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py
new file mode 100644
index 00000000000..5a78e8c21a8
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py
@@ -0,0 +1,110 @@
+"""Playwright web browser toolkit."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Optional, Type, cast
+
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.tools import BaseTool
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools.playwright.base import (
+ BaseBrowserTool,
+ lazy_import_playwright_browsers,
+)
+from langchain_community.tools.playwright.click import ClickTool
+from langchain_community.tools.playwright.current_page import CurrentWebPageTool
+from langchain_community.tools.playwright.extract_hyperlinks import (
+ ExtractHyperlinksTool,
+)
+from langchain_community.tools.playwright.extract_text import ExtractTextTool
+from langchain_community.tools.playwright.get_elements import GetElementsTool
+from langchain_community.tools.playwright.navigate import NavigateTool
+from langchain_community.tools.playwright.navigate_back import NavigateBackTool
+
+if TYPE_CHECKING:
+ from playwright.async_api import Browser as AsyncBrowser
+ from playwright.sync_api import Browser as SyncBrowser
+else:
+ try:
+ # We do this so pydantic can resolve the types when instantiating
+ from playwright.async_api import Browser as AsyncBrowser
+ from playwright.sync_api import Browser as SyncBrowser
+ except ImportError:
+ pass
+
+
+class PlayWrightBrowserToolkit(BaseToolkit):
+ """Toolkit for PlayWright browser tools.
+
+ **Security Note**: This toolkit provides code to control a web-browser.
+
+ Careful if exposing this toolkit to end-users. The tools in the toolkit
+ are capable of navigating to arbitrary webpages, clicking on arbitrary
+ elements, and extracting arbitrary text and hyperlinks from webpages.
+
+ Specifically, by default this toolkit allows navigating to:
+
+ - Any URL (including any internal network URLs)
+ - And local files
+
+ If exposing to end-users, consider limiting network access to the
+ server that hosts the agent; in addition, consider it is advised
+ to create a custom NavigationTool wht an args_schema that limits the URLs
+ that can be navigated to (e.g., only allow navigating to URLs that
+ start with a particular prefix).
+
+ Remember to scope permissions to the minimal permissions necessary for
+ the application. If the default tool selection is not appropriate for
+ the application, consider creating a custom toolkit with the appropriate
+ tools.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ sync_browser: Optional["SyncBrowser"] = None
+ async_browser: Optional["AsyncBrowser"] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ @root_validator
+ def validate_imports_and_browser_provided(cls, values: dict) -> dict:
+ """Check that the arguments are valid."""
+ lazy_import_playwright_browsers()
+ if values.get("async_browser") is None and values.get("sync_browser") is None:
+ raise ValueError("Either async_browser or sync_browser must be specified.")
+ return values
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ tool_classes: List[Type[BaseBrowserTool]] = [
+ ClickTool,
+ NavigateTool,
+ NavigateBackTool,
+ ExtractTextTool,
+ ExtractHyperlinksTool,
+ GetElementsTool,
+ CurrentWebPageTool,
+ ]
+
+ tools = [
+ tool_cls.from_browser(
+ sync_browser=self.sync_browser, async_browser=self.async_browser
+ )
+ for tool_cls in tool_classes
+ ]
+ return cast(List[BaseTool], tools)
+
+ @classmethod
+ def from_browser(
+ cls,
+ sync_browser: Optional[SyncBrowser] = None,
+ async_browser: Optional[AsyncBrowser] = None,
+ ) -> PlayWrightBrowserToolkit:
+ """Instantiate the toolkit."""
+ # This is to raise a better error than the forward ref ones Pydantic would have
+ lazy_import_playwright_browsers()
+ return cls(sync_browser=sync_browser, async_browser=async_browser)
diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/__init__.py b/libs/community/langchain_community/agent_toolkits/powerbi/__init__.py
new file mode 100644
index 00000000000..42a9b09ac7e
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/powerbi/__init__.py
@@ -0,0 +1 @@
+"""Power BI agent."""
diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/base.py b/libs/community/langchain_community/agent_toolkits/powerbi/base.py
new file mode 100644
index 00000000000..06de2b97a7b
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/powerbi/base.py
@@ -0,0 +1,73 @@
+"""Power BI agent."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.language_models import BaseLanguageModel
+
+from langchain_community.agent_toolkits.powerbi.prompt import (
+ POWERBI_PREFIX,
+ POWERBI_SUFFIX,
+)
+from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
+from langchain_community.utilities.powerbi import PowerBIDataset
+
+if TYPE_CHECKING:
+ from langchain.agents import AgentExecutor
+
+
+def create_pbi_agent(
+ llm: BaseLanguageModel,
+ toolkit: Optional[PowerBIToolkit] = None,
+ powerbi: Optional[PowerBIDataset] = None,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ prefix: str = POWERBI_PREFIX,
+ suffix: str = POWERBI_SUFFIX,
+ format_instructions: Optional[str] = None,
+ examples: Optional[str] = None,
+ input_variables: Optional[List[str]] = None,
+ top_k: int = 10,
+ verbose: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct a Power BI agent from an LLM and tools."""
+ from langchain.agents import AgentExecutor
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.chains.llm import LLMChain
+
+ if toolkit is None:
+ if powerbi is None:
+ raise ValueError("Must provide either a toolkit or powerbi dataset")
+ toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
+ tools = toolkit.get_tools()
+ tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
+ prompt_params = (
+ {"format_instructions": format_instructions}
+ if format_instructions is not None
+ else {}
+ )
+ agent = ZeroShotAgent(
+ llm_chain=LLMChain(
+ llm=llm,
+ prompt=ZeroShotAgent.create_prompt(
+ tools,
+ prefix=prefix.format(top_k=top_k).format(tables=tables),
+ suffix=suffix,
+ input_variables=input_variables,
+ **prompt_params,
+ ),
+ callback_manager=callback_manager, # type: ignore
+ verbose=verbose,
+ ),
+ allowed_tools=[tool.name for tool in tools],
+ **kwargs,
+ )
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ verbose=verbose,
+ **(agent_executor_kwargs or {}),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py b/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py
new file mode 100644
index 00000000000..acad61b4422
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py
@@ -0,0 +1,71 @@
+"""Power BI agent."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.language_models.chat_models import BaseChatModel
+
+from langchain_community.agent_toolkits.powerbi.prompt import (
+ POWERBI_CHAT_PREFIX,
+ POWERBI_CHAT_SUFFIX,
+)
+from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
+from langchain_community.utilities.powerbi import PowerBIDataset
+
+if TYPE_CHECKING:
+ from langchain.agents import AgentExecutor
+ from langchain.agents.agent import AgentOutputParser
+ from langchain.memory.chat_memory import BaseChatMemory
+
+
+def create_pbi_chat_agent(
+ llm: BaseChatModel,
+ toolkit: Optional[PowerBIToolkit] = None,
+ powerbi: Optional[PowerBIDataset] = None,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ output_parser: Optional[AgentOutputParser] = None,
+ prefix: str = POWERBI_CHAT_PREFIX,
+ suffix: str = POWERBI_CHAT_SUFFIX,
+ examples: Optional[str] = None,
+ input_variables: Optional[List[str]] = None,
+ memory: Optional[BaseChatMemory] = None,
+ top_k: int = 10,
+ verbose: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct a Power BI agent from a Chat LLM and tools.
+
+ If you supply only a toolkit and no Power BI dataset, the same LLM is used for both.
+ """
+ from langchain.agents import AgentExecutor
+ from langchain.agents.conversational_chat.base import ConversationalChatAgent
+ from langchain.memory import ConversationBufferMemory
+
+ if toolkit is None:
+ if powerbi is None:
+ raise ValueError("Must provide either a toolkit or powerbi dataset")
+ toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
+ tools = toolkit.get_tools()
+ tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
+ agent = ConversationalChatAgent.from_llm_and_tools(
+ llm=llm,
+ tools=tools,
+ system_message=prefix.format(top_k=top_k).format(tables=tables),
+ human_message=suffix,
+ input_variables=input_variables,
+ callback_manager=callback_manager,
+ output_parser=output_parser,
+ verbose=verbose,
+ **kwargs,
+ )
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ memory=memory
+ or ConversationBufferMemory(memory_key="chat_history", return_messages=True),
+ verbose=verbose,
+ **(agent_executor_kwargs or {}),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/prompt.py b/libs/community/langchain_community/agent_toolkits/powerbi/prompt.py
new file mode 100644
index 00000000000..673a6bed29b
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/powerbi/prompt.py
@@ -0,0 +1,38 @@
+# flake8: noqa
+"""Prompts for PowerBI agent."""
+
+
+POWERBI_PREFIX = """You are an agent designed to help users interact with a PowerBI Dataset.
+
+Agent has access to a tool that can write a query based on the question and then run those against PowerBI, Microsofts business intelligence tool. The questions from the users should be interpreted as related to the dataset that is available and not general questions about the world. If the question does not seem related to the dataset, return "This does not appear to be part of this dataset." as the answer.
+
+Given an input question, ask to run the questions against the dataset, then look at the results and return the answer, the answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
+"""
+
+POWERBI_SUFFIX = """Begin!
+
+Question: {input}
+Thought: I can first ask which tables I have, then how each table is defined and then ask the query tool the question I need, and finally create a nice sentence that answers the question.
+{agent_scratchpad}"""
+
+POWERBI_CHAT_PREFIX = """Assistant is a large language model built to help users interact with a PowerBI Dataset.
+
+Assistant should try to create a correct and complete answer to the question from the user. If the user asks a question not related to the dataset it should return "This does not appear to be part of this dataset." as the answer. The user might make a mistake with the spelling of certain values, if you think that is the case, ask the user to confirm the spelling of the value and then run the query again. Unless the user specifies a specific number of examples they wish to obtain, and the results are too large, limit your query to at most {top_k} results, but make it clear when answering which field was used for the filtering. The user has access to these tables: {{tables}}.
+
+The answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000.
+"""
+
+POWERBI_CHAT_SUFFIX = """TOOLS
+------
+Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:
+
+{{tools}}
+
+{format_instructions}
+
+USER'S INPUT
+--------------------
+Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
+
+{{{{input}}}}
+"""
diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py b/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py
new file mode 100644
index 00000000000..1a89d77ad2a
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py
@@ -0,0 +1,108 @@
+"""Toolkit for interacting with a Power BI dataset."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Optional, Union
+
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.prompts import PromptTemplate
+from langchain_core.prompts.chat import (
+ ChatPromptTemplate,
+ HumanMessagePromptTemplate,
+ SystemMessagePromptTemplate,
+)
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.powerbi.prompt import (
+ QUESTION_TO_QUERY_BASE,
+ SINGLE_QUESTION_TO_QUERY,
+ USER_INPUT,
+)
+from langchain_community.tools.powerbi.tool import (
+ InfoPowerBITool,
+ ListPowerBITool,
+ QueryPowerBITool,
+)
+from langchain_community.utilities.powerbi import PowerBIDataset
+
+if TYPE_CHECKING:
+ from langchain.chains.llm import LLMChain
+
+
+class PowerBIToolkit(BaseToolkit):
+ """Toolkit for interacting with Power BI dataset.
+
+ *Security Note*: This toolkit interacts with an external service.
+
+ Control access to who can use this toolkit.
+
+ Make sure that the capabilities given by this toolkit to the calling
+ code are appropriately scoped to the application.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ powerbi: PowerBIDataset = Field(exclude=True)
+ llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True)
+ examples: Optional[str] = None
+ max_iterations: int = 5
+ callback_manager: Optional[BaseCallbackManager] = None
+ output_token_limit: Optional[int] = None
+ tiktoken_model_name: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [
+ QueryPowerBITool(
+ llm_chain=self._get_chain(),
+ powerbi=self.powerbi,
+ examples=self.examples,
+ max_iterations=self.max_iterations,
+ output_token_limit=self.output_token_limit,
+ tiktoken_model_name=self.tiktoken_model_name,
+ ),
+ InfoPowerBITool(powerbi=self.powerbi),
+ ListPowerBITool(powerbi=self.powerbi),
+ ]
+
+ def _get_chain(self) -> LLMChain:
+ """Construct the chain based on the callback manager and model type."""
+ from langchain.chains.llm import LLMChain
+
+ if isinstance(self.llm, BaseLanguageModel):
+ return LLMChain(
+ llm=self.llm,
+ callback_manager=self.callback_manager
+ if self.callback_manager
+ else None,
+ prompt=PromptTemplate(
+ template=SINGLE_QUESTION_TO_QUERY,
+ input_variables=["tool_input", "tables", "schemas", "examples"],
+ ),
+ )
+
+ system_prompt = SystemMessagePromptTemplate(
+ prompt=PromptTemplate(
+ template=QUESTION_TO_QUERY_BASE,
+ input_variables=["tables", "schemas", "examples"],
+ )
+ )
+ human_prompt = HumanMessagePromptTemplate(
+ prompt=PromptTemplate(
+ template=USER_INPUT,
+ input_variables=["tool_input"],
+ )
+ )
+ return LLMChain(
+ llm=self.llm,
+ callback_manager=self.callback_manager if self.callback_manager else None,
+ prompt=ChatPromptTemplate.from_messages([system_prompt, human_prompt]),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/slack/__init__.py b/libs/community/langchain_community/agent_toolkits/slack/__init__.py
new file mode 100644
index 00000000000..1ec5ae704ce
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/slack/__init__.py
@@ -0,0 +1 @@
+"""Slack toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/slack/toolkit.py b/libs/community/langchain_community/agent_toolkits/slack/toolkit.py
new file mode 100644
index 00000000000..bc0f09ff1db
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/slack/toolkit.py
@@ -0,0 +1,36 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List
+
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.slack.get_channel import SlackGetChannel
+from langchain_community.tools.slack.get_message import SlackGetMessage
+from langchain_community.tools.slack.schedule_message import SlackScheduleMessage
+from langchain_community.tools.slack.send_message import SlackSendMessage
+from langchain_community.tools.slack.utils import login
+
+if TYPE_CHECKING:
+ from slack_sdk import WebClient
+
+
+class SlackToolkit(BaseToolkit):
+ """Toolkit for interacting with Slack."""
+
+ client: WebClient = Field(default_factory=login)
+
+ class Config:
+ """Pydantic config."""
+
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [
+ SlackGetChannel(),
+ SlackGetMessage(),
+ SlackScheduleMessage(),
+ SlackSendMessage(),
+ ]
diff --git a/libs/community/langchain_community/agent_toolkits/spark_sql/__init__.py b/libs/community/langchain_community/agent_toolkits/spark_sql/__init__.py
new file mode 100644
index 00000000000..4308c079443
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/spark_sql/__init__.py
@@ -0,0 +1 @@
+"""Spark SQL agent."""
diff --git a/libs/community/langchain_community/agent_toolkits/spark_sql/base.py b/libs/community/langchain_community/agent_toolkits/spark_sql/base.py
new file mode 100644
index 00000000000..dca019b0425
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/spark_sql/base.py
@@ -0,0 +1,70 @@
+"""Spark SQL agent."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.callbacks import BaseCallbackManager, Callbacks
+from langchain_core.language_models import BaseLanguageModel
+
+from langchain_community.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
+from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
+
+if TYPE_CHECKING:
+ from langchain.agents.agent import AgentExecutor
+
+
+def create_spark_sql_agent(
+ llm: BaseLanguageModel,
+ toolkit: SparkSQLToolkit,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ callbacks: Callbacks = None,
+ prefix: str = SQL_PREFIX,
+ suffix: str = SQL_SUFFIX,
+ format_instructions: Optional[str] = None,
+ input_variables: Optional[List[str]] = None,
+ top_k: int = 10,
+ max_iterations: Optional[int] = 15,
+ max_execution_time: Optional[float] = None,
+ early_stopping_method: str = "force",
+ verbose: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct a Spark SQL agent from an LLM and tools."""
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.chains.llm import LLMChain
+
+ tools = toolkit.get_tools()
+ prefix = prefix.format(top_k=top_k)
+ prompt_params = (
+ {"format_instructions": format_instructions}
+ if format_instructions is not None
+ else {}
+ )
+ prompt = ZeroShotAgent.create_prompt(
+ tools,
+ prefix=prefix,
+ suffix=suffix,
+ input_variables=input_variables,
+ **prompt_params,
+ )
+ llm_chain = LLMChain(
+ llm=llm,
+ prompt=prompt,
+ callback_manager=callback_manager,
+ callbacks=callbacks,
+ )
+ tool_names = [tool.name for tool in tools]
+ agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ callbacks=callbacks,
+ verbose=verbose,
+ max_iterations=max_iterations,
+ max_execution_time=max_execution_time,
+ early_stopping_method=early_stopping_method,
+ **(agent_executor_kwargs or {}),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/spark_sql/prompt.py b/libs/community/langchain_community/agent_toolkits/spark_sql/prompt.py
new file mode 100644
index 00000000000..b499085d3fe
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/spark_sql/prompt.py
@@ -0,0 +1,21 @@
+# flake8: noqa
+
+SQL_PREFIX = """You are an agent designed to interact with Spark SQL.
+Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
+Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
+You can order the results by a relevant column to return the most interesting examples in the database.
+Never query for all the columns from a specific table, only ask for the relevant columns given the question.
+You have access to tools for interacting with the database.
+Only use the below tools. Only use the information returned by the below tools to construct your final answer.
+You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
+
+DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
+
+If the question does not seem related to the database, just return "I don't know" as the answer.
+"""
+
+SQL_SUFFIX = """Begin!
+
+Question: {input}
+Thought: I should look at the tables in the database to see what I can query.
+{agent_scratchpad}"""
diff --git a/libs/community/langchain_community/agent_toolkits/spark_sql/toolkit.py b/libs/community/langchain_community/agent_toolkits/spark_sql/toolkit.py
new file mode 100644
index 00000000000..fd0363132d0
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/spark_sql/toolkit.py
@@ -0,0 +1,36 @@
+"""Toolkit for interacting with Spark SQL."""
+from typing import List
+
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.spark_sql.tool import (
+ InfoSparkSQLTool,
+ ListSparkSQLTool,
+ QueryCheckerTool,
+ QuerySparkSQLTool,
+)
+from langchain_community.utilities.spark_sql import SparkSQL
+
+
+class SparkSQLToolkit(BaseToolkit):
+ """Toolkit for interacting with Spark SQL."""
+
+ db: SparkSQL = Field(exclude=True)
+ llm: BaseLanguageModel = Field(exclude=True)
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return [
+ QuerySparkSQLTool(db=self.db),
+ InfoSparkSQLTool(db=self.db),
+ ListSparkSQLTool(db=self.db),
+ QueryCheckerTool(db=self.db, llm=self.llm),
+ ]
diff --git a/libs/community/langchain_community/agent_toolkits/sql/__init__.py b/libs/community/langchain_community/agent_toolkits/sql/__init__.py
new file mode 100644
index 00000000000..74293a52391
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/sql/__init__.py
@@ -0,0 +1 @@
+"""SQL agent."""
diff --git a/libs/community/langchain_community/agent_toolkits/sql/base.py b/libs/community/langchain_community/agent_toolkits/sql/base.py
new file mode 100644
index 00000000000..c2451f7c231
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/sql/base.py
@@ -0,0 +1,108 @@
+"""SQL agent."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
+
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.messages import AIMessage, SystemMessage
+from langchain_core.prompts.chat import (
+ ChatPromptTemplate,
+ HumanMessagePromptTemplate,
+ MessagesPlaceholder,
+)
+
+from langchain_community.agent_toolkits.sql.prompt import (
+ SQL_FUNCTIONS_SUFFIX,
+ SQL_PREFIX,
+ SQL_SUFFIX,
+)
+from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
+from langchain_community.tools import BaseTool
+
+if TYPE_CHECKING:
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.agent_types import AgentType
+
+
+def create_sql_agent(
+ llm: BaseLanguageModel,
+ toolkit: SQLDatabaseToolkit,
+ agent_type: Optional[AgentType] = None,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ prefix: str = SQL_PREFIX,
+ suffix: Optional[str] = None,
+ format_instructions: Optional[str] = None,
+ input_variables: Optional[List[str]] = None,
+ top_k: int = 10,
+ max_iterations: Optional[int] = 15,
+ max_execution_time: Optional[float] = None,
+ early_stopping_method: str = "force",
+ verbose: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ extra_tools: Sequence[BaseTool] = (),
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct an SQL agent from an LLM and tools."""
+ from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
+ from langchain.agents.agent_types import AgentType
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
+ from langchain.chains.llm import LLMChain
+
+ agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
+ tools = toolkit.get_tools() + list(extra_tools)
+ prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
+ agent: BaseSingleActionAgent
+
+ if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
+ prompt_params = (
+ {"format_instructions": format_instructions}
+ if format_instructions is not None
+ else {}
+ )
+ prompt = ZeroShotAgent.create_prompt(
+ tools,
+ prefix=prefix,
+ suffix=suffix or SQL_SUFFIX,
+ input_variables=input_variables,
+ **prompt_params,
+ )
+ llm_chain = LLMChain(
+ llm=llm,
+ prompt=prompt,
+ callback_manager=callback_manager,
+ )
+ tool_names = [tool.name for tool in tools]
+ agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
+
+ elif agent_type == AgentType.OPENAI_FUNCTIONS:
+ messages = [
+ SystemMessage(content=prefix),
+ HumanMessagePromptTemplate.from_template("{input}"),
+ AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ]
+ input_variables = ["input", "agent_scratchpad"]
+ _prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
+
+ agent = OpenAIFunctionsAgent(
+ llm=llm,
+ prompt=_prompt,
+ tools=tools,
+ callback_manager=callback_manager,
+ **kwargs,
+ )
+ else:
+ raise ValueError(f"Agent type {agent_type} not supported at the moment.")
+
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ verbose=verbose,
+ max_iterations=max_iterations,
+ max_execution_time=max_execution_time,
+ early_stopping_method=early_stopping_method,
+ **(agent_executor_kwargs or {}),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/sql/prompt.py b/libs/community/langchain_community/agent_toolkits/sql/prompt.py
new file mode 100644
index 00000000000..92464da4b9b
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/sql/prompt.py
@@ -0,0 +1,23 @@
+# flake8: noqa
+
+SQL_PREFIX = """You are an agent designed to interact with a SQL database.
+Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
+Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
+You can order the results by a relevant column to return the most interesting examples in the database.
+Never query for all the columns from a specific table, only ask for the relevant columns given the question.
+You have access to tools for interacting with the database.
+Only use the below tools. Only use the information returned by the below tools to construct your final answer.
+You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
+
+DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
+
+If the question does not seem related to the database, just return "I don't know" as the answer.
+"""
+
+SQL_SUFFIX = """Begin!
+
+Question: {input}
+Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
+{agent_scratchpad}"""
+
+SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
diff --git a/libs/community/langchain_community/agent_toolkits/sql/toolkit.py b/libs/community/langchain_community/agent_toolkits/sql/toolkit.py
new file mode 100644
index 00000000000..382fcbb4856
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/sql/toolkit.py
@@ -0,0 +1,71 @@
+"""Toolkit for interacting with an SQL database."""
+from typing import List
+
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.sql_database.tool import (
+ InfoSQLDatabaseTool,
+ ListSQLDatabaseTool,
+ QuerySQLCheckerTool,
+ QuerySQLDataBaseTool,
+)
+from langchain_community.utilities.sql_database import SQLDatabase
+
+
+class SQLDatabaseToolkit(BaseToolkit):
+ """Toolkit for interacting with SQL databases."""
+
+ db: SQLDatabase = Field(exclude=True)
+ llm: BaseLanguageModel = Field(exclude=True)
+
+ @property
+ def dialect(self) -> str:
+ """Return string representation of SQL dialect to use."""
+ return self.db.dialect
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
+ info_sql_database_tool_description = (
+ "Input to this tool is a comma-separated list of tables, output is the "
+ "schema and sample rows for those tables. "
+ "Be sure that the tables actually exist by calling "
+ f"{list_sql_database_tool.name} first! "
+ "Example Input: table1, table2, table3"
+ )
+ info_sql_database_tool = InfoSQLDatabaseTool(
+ db=self.db, description=info_sql_database_tool_description
+ )
+ query_sql_database_tool_description = (
+ "Input to this tool is a detailed and correct SQL query, output is a "
+ "result from the database. If the query is not correct, an error message "
+ "will be returned. If an error is returned, rewrite the query, check the "
+ "query, and try again. If you encounter an issue with Unknown column "
+ f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
+ "to query the correct table fields."
+ )
+ query_sql_database_tool = QuerySQLDataBaseTool(
+ db=self.db, description=query_sql_database_tool_description
+ )
+ query_sql_checker_tool_description = (
+ "Use this tool to double check if your query is correct before executing "
+ "it. Always use this tool before executing a query with "
+ f"{query_sql_database_tool.name}!"
+ )
+ query_sql_checker_tool = QuerySQLCheckerTool(
+ db=self.db, llm=self.llm, description=query_sql_checker_tool_description
+ )
+ return [
+ query_sql_database_tool,
+ info_sql_database_tool,
+ list_sql_database_tool,
+ query_sql_checker_tool,
+ ]
diff --git a/libs/community/langchain_community/agent_toolkits/steam/__init__.py b/libs/community/langchain_community/agent_toolkits/steam/__init__.py
new file mode 100644
index 00000000000..f9998108242
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/steam/__init__.py
@@ -0,0 +1 @@
+"""Steam Toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/steam/toolkit.py b/libs/community/langchain_community/agent_toolkits/steam/toolkit.py
new file mode 100644
index 00000000000..1fe57b032bf
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/steam/toolkit.py
@@ -0,0 +1,48 @@
+"""Steam Toolkit."""
+from typing import List
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.steam.prompt import (
+ STEAM_GET_GAMES_DETAILS,
+ STEAM_GET_RECOMMENDED_GAMES,
+)
+from langchain_community.tools.steam.tool import SteamWebAPIQueryRun
+from langchain_community.utilities.steam import SteamWebAPIWrapper
+
+
+class SteamToolkit(BaseToolkit):
+ """Steam Toolkit."""
+
+ tools: List[BaseTool] = []
+
+ @classmethod
+ def from_steam_api_wrapper(
+ cls, steam_api_wrapper: SteamWebAPIWrapper
+ ) -> "SteamToolkit":
+ operations: List[dict] = [
+ {
+ "mode": "get_games_details",
+ "name": "Get Games Details",
+ "description": STEAM_GET_GAMES_DETAILS,
+ },
+ {
+ "mode": "get_recommended_games",
+ "name": "Get Recommended Games",
+ "description": STEAM_GET_RECOMMENDED_GAMES,
+ },
+ ]
+ tools = [
+ SteamWebAPIQueryRun(
+ name=action["name"],
+ description=action["description"],
+ mode=action["mode"],
+ api_wrapper=steam_api_wrapper,
+ )
+ for action in operations
+ ]
+ return cls(tools=tools)
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ return self.tools
diff --git a/libs/community/langchain_community/agent_toolkits/vectorstore/base.py b/libs/community/langchain_community/agent_toolkits/vectorstore/base.py
new file mode 100644
index 00000000000..7354dfb0df2
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/vectorstore/base.py
@@ -0,0 +1,106 @@
+"""VectorStore agent."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.language_models import BaseLanguageModel
+
+from langchain_community.agent_toolkits.vectorstore.prompt import PREFIX, ROUTER_PREFIX
+from langchain_community.agent_toolkits.vectorstore.toolkit import (
+ VectorStoreRouterToolkit,
+ VectorStoreToolkit,
+)
+
+if TYPE_CHECKING:
+ from langchain.agents.agent import AgentExecutor
+
+
+def create_vectorstore_agent(
+ llm: BaseLanguageModel,
+ toolkit: VectorStoreToolkit,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ prefix: str = PREFIX,
+ verbose: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct a VectorStore agent from an LLM and tools.
+
+ Args:
+ llm (BaseLanguageModel): LLM that will be used by the agent
+ toolkit (VectorStoreToolkit): Set of tools for the agent
+ callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ]
+ prefix (str, optional): The prefix prompt for the agent. If not provided uses default PREFIX.
+ verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ]
+ agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ]
+ **kwargs: Additional named parameters to pass to the ZeroShotAgent.
+
+ Returns:
+ AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response
+ """ # noqa: E501
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.chains.llm import LLMChain
+
+ tools = toolkit.get_tools()
+ prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
+ llm_chain = LLMChain(
+ llm=llm,
+ prompt=prompt,
+ callback_manager=callback_manager,
+ )
+ tool_names = [tool.name for tool in tools]
+ agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ verbose=verbose,
+ **(agent_executor_kwargs or {}),
+ )
+
+
+def create_vectorstore_router_agent(
+ llm: BaseLanguageModel,
+ toolkit: VectorStoreRouterToolkit,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ prefix: str = ROUTER_PREFIX,
+ verbose: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct a VectorStore router agent from an LLM and tools.
+
+ Args:
+ llm (BaseLanguageModel): LLM that will be used by the agent
+ toolkit (VectorStoreRouterToolkit): Set of tools for the agent which have routing capability with multiple vector stores
+ callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ]
+ prefix (str, optional): The prefix prompt for the router agent. If not provided uses default ROUTER_PREFIX.
+ verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ]
+ agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ]
+ **kwargs: Additional named parameters to pass to the ZeroShotAgent.
+
+ Returns:
+ AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response.
+ """ # noqa: E501
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.mrkl.base import ZeroShotAgent
+ from langchain.chains.llm import LLMChain
+
+ tools = toolkit.get_tools()
+ prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
+ llm_chain = LLMChain(
+ llm=llm,
+ prompt=prompt,
+ callback_manager=callback_manager,
+ )
+ tool_names = [tool.name for tool in tools]
+ agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
+ return AgentExecutor.from_agent_and_tools(
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ verbose=verbose,
+ **(agent_executor_kwargs or {}),
+ )
diff --git a/libs/community/langchain_community/agent_toolkits/xorbits/__init__.py b/libs/community/langchain_community/agent_toolkits/xorbits/__init__.py
new file mode 100644
index 00000000000..fd8fc13ba0d
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/xorbits/__init__.py
@@ -0,0 +1,26 @@
+from pathlib import Path
+from typing import Any
+
+from langchain_core._api.path import as_import_path
+
+
+def __getattr__(name: str) -> Any:
+ """Get attr name."""
+
+ if name == "create_xorbits_agent":
+ # Get directory of langchain package
+ HERE = Path(__file__).parents[3]
+ here = as_import_path(Path(__file__).parent, relative_to=HERE)
+
+ old_path = "langchain." + here + "." + name
+ new_path = "langchain_experimental." + here + "." + name
+ raise ImportError(
+ "This agent has been moved to langchain experiment. "
+ "This agent relies on python REPL tool under the hood, so to use it "
+ "safely please sandbox the python REPL. "
+ "Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
+ "and https://github.com/langchain-ai/langchain/discussions/11680"
+ "To keep using this code as is, install langchain experimental and "
+ f"update your import statement from:\n `{old_path}` to `{new_path}`."
+ )
+ raise AttributeError(f"{name} does not exist")
diff --git a/libs/community/langchain_community/agent_toolkits/zapier/__init__.py b/libs/community/langchain_community/agent_toolkits/zapier/__init__.py
new file mode 100644
index 00000000000..faef4a32531
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/zapier/__init__.py
@@ -0,0 +1 @@
+"""Zapier Toolkit."""
diff --git a/libs/community/langchain_community/agent_toolkits/zapier/toolkit.py b/libs/community/langchain_community/agent_toolkits/zapier/toolkit.py
new file mode 100644
index 00000000000..f5f335efdc8
--- /dev/null
+++ b/libs/community/langchain_community/agent_toolkits/zapier/toolkit.py
@@ -0,0 +1,60 @@
+"""[DEPRECATED] Zapier Toolkit."""
+from typing import List
+
+from langchain_core._api import warn_deprecated
+
+from langchain_community.agent_toolkits.base import BaseToolkit
+from langchain_community.tools import BaseTool
+from langchain_community.tools.zapier.tool import ZapierNLARunAction
+from langchain_community.utilities.zapier import ZapierNLAWrapper
+
+
+class ZapierToolkit(BaseToolkit):
+ """Zapier Toolkit."""
+
+ tools: List[BaseTool] = []
+
+ @classmethod
+ def from_zapier_nla_wrapper(
+ cls, zapier_nla_wrapper: ZapierNLAWrapper
+ ) -> "ZapierToolkit":
+ """Create a toolkit from a ZapierNLAWrapper."""
+ actions = zapier_nla_wrapper.list()
+ tools = [
+ ZapierNLARunAction(
+ action_id=action["id"],
+ zapier_description=action["description"],
+ params_schema=action["params"],
+ api_wrapper=zapier_nla_wrapper,
+ )
+ for action in actions
+ ]
+ return cls(tools=tools)
+
+ @classmethod
+ async def async_from_zapier_nla_wrapper(
+ cls, zapier_nla_wrapper: ZapierNLAWrapper
+ ) -> "ZapierToolkit":
+ """Create a toolkit from a ZapierNLAWrapper."""
+ actions = await zapier_nla_wrapper.alist()
+ tools = [
+ ZapierNLARunAction(
+ action_id=action["id"],
+ zapier_description=action["description"],
+ params_schema=action["params"],
+ api_wrapper=zapier_nla_wrapper,
+ )
+ for action in actions
+ ]
+ return cls(tools=tools)
+
+ def get_tools(self) -> List[BaseTool]:
+ """Get the tools in the toolkit."""
+ warn_deprecated(
+ since="0.0.319",
+ message=(
+ "This tool will be deprecated on 2023-11-17. See "
+ "https://nla.zapier.com/sunset/ for details"
+ ),
+ )
+ return self.tools
diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py
new file mode 100644
index 00000000000..a0070c2e59a
--- /dev/null
+++ b/libs/community/langchain_community/cache.py
@@ -0,0 +1,1556 @@
+"""
+.. warning::
+ Beta Feature!
+
+**Cache** provides an optional caching layer for LLMs.
+
+Cache is useful for two reasons:
+
+- It can save you money by reducing the number of API calls you make to the LLM
+ provider if you're often requesting the same completion multiple times.
+- It can speed up your application by reducing the number of API calls you make
+ to the LLM provider.
+
+Cache directly competes with Memory. See documentation for Pros and Cons.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseCache --> Cache # Examples: InMemoryCache, RedisCache, GPTCache
+"""
+from __future__ import annotations
+
+import hashlib
+import inspect
+import json
+import logging
+import uuid
+import warnings
+from datetime import timedelta
+from functools import lru_cache
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
+
+from sqlalchemy import Column, Integer, Row, String, create_engine, select
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm import Session
+
+try:
+ from sqlalchemy.orm import declarative_base
+except ImportError:
+ from sqlalchemy.ext.declarative import declarative_base
+
+from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
+from langchain_core.embeddings import Embeddings
+from langchain_core.language_models.llms import LLM, get_prompts
+from langchain_core.load.dump import dumps
+from langchain_core.load.load import loads
+from langchain_core.outputs import ChatGeneration, Generation
+from langchain_core.utils import get_from_env
+
+from langchain_community.vectorstores.redis import Redis as RedisVectorstore
+
+logger = logging.getLogger(__file__)
+
+if TYPE_CHECKING:
+ import momento
+ from cassandra.cluster import Session as CassandraSession
+
+
+def _hash(_input: str) -> str:
+ """Use a deterministic hashing approach."""
+ return hashlib.md5(_input.encode()).hexdigest()
+
+
+def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
+ """Dump generations to json.
+
+ Args:
+ generations (RETURN_VAL_TYPE): A list of language model generations.
+
+ Returns:
+ str: Json representing a list of generations.
+
+ Warning: would not work well with arbitrary subclasses of `Generation`
+ """
+ return json.dumps([generation.dict() for generation in generations])
+
+
+def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
+ """Load generations from json.
+
+ Args:
+ generations_json (str): A string of json representing a list of generations.
+
+ Raises:
+ ValueError: Could not decode json string to list of generations.
+
+ Returns:
+ RETURN_VAL_TYPE: A list of generations.
+
+ Warning: would not work well with arbitrary subclasses of `Generation`
+ """
+ try:
+ results = json.loads(generations_json)
+ return [Generation(**generation_dict) for generation_dict in results]
+ except json.JSONDecodeError:
+ raise ValueError(
+ f"Could not decode json to list of generations: {generations_json}"
+ )
+
+
+def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
+ """
+ Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
+
+ Args:
+ generations (RETURN_VAL_TYPE): A list of language model generations.
+
+ Returns:
+ str: a single string representing a list of generations.
+
+ This function (+ its counterpart `_loads_generations`) rely on
+ the dumps/loads pair with Reviver, so are able to deal
+ with all subclasses of Generation.
+
+ Each item in the list can be `dumps`ed to a string,
+ then we make the whole list of strings into a json-dumped.
+ """
+ return json.dumps([dumps(_item) for _item in generations])
+
+
+def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
+ """
+ Deserialization of a string into a generic RETURN_VAL_TYPE
+ (i.e. a sequence of `Generation`).
+
+ See `_dumps_generations`, the inverse of this function.
+
+ Args:
+ generations_str (str): A string representing a list of generations.
+
+ Compatible with the legacy cache-blob format
+ Does not raise exceptions for malformed entries, just logs a warning
+ and returns none: the caller should be prepared for such a cache miss.
+
+ Returns:
+ RETURN_VAL_TYPE: A list of generations.
+ """
+ try:
+ generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
+ return generations
+ except (json.JSONDecodeError, TypeError):
+ # deferring the (soft) handling to after the legacy-format attempt
+ pass
+
+ try:
+ gen_dicts = json.loads(generations_str)
+ # not relying on `_load_generations_from_json` (which could disappear):
+ generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
+ logger.warning(
+ f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
+ )
+ return generations
+ except (json.JSONDecodeError, TypeError):
+ logger.warning(
+ f"Malformed/unparsable cached blob encountered: '{generations_str}'"
+ )
+ return None
+
+
+class InMemoryCache(BaseCache):
+ """Cache that stores things in memory."""
+
+ def __init__(self) -> None:
+ """Initialize with empty cache."""
+ self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ return self._cache.get((prompt, llm_string), None)
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache based on prompt and llm_string."""
+ self._cache[(prompt, llm_string)] = return_val
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear cache."""
+ self._cache = {}
+
+
+Base = declarative_base()
+
+
+class FullLLMCache(Base): # type: ignore
+ """SQLite table for full LLM Cache (all generations)."""
+
+ __tablename__ = "full_llm_cache"
+ prompt = Column(String, primary_key=True)
+ llm = Column(String, primary_key=True)
+ idx = Column(Integer, primary_key=True)
+ response = Column(String)
+
+
+class SQLAlchemyCache(BaseCache):
+ """Cache that uses SQAlchemy as a backend."""
+
+ def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache):
+ """Initialize by creating all tables."""
+ self.engine = engine
+ self.cache_schema = cache_schema
+ self.cache_schema.metadata.create_all(self.engine)
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ stmt = (
+ select(self.cache_schema.response)
+ .where(self.cache_schema.prompt == prompt) # type: ignore
+ .where(self.cache_schema.llm == llm_string)
+ .order_by(self.cache_schema.idx)
+ )
+ with Session(self.engine) as session:
+ rows = session.execute(stmt).fetchall()
+ if rows:
+ try:
+ return [loads(row[0]) for row in rows]
+ except Exception:
+ logger.warning(
+ "Retrieving a cache value that could not be deserialized "
+ "properly. This is likely due to the cache being in an "
+ "older format. Please recreate your cache to avoid this "
+ "error."
+ )
+ # In a previous life we stored the raw text directly
+ # in the table, so assume it's in that format.
+ return [Generation(text=row[0]) for row in rows]
+ return None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update based on prompt and llm_string."""
+ items = [
+ self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i)
+ for i, gen in enumerate(return_val)
+ ]
+ with Session(self.engine) as session, session.begin():
+ for item in items:
+ session.merge(item)
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear cache."""
+ with Session(self.engine) as session:
+ session.query(self.cache_schema).delete()
+ session.commit()
+
+
+class SQLiteCache(SQLAlchemyCache):
+ """Cache that uses SQLite as a backend."""
+
+ def __init__(self, database_path: str = ".langchain.db"):
+ """Initialize by creating the engine and all tables."""
+ engine = create_engine(f"sqlite:///{database_path}")
+ super().__init__(engine)
+
+
+class UpstashRedisCache(BaseCache):
+ """Cache that uses Upstash Redis as a backend."""
+
+ def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
+ """
+ Initialize an instance of UpstashRedisCache.
+
+ This method initializes an object with Upstash Redis caching capabilities.
+ It takes a `redis_` parameter, which should be an instance of an Upstash Redis
+ client class, allowing the object to interact with Upstash Redis
+ server for caching purposes.
+
+ Parameters:
+ redis_: An instance of Upstash Redis client class
+ (e.g., Redis) used for caching.
+ This allows the object to communicate with
+ Redis server for caching operations on.
+ ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
+ If provided, it sets the time duration for how long cached
+ items will remain valid. If not provided, cached items will not
+ have an automatic expiration.
+ """
+ try:
+ from upstash_redis import Redis
+ except ImportError:
+ raise ValueError(
+ "Could not import upstash_redis python package. "
+ "Please install it with `pip install upstash_redis`."
+ )
+ if not isinstance(redis_, Redis):
+ raise ValueError("Please pass in Upstash Redis object.")
+ self.redis = redis_
+ self.ttl = ttl
+
+ def _key(self, prompt: str, llm_string: str) -> str:
+ """Compute key from prompt and llm_string"""
+ return _hash(prompt + llm_string)
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ generations = []
+ # Read from a HASH
+ results = self.redis.hgetall(self._key(prompt, llm_string))
+ if results:
+ for _, text in results.items():
+ generations.append(Generation(text=text))
+ return generations if generations else None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache based on prompt and llm_string."""
+ for gen in return_val:
+ if not isinstance(gen, Generation):
+ raise ValueError(
+ "UpstashRedisCache supports caching of normal LLM generations, "
+ f"got {type(gen)}"
+ )
+ if isinstance(gen, ChatGeneration):
+ warnings.warn(
+ "NOTE: Generation has not been cached. UpstashRedisCache does not"
+ " support caching ChatModel outputs."
+ )
+ return
+ # Write to a HASH
+ key = self._key(prompt, llm_string)
+
+ mapping = {
+ str(idx): generation.text for idx, generation in enumerate(return_val)
+ }
+ self.redis.hset(key=key, values=mapping)
+
+ if self.ttl is not None:
+ self.redis.expire(key, self.ttl)
+
+ def clear(self, **kwargs: Any) -> None:
+ """
+ Clear cache. If `asynchronous` is True, flush asynchronously.
+ This flushes the *whole* db.
+ """
+ asynchronous = kwargs.get("asynchronous", False)
+ if asynchronous:
+ asynchronous = "ASYNC"
+ else:
+ asynchronous = "SYNC"
+ self.redis.flushdb(flush_type=asynchronous)
+
+
+class RedisCache(BaseCache):
+ """Cache that uses Redis as a backend."""
+
+ def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
+ """
+ Initialize an instance of RedisCache.
+
+ This method initializes an object with Redis caching capabilities.
+ It takes a `redis_` parameter, which should be an instance of a Redis
+ client class, allowing the object to interact with a Redis
+ server for caching purposes.
+
+ Parameters:
+ redis_ (Any): An instance of a Redis client class
+ (e.g., redis.Redis) used for caching.
+ This allows the object to communicate with a
+ Redis server for caching operations.
+ ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
+ If provided, it sets the time duration for how long cached
+ items will remain valid. If not provided, cached items will not
+ have an automatic expiration.
+ """
+ try:
+ from redis import Redis
+ except ImportError:
+ raise ValueError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ )
+ if not isinstance(redis_, Redis):
+ raise ValueError("Please pass in Redis object.")
+ self.redis = redis_
+ self.ttl = ttl
+
+ def _key(self, prompt: str, llm_string: str) -> str:
+ """Compute key from prompt and llm_string"""
+ return _hash(prompt + llm_string)
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ generations = []
+ # Read from a Redis HASH
+ results = self.redis.hgetall(self._key(prompt, llm_string))
+ if results:
+ for _, text in results.items():
+ try:
+ generations.append(loads(text))
+ except Exception:
+ logger.warning(
+ "Retrieving a cache value that could not be deserialized "
+ "properly. This is likely due to the cache being in an "
+ "older format. Please recreate your cache to avoid this "
+ "error."
+ )
+ # In a previous life we stored the raw text directly
+ # in the table, so assume it's in that format.
+ generations.append(Generation(text=text))
+ return generations if generations else None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache based on prompt and llm_string."""
+ for gen in return_val:
+ if not isinstance(gen, Generation):
+ raise ValueError(
+ "RedisCache only supports caching of normal LLM generations, "
+ f"got {type(gen)}"
+ )
+ # Write to a Redis HASH
+ key = self._key(prompt, llm_string)
+
+ with self.redis.pipeline() as pipe:
+ pipe.hset(
+ key,
+ mapping={
+ str(idx): dumps(generation)
+ for idx, generation in enumerate(return_val)
+ },
+ )
+ if self.ttl is not None:
+ pipe.expire(key, self.ttl)
+
+ pipe.execute()
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear cache. If `asynchronous` is True, flush asynchronously."""
+ asynchronous = kwargs.get("asynchronous", False)
+ self.redis.flushdb(asynchronous=asynchronous, **kwargs)
+
+
+class RedisSemanticCache(BaseCache):
+ """Cache that uses Redis as a vector-store backend."""
+
+ # TODO - implement a TTL policy in Redis
+
+ DEFAULT_SCHEMA = {
+ "content_key": "prompt",
+ "text": [
+ {"name": "prompt"},
+ ],
+ "extra": [{"name": "return_val"}, {"name": "llm_string"}],
+ }
+
+ def __init__(
+ self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
+ ):
+ """Initialize by passing in the `init` GPTCache func
+
+ Args:
+ redis_url (str): URL to connect to Redis.
+ embedding (Embedding): Embedding provider for semantic encoding and search.
+ score_threshold (float, 0.2):
+
+ Example:
+
+ .. code-block:: python
+
+ from langchain_community.globals import set_llm_cache
+
+ from langchain_community.cache import RedisSemanticCache
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ set_llm_cache(RedisSemanticCache(
+ redis_url="redis://localhost:6379",
+ embedding=OpenAIEmbeddings()
+ ))
+
+ """
+ self._cache_dict: Dict[str, RedisVectorstore] = {}
+ self.redis_url = redis_url
+ self.embedding = embedding
+ self.score_threshold = score_threshold
+
+ def _index_name(self, llm_string: str) -> str:
+ hashed_index = _hash(llm_string)
+ return f"cache:{hashed_index}"
+
+ def _get_llm_cache(self, llm_string: str) -> RedisVectorstore:
+ index_name = self._index_name(llm_string)
+
+ # return vectorstore client for the specific llm string
+ if index_name in self._cache_dict:
+ return self._cache_dict[index_name]
+
+ # create new vectorstore client for the specific llm string
+ try:
+ self._cache_dict[index_name] = RedisVectorstore.from_existing_index(
+ embedding=self.embedding,
+ index_name=index_name,
+ redis_url=self.redis_url,
+ schema=cast(Dict, self.DEFAULT_SCHEMA),
+ )
+ except ValueError:
+ redis = RedisVectorstore(
+ embedding=self.embedding,
+ index_name=index_name,
+ redis_url=self.redis_url,
+ index_schema=cast(Dict, self.DEFAULT_SCHEMA),
+ )
+ _embedding = self.embedding.embed_query(text="test")
+ redis._create_index_if_not_exist(dim=len(_embedding))
+ self._cache_dict[index_name] = redis
+
+ return self._cache_dict[index_name]
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear semantic cache for a given llm_string."""
+ index_name = self._index_name(kwargs["llm_string"])
+ if index_name in self._cache_dict:
+ self._cache_dict[index_name].drop_index(
+ index_name=index_name, delete_documents=True, redis_url=self.redis_url
+ )
+ del self._cache_dict[index_name]
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ llm_cache = self._get_llm_cache(llm_string)
+ generations: List = []
+ # Read from a Hash
+ results = llm_cache.similarity_search(
+ query=prompt,
+ k=1,
+ distance_threshold=self.score_threshold,
+ )
+ if results:
+ for document in results:
+ try:
+ generations.extend(loads(document.metadata["return_val"]))
+ except Exception:
+ logger.warning(
+ "Retrieving a cache value that could not be deserialized "
+ "properly. This is likely due to the cache being in an "
+ "older format. Please recreate your cache to avoid this "
+ "error."
+ )
+ # In a previous life we stored the raw text directly
+ # in the table, so assume it's in that format.
+ generations.extend(
+ _load_generations_from_json(document.metadata["return_val"])
+ )
+ return generations if generations else None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache based on prompt and llm_string."""
+ for gen in return_val:
+ if not isinstance(gen, Generation):
+ raise ValueError(
+ "RedisSemanticCache only supports caching of "
+ f"normal LLM generations, got {type(gen)}"
+ )
+ llm_cache = self._get_llm_cache(llm_string)
+
+ metadata = {
+ "llm_string": llm_string,
+ "prompt": prompt,
+ "return_val": dumps([g for g in return_val]),
+ }
+ llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
+
+
+class GPTCache(BaseCache):
+ """Cache that uses GPTCache as a backend."""
+
+ def __init__(
+ self,
+ init_func: Union[
+ Callable[[Any, str], None], Callable[[Any], None], None
+ ] = None,
+ ):
+ """Initialize by passing in init function (default: `None`).
+
+ Args:
+ init_func (Optional[Callable[[Any], None]]): init `GPTCache` function
+ (default: `None`)
+
+ Example:
+ .. code-block:: python
+
+ # Initialize GPTCache with a custom init function
+ import gptcache
+ from gptcache.processor.pre import get_prompt
+ from gptcache.manager.factory import get_data_manager
+ from langchain_community.globals import set_llm_cache
+
+ # Avoid multiple caches using the same file,
+ causing different llm model caches to affect each other
+
+ def init_gptcache(cache_obj: gptcache.Cache, llm str):
+ cache_obj.init(
+ pre_embedding_func=get_prompt,
+ data_manager=manager_factory(
+ manager="map",
+ data_dir=f"map_cache_{llm}"
+ ),
+ )
+
+ set_llm_cache(GPTCache(init_gptcache))
+
+ """
+ try:
+ import gptcache # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import gptcache python package. "
+ "Please install it with `pip install gptcache`."
+ )
+
+ self.init_gptcache_func: Union[
+ Callable[[Any, str], None], Callable[[Any], None], None
+ ] = init_func
+ self.gptcache_dict: Dict[str, Any] = {}
+
+ def _new_gptcache(self, llm_string: str) -> Any:
+ """New gptcache object"""
+ from gptcache import Cache
+ from gptcache.manager.factory import get_data_manager
+ from gptcache.processor.pre import get_prompt
+
+ _gptcache = Cache()
+ if self.init_gptcache_func is not None:
+ sig = inspect.signature(self.init_gptcache_func)
+ if len(sig.parameters) == 2:
+ self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg]
+ else:
+ self.init_gptcache_func(_gptcache) # type: ignore[call-arg]
+ else:
+ _gptcache.init(
+ pre_embedding_func=get_prompt,
+ data_manager=get_data_manager(data_path=llm_string),
+ )
+
+ self.gptcache_dict[llm_string] = _gptcache
+ return _gptcache
+
+ def _get_gptcache(self, llm_string: str) -> Any:
+ """Get a cache object.
+
+ When the corresponding llm model cache does not exist, it will be created."""
+ _gptcache = self.gptcache_dict.get(llm_string, None)
+ if not _gptcache:
+ _gptcache = self._new_gptcache(llm_string)
+ return _gptcache
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up the cache data.
+ First, retrieve the corresponding cache object using the `llm_string` parameter,
+ and then retrieve the data from the cache based on the `prompt`.
+ """
+ from gptcache.adapter.api import get
+
+ _gptcache = self._get_gptcache(llm_string)
+
+ res = get(prompt, cache_obj=_gptcache)
+ if res:
+ return [
+ Generation(**generation_dict) for generation_dict in json.loads(res)
+ ]
+ return None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache.
+ First, retrieve the corresponding cache object using the `llm_string` parameter,
+ and then store the `prompt` and `return_val` in the cache object.
+ """
+ for gen in return_val:
+ if not isinstance(gen, Generation):
+ raise ValueError(
+ "GPTCache only supports caching of normal LLM generations, "
+ f"got {type(gen)}"
+ )
+ from gptcache.adapter.api import put
+
+ _gptcache = self._get_gptcache(llm_string)
+ handled_data = json.dumps([generation.dict() for generation in return_val])
+ put(prompt, handled_data, cache_obj=_gptcache)
+ return None
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear cache."""
+ from gptcache import Cache
+
+ for gptcache_instance in self.gptcache_dict.values():
+ gptcache_instance = cast(Cache, gptcache_instance)
+ gptcache_instance.flush()
+
+ self.gptcache_dict.clear()
+
+
+def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
+ """Create cache if it doesn't exist.
+
+ Raises:
+ SdkException: Momento service or network error
+ Exception: Unexpected response
+ """
+ from momento.responses import CreateCache
+
+ create_cache_response = cache_client.create_cache(cache_name)
+ if isinstance(create_cache_response, CreateCache.Success) or isinstance(
+ create_cache_response, CreateCache.CacheAlreadyExists
+ ):
+ return None
+ elif isinstance(create_cache_response, CreateCache.Error):
+ raise create_cache_response.inner_exception
+ else:
+ raise Exception(f"Unexpected response cache creation: {create_cache_response}")
+
+
+def _validate_ttl(ttl: Optional[timedelta]) -> None:
+ if ttl is not None and ttl <= timedelta(seconds=0):
+ raise ValueError(f"ttl must be positive but was {ttl}.")
+
+
+class MomentoCache(BaseCache):
+ """Cache that uses Momento as a backend. See https://gomomento.com/"""
+
+ def __init__(
+ self,
+ cache_client: momento.CacheClient,
+ cache_name: str,
+ *,
+ ttl: Optional[timedelta] = None,
+ ensure_cache_exists: bool = True,
+ ):
+ """Instantiate a prompt cache using Momento as a backend.
+
+ Note: to instantiate the cache client passed to MomentoCache,
+ you must have a Momento account. See https://gomomento.com/.
+
+ Args:
+ cache_client (CacheClient): The Momento cache client.
+ cache_name (str): The name of the cache to use to store the data.
+ ttl (Optional[timedelta], optional): The time to live for the cache items.
+ Defaults to None, ie use the client default TTL.
+ ensure_cache_exists (bool, optional): Create the cache if it doesn't
+ exist. Defaults to True.
+
+ Raises:
+ ImportError: Momento python package is not installed.
+ TypeError: cache_client is not of type momento.CacheClientObject
+ ValueError: ttl is non-null and non-negative
+ """
+ try:
+ from momento import CacheClient
+ except ImportError:
+ raise ImportError(
+ "Could not import momento python package. "
+ "Please install it with `pip install momento`."
+ )
+ if not isinstance(cache_client, CacheClient):
+ raise TypeError("cache_client must be a momento.CacheClient object.")
+ _validate_ttl(ttl)
+ if ensure_cache_exists:
+ _ensure_cache_exists(cache_client, cache_name)
+
+ self.cache_client = cache_client
+ self.cache_name = cache_name
+ self.ttl = ttl
+
+ @classmethod
+ def from_client_params(
+ cls,
+ cache_name: str,
+ ttl: timedelta,
+ *,
+ configuration: Optional[momento.config.Configuration] = None,
+ api_key: Optional[str] = None,
+ auth_token: Optional[str] = None, # for backwards compatibility
+ **kwargs: Any,
+ ) -> MomentoCache:
+ """Construct cache from CacheClient parameters."""
+ try:
+ from momento import CacheClient, Configurations, CredentialProvider
+ except ImportError:
+ raise ImportError(
+ "Could not import momento python package. "
+ "Please install it with `pip install momento`."
+ )
+ if configuration is None:
+ configuration = Configurations.Laptop.v1()
+
+ # Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility
+ try:
+ api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
+ except ValueError:
+ api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY")
+ credentials = CredentialProvider.from_string(api_key)
+ cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
+ return cls(cache_client, cache_name, ttl=ttl, **kwargs)
+
+ def __key(self, prompt: str, llm_string: str) -> str:
+ """Compute cache key from prompt and associated model and settings.
+
+ Args:
+ prompt (str): The prompt run through the language model.
+ llm_string (str): The language model version and settings.
+
+ Returns:
+ str: The cache key.
+ """
+ return _hash(prompt + llm_string)
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Lookup llm generations in cache by prompt and associated model and settings.
+
+ Args:
+ prompt (str): The prompt run through the language model.
+ llm_string (str): The language model version and settings.
+
+ Raises:
+ SdkException: Momento service or network error
+
+ Returns:
+ Optional[RETURN_VAL_TYPE]: A list of language model generations.
+ """
+ from momento.responses import CacheGet
+
+ generations: RETURN_VAL_TYPE = []
+
+ get_response = self.cache_client.get(
+ self.cache_name, self.__key(prompt, llm_string)
+ )
+ if isinstance(get_response, CacheGet.Hit):
+ value = get_response.value_string
+ generations = _load_generations_from_json(value)
+ elif isinstance(get_response, CacheGet.Miss):
+ pass
+ elif isinstance(get_response, CacheGet.Error):
+ raise get_response.inner_exception
+ return generations if generations else None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Store llm generations in cache.
+
+ Args:
+ prompt (str): The prompt run through the language model.
+ llm_string (str): The language model string.
+ return_val (RETURN_VAL_TYPE): A list of language model generations.
+
+ Raises:
+ SdkException: Momento service or network error
+ Exception: Unexpected response
+ """
+ for gen in return_val:
+ if not isinstance(gen, Generation):
+ raise ValueError(
+ "Momento only supports caching of normal LLM generations, "
+ f"got {type(gen)}"
+ )
+ key = self.__key(prompt, llm_string)
+ value = _dump_generations_to_json(return_val)
+ set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
+ from momento.responses import CacheSet
+
+ if isinstance(set_response, CacheSet.Success):
+ pass
+ elif isinstance(set_response, CacheSet.Error):
+ raise set_response.inner_exception
+ else:
+ raise Exception(f"Unexpected response: {set_response}")
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear the cache.
+
+ Raises:
+ SdkException: Momento service or network error
+ """
+ from momento.responses import CacheFlush
+
+ flush_response = self.cache_client.flush_cache(self.cache_name)
+ if isinstance(flush_response, CacheFlush.Success):
+ pass
+ elif isinstance(flush_response, CacheFlush.Error):
+ raise flush_response.inner_exception
+
+
+CASSANDRA_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_cache"
+CASSANDRA_CACHE_DEFAULT_TTL_SECONDS = None
+
+
+class CassandraCache(BaseCache):
+ """
+ Cache that uses Cassandra / Astra DB as a backend.
+
+ It uses a single Cassandra table.
+ The lookup keys (which get to form the primary key) are:
+ - prompt, a string
+ - llm_string, a deterministic str representation of the model parameters.
+ (needed to prevent collisions same-prompt-different-model collisions)
+ """
+
+ def __init__(
+ self,
+ session: Optional[CassandraSession] = None,
+ keyspace: Optional[str] = None,
+ table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
+ ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
+ skip_provisioning: bool = False,
+ ):
+ """
+ Initialize with a ready session and a keyspace name.
+ Args:
+ session (cassandra.cluster.Session): an open Cassandra session
+ keyspace (str): the keyspace to use for storing the cache
+ table_name (str): name of the Cassandra table to use as cache
+ ttl_seconds (optional int): time-to-live for cache entries
+ (default: None, i.e. forever)
+ """
+ try:
+ from cassio.table import ElasticCassandraTable
+ except (ImportError, ModuleNotFoundError):
+ raise ValueError(
+ "Could not import cassio python package. "
+ "Please install it with `pip install cassio`."
+ )
+
+ self.session = session
+ self.keyspace = keyspace
+ self.table_name = table_name
+ self.ttl_seconds = ttl_seconds
+
+ self.kv_cache = ElasticCassandraTable(
+ session=self.session,
+ keyspace=self.keyspace,
+ table=self.table_name,
+ keys=["llm_string", "prompt"],
+ primary_key_type=["TEXT", "TEXT"],
+ ttl_seconds=self.ttl_seconds,
+ skip_provisioning=skip_provisioning,
+ )
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ item = self.kv_cache.get(
+ llm_string=_hash(llm_string),
+ prompt=_hash(prompt),
+ )
+ if item is not None:
+ generations = _loads_generations(item["body_blob"])
+ # this protects against malformed cached items:
+ if generations is not None:
+ return generations
+ else:
+ return None
+ else:
+ return None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache based on prompt and llm_string."""
+ blob = _dumps_generations(return_val)
+ self.kv_cache.put(
+ llm_string=_hash(llm_string),
+ prompt=_hash(prompt),
+ body_blob=blob,
+ )
+
+ def delete_through_llm(
+ self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
+ ) -> None:
+ """
+ A wrapper around `delete` with the LLM being passed.
+ In case the llm(prompt) calls have a `stop` param, you should pass it here
+ """
+ llm_string = get_prompts(
+ {**llm.dict(), **{"stop": stop}},
+ [],
+ )[1]
+ return self.delete(prompt, llm_string=llm_string)
+
+ def delete(self, prompt: str, llm_string: str) -> None:
+ """Evict from cache if there's an entry."""
+ return self.kv_cache.delete(
+ llm_string=_hash(llm_string),
+ prompt=_hash(prompt),
+ )
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear cache. This is for all LLMs at once."""
+ self.kv_cache.clear()
+
+
+CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
+CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
+CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_semantic_cache"
+CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS = None
+CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
+
+
+class CassandraSemanticCache(BaseCache):
+ """
+ Cache that uses Cassandra as a vector-store backend for semantic
+ (i.e. similarity-based) lookup.
+
+ It uses a single (vector) Cassandra table and stores, in principle,
+ cached values from several LLMs, so the LLM's llm_string is part
+ of the rows' primary keys.
+
+ The similarity is based on one of several distance metrics (default: "dot").
+ If choosing another metric, the default threshold is to be re-tuned accordingly.
+ """
+
+ def __init__(
+ self,
+ session: Optional[CassandraSession],
+ keyspace: Optional[str],
+ embedding: Embeddings,
+ table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME,
+ distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC,
+ score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
+ ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
+ skip_provisioning: bool = False,
+ ):
+ """
+ Initialize the cache with all relevant parameters.
+ Args:
+ session (cassandra.cluster.Session): an open Cassandra session
+ keyspace (str): the keyspace to use for storing the cache
+ embedding (Embedding): Embedding provider for semantic
+ encoding and search.
+ table_name (str): name of the Cassandra (vector) table
+ to use as cache
+ distance_metric (str, 'dot'): which measure to adopt for
+ similarity searches
+ score_threshold (optional float): numeric value to use as
+ cutoff for the similarity searches
+ ttl_seconds (optional int): time-to-live for cache entries
+ (default: None, i.e. forever)
+ The default score threshold is tuned to the default metric.
+ Tune it carefully yourself if switching to another distance metric.
+ """
+ try:
+ from cassio.table import MetadataVectorCassandraTable
+ except (ImportError, ModuleNotFoundError):
+ raise ValueError(
+ "Could not import cassio python package. "
+ "Please install it with `pip install cassio`."
+ )
+ self.session = session
+ self.keyspace = keyspace
+ self.embedding = embedding
+ self.table_name = table_name
+ self.distance_metric = distance_metric
+ self.score_threshold = score_threshold
+ self.ttl_seconds = ttl_seconds
+
+ # The contract for this class has separate lookup and update:
+ # in order to spare some embedding calculations we cache them between
+ # the two calls.
+ # Note: each instance of this class has its own `_get_embedding` with
+ # its own lru.
+ @lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
+ def _cache_embedding(text: str) -> List[float]:
+ return self.embedding.embed_query(text=text)
+
+ self._get_embedding = _cache_embedding
+ self.embedding_dimension = self._get_embedding_dimension()
+
+ self.table = MetadataVectorCassandraTable(
+ session=self.session,
+ keyspace=self.keyspace,
+ table=self.table_name,
+ primary_key_type=["TEXT"],
+ vector_dimension=self.embedding_dimension,
+ ttl_seconds=self.ttl_seconds,
+ metadata_indexing=("allow", {"_llm_string_hash"}),
+ skip_provisioning=skip_provisioning,
+ )
+
+ def _get_embedding_dimension(self) -> int:
+ return len(self._get_embedding(text="This is a sample sentence."))
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache based on prompt and llm_string."""
+ embedding_vector = self._get_embedding(text=prompt)
+ llm_string_hash = _hash(llm_string)
+ body = _dumps_generations(return_val)
+ metadata = {
+ "_prompt": prompt,
+ "_llm_string_hash": llm_string_hash,
+ }
+ row_id = f"{_hash(prompt)}-{llm_string_hash}"
+ #
+ self.table.put(
+ body_blob=body,
+ vector=embedding_vector,
+ row_id=row_id,
+ metadata=metadata,
+ )
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ hit_with_id = self.lookup_with_id(prompt, llm_string)
+ if hit_with_id is not None:
+ return hit_with_id[1]
+ else:
+ return None
+
+ def lookup_with_id(
+ self, prompt: str, llm_string: str
+ ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
+ """
+ Look up based on prompt and llm_string.
+ If there are hits, return (document_id, cached_entry)
+ """
+ prompt_embedding: List[float] = self._get_embedding(text=prompt)
+ hits = list(
+ self.table.metric_ann_search(
+ vector=prompt_embedding,
+ metadata={"_llm_string_hash": _hash(llm_string)},
+ n=1,
+ metric=self.distance_metric,
+ metric_threshold=self.score_threshold,
+ )
+ )
+ if hits:
+ hit = hits[0]
+ generations = _loads_generations(hit["body_blob"])
+ if generations is not None:
+ # this protects against malformed cached items:
+ return (
+ hit["row_id"],
+ generations,
+ )
+ else:
+ return None
+ else:
+ return None
+
+ def lookup_with_id_through_llm(
+ self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
+ ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
+ llm_string = get_prompts(
+ {**llm.dict(), **{"stop": stop}},
+ [],
+ )[1]
+ return self.lookup_with_id(prompt, llm_string=llm_string)
+
+ def delete_by_document_id(self, document_id: str) -> None:
+ """
+ Given this is a "similarity search" cache, an invalidation pattern
+ that makes sense is first a lookup to get an ID, and then deleting
+ with that ID. This is for the second step.
+ """
+ self.table.delete(row_id=document_id)
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear the *whole* semantic cache."""
+ self.table.clear()
+
+
+class FullMd5LLMCache(Base): # type: ignore
+ """SQLite table for full LLM Cache (all generations)."""
+
+ __tablename__ = "full_md5_llm_cache"
+ id = Column(String, primary_key=True)
+ prompt_md5 = Column(String, index=True)
+ llm = Column(String, index=True)
+ idx = Column(Integer, index=True)
+ prompt = Column(String)
+ response = Column(String)
+
+
+class SQLAlchemyMd5Cache(BaseCache):
+ """Cache that uses SQAlchemy as a backend."""
+
+ def __init__(
+ self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache
+ ):
+ """Initialize by creating all tables."""
+ self.engine = engine
+ self.cache_schema = cache_schema
+ self.cache_schema.metadata.create_all(self.engine)
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ rows = self._search_rows(prompt, llm_string)
+ if rows:
+ return [loads(row[0]) for row in rows]
+ return None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update based on prompt and llm_string."""
+ self._delete_previous(prompt, llm_string)
+ prompt_md5 = self.get_md5(prompt)
+ items = [
+ self.cache_schema(
+ id=str(uuid.uuid1()),
+ prompt=prompt,
+ prompt_md5=prompt_md5,
+ llm=llm_string,
+ response=dumps(gen),
+ idx=i,
+ )
+ for i, gen in enumerate(return_val)
+ ]
+ with Session(self.engine) as session, session.begin():
+ for item in items:
+ session.merge(item)
+
+ def _delete_previous(self, prompt: str, llm_string: str) -> None:
+ stmt = (
+ select(self.cache_schema.response)
+ .where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
+ .where(self.cache_schema.llm == llm_string)
+ .where(self.cache_schema.prompt == prompt)
+ .order_by(self.cache_schema.idx)
+ )
+ with Session(self.engine) as session, session.begin():
+ rows = session.execute(stmt).fetchall()
+ for item in rows:
+ session.delete(item)
+
+ def _search_rows(self, prompt: str, llm_string: str) -> List[Row]:
+ prompt_pd5 = self.get_md5(prompt)
+ stmt = (
+ select(self.cache_schema.response)
+ .where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore
+ .where(self.cache_schema.llm == llm_string)
+ .where(self.cache_schema.prompt == prompt)
+ .order_by(self.cache_schema.idx)
+ )
+ with Session(self.engine) as session:
+ return session.execute(stmt).fetchall()
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear cache."""
+ with Session(self.engine) as session:
+ session.execute(self.cache_schema.delete())
+
+ @staticmethod
+ def get_md5(input_string: str) -> str:
+ return hashlib.md5(input_string.encode()).hexdigest()
+
+
+ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache"
+
+
+class AstraDBCache(BaseCache):
+ """
+ Cache that uses Astra DB as a backend.
+
+ It uses a single collection as a kv store
+ The lookup keys, combined in the _id of the documents, are:
+ - prompt, a string
+ - llm_string, a deterministic str representation of the model parameters.
+ (needed to prevent same-prompt-different-model collisions)
+ """
+
+ def __init__(
+ self,
+ *,
+ collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
+ token: Optional[str] = None,
+ api_endpoint: Optional[str] = None,
+ astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
+ namespace: Optional[str] = None,
+ ):
+ """
+ Create an AstraDB cache using a collection for storage.
+
+ Args (only keyword-arguments accepted):
+ collection_name (str): name of the Astra DB collection to create/use.
+ token (Optional[str]): API token for Astra DB usage.
+ api_endpoint (Optional[str]): full URL to the API endpoint,
+ such as "https://-us-east1.apps.astra.datastax.com".
+ astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
+ you can pass an already-created 'astrapy.db.AstraDB' instance.
+ namespace (Optional[str]): namespace (aka keyspace) where the
+ collection is created. Defaults to the database's "default namespace".
+ """
+ try:
+ from astrapy.db import (
+ AstraDB as LibAstraDB,
+ )
+ except (ImportError, ModuleNotFoundError):
+ raise ImportError(
+ "Could not import a recent astrapy python package. "
+ "Please install it with `pip install --upgrade astrapy`."
+ )
+ # Conflicting-arg checks:
+ if astra_db_client is not None:
+ if token is not None or api_endpoint is not None:
+ raise ValueError(
+ "You cannot pass 'astra_db_client' to AstraDB if passing "
+ "'token' and 'api_endpoint'."
+ )
+
+ self.collection_name = collection_name
+ self.token = token
+ self.api_endpoint = api_endpoint
+ self.namespace = namespace
+
+ if astra_db_client is not None:
+ self.astra_db = astra_db_client
+ else:
+ self.astra_db = LibAstraDB(
+ token=self.token,
+ api_endpoint=self.api_endpoint,
+ namespace=self.namespace,
+ )
+ self.collection = self.astra_db.create_collection(
+ collection_name=self.collection_name,
+ )
+
+ @staticmethod
+ def _make_id(prompt: str, llm_string: str) -> str:
+ return f"{_hash(prompt)}#{_hash(llm_string)}"
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ doc_id = self._make_id(prompt, llm_string)
+ item = self.collection.find_one(
+ filter={
+ "_id": doc_id,
+ },
+ projection={
+ "body_blob": 1,
+ },
+ )["data"]["document"]
+ if item is not None:
+ generations = _loads_generations(item["body_blob"])
+ # this protects against malformed cached items:
+ if generations is not None:
+ return generations
+ else:
+ return None
+ else:
+ return None
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache based on prompt and llm_string."""
+ doc_id = self._make_id(prompt, llm_string)
+ blob = _dumps_generations(return_val)
+ self.collection.upsert(
+ {
+ "_id": doc_id,
+ "body_blob": blob,
+ },
+ )
+
+ def delete_through_llm(
+ self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
+ ) -> None:
+ """
+ A wrapper around `delete` with the LLM being passed.
+ In case the llm(prompt) calls have a `stop` param, you should pass it here
+ """
+ llm_string = get_prompts(
+ {**llm.dict(), **{"stop": stop}},
+ [],
+ )[1]
+ return self.delete(prompt, llm_string=llm_string)
+
+ def delete(self, prompt: str, llm_string: str) -> None:
+ """Evict from cache if there's an entry."""
+ doc_id = self._make_id(prompt, llm_string)
+ return self.collection.delete_one(doc_id)
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear cache. This is for all LLMs at once."""
+ self.astra_db.truncate_collection(self.collection_name)
+
+
+ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85
+ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache"
+ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
+
+
+class AstraDBSemanticCache(BaseCache):
+ """
+ Cache that uses Astra DB as a vector-store backend for semantic
+ (i.e. similarity-based) lookup.
+
+ It uses a single (vector) collection and can store
+ cached values from several LLMs, so the LLM's 'llm_string' is stored
+ in the document metadata.
+
+ You can choose the preferred similarity (or use the API default) --
+ remember the threshold might require metric-dependend tuning.
+ """
+
+ def __init__(
+ self,
+ *,
+ collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
+ token: Optional[str] = None,
+ api_endpoint: Optional[str] = None,
+ astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
+ namespace: Optional[str] = None,
+ embedding: Embeddings,
+ metric: Optional[str] = None,
+ similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
+ ):
+ """
+ Initialize the cache with all relevant parameters.
+ Args:
+
+ collection_name (str): name of the Astra DB collection to create/use.
+ token (Optional[str]): API token for Astra DB usage.
+ api_endpoint (Optional[str]): full URL to the API endpoint,
+ such as "https://-us-east1.apps.astra.datastax.com".
+ astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
+ you can pass an already-created 'astrapy.db.AstraDB' instance.
+ namespace (Optional[str]): namespace (aka keyspace) where the
+ collection is created. Defaults to the database's "default namespace".
+ embedding (Embedding): Embedding provider for semantic
+ encoding and search.
+ metric: the function to use for evaluating similarity of text embeddings.
+ Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product')
+ similarity_threshold (float, optional): the minimum similarity
+ for accepting a (semantic-search) match.
+
+ The default score threshold is tuned to the default metric.
+ Tune it carefully yourself if switching to another distance metric.
+ """
+ try:
+ from astrapy.db import (
+ AstraDB as LibAstraDB,
+ )
+ except (ImportError, ModuleNotFoundError):
+ raise ImportError(
+ "Could not import a recent astrapy python package. "
+ "Please install it with `pip install --upgrade astrapy`."
+ )
+ # Conflicting-arg checks:
+ if astra_db_client is not None:
+ if token is not None or api_endpoint is not None:
+ raise ValueError(
+ "You cannot pass 'astra_db_client' to AstraDB if passing "
+ "'token' and 'api_endpoint'."
+ )
+
+ self.embedding = embedding
+ self.metric = metric
+ self.similarity_threshold = similarity_threshold
+
+ # The contract for this class has separate lookup and update:
+ # in order to spare some embedding calculations we cache them between
+ # the two calls.
+ # Note: each instance of this class has its own `_get_embedding` with
+ # its own lru.
+ @lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
+ def _cache_embedding(text: str) -> List[float]:
+ return self.embedding.embed_query(text=text)
+
+ self._get_embedding = _cache_embedding
+ self.embedding_dimension = self._get_embedding_dimension()
+
+ self.collection_name = collection_name
+ self.token = token
+ self.api_endpoint = api_endpoint
+ self.namespace = namespace
+
+ if astra_db_client is not None:
+ self.astra_db = astra_db_client
+ else:
+ self.astra_db = LibAstraDB(
+ token=self.token,
+ api_endpoint=self.api_endpoint,
+ namespace=self.namespace,
+ )
+ self.collection = self.astra_db.create_collection(
+ collection_name=self.collection_name,
+ dimension=self.embedding_dimension,
+ metric=self.metric,
+ )
+
+ def _get_embedding_dimension(self) -> int:
+ return len(self._get_embedding(text="This is a sample sentence."))
+
+ @staticmethod
+ def _make_id(prompt: str, llm_string: str) -> str:
+ return f"{_hash(prompt)}#{_hash(llm_string)}"
+
+ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
+ """Update cache based on prompt and llm_string."""
+ doc_id = self._make_id(prompt, llm_string)
+ llm_string_hash = _hash(llm_string)
+ embedding_vector = self._get_embedding(text=prompt)
+ body = _dumps_generations(return_val)
+ #
+ self.collection.upsert(
+ {
+ "_id": doc_id,
+ "body_blob": body,
+ "llm_string_hash": llm_string_hash,
+ "$vector": embedding_vector,
+ }
+ )
+
+ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
+ """Look up based on prompt and llm_string."""
+ hit_with_id = self.lookup_with_id(prompt, llm_string)
+ if hit_with_id is not None:
+ return hit_with_id[1]
+ else:
+ return None
+
+ def lookup_with_id(
+ self, prompt: str, llm_string: str
+ ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
+ """
+ Look up based on prompt and llm_string.
+ If there are hits, return (document_id, cached_entry) for the top hit
+ """
+ prompt_embedding: List[float] = self._get_embedding(text=prompt)
+ llm_string_hash = _hash(llm_string)
+
+ hit = self.collection.vector_find_one(
+ vector=prompt_embedding,
+ filter={
+ "llm_string_hash": llm_string_hash,
+ },
+ fields=["body_blob", "_id"],
+ include_similarity=True,
+ )
+
+ if hit is None or hit["$similarity"] < self.similarity_threshold:
+ return None
+ else:
+ generations = _loads_generations(hit["body_blob"])
+ if generations is not None:
+ # this protects against malformed cached items:
+ return (hit["_id"], generations)
+ else:
+ return None
+
+ def lookup_with_id_through_llm(
+ self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
+ ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
+ llm_string = get_prompts(
+ {**llm.dict(), **{"stop": stop}},
+ [],
+ )[1]
+ return self.lookup_with_id(prompt, llm_string=llm_string)
+
+ def delete_by_document_id(self, document_id: str) -> None:
+ """
+ Given this is a "similarity search" cache, an invalidation pattern
+ that makes sense is first a lookup to get an ID, and then deleting
+ with that ID. This is for the second step.
+ """
+ self.collection.delete_one(document_id)
+
+ def clear(self, **kwargs: Any) -> None:
+ """Clear the *whole* semantic cache."""
+ self.astra_db.truncate_collection(self.collection_name)
diff --git a/libs/community/langchain_community/callbacks/__init__.py b/libs/community/langchain_community/callbacks/__init__.py
new file mode 100644
index 00000000000..6016e8304d7
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/__init__.py
@@ -0,0 +1,66 @@
+"""**Callback handlers** allow listening to events in LangChain.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseCallbackHandler --> CallbackHandler # Example: AimCallbackHandler
+"""
+
+from langchain_community.callbacks.aim_callback import AimCallbackHandler
+from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler
+from langchain_community.callbacks.arize_callback import ArizeCallbackHandler
+from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler
+from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler
+from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler
+from langchain_community.callbacks.context_callback import ContextCallbackHandler
+from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler
+from langchain_community.callbacks.human import HumanApprovalCallbackHandler
+from langchain_community.callbacks.infino_callback import InfinoCallbackHandler
+from langchain_community.callbacks.labelstudio_callback import (
+ LabelStudioCallbackHandler,
+)
+from langchain_community.callbacks.llmonitor_callback import LLMonitorCallbackHandler
+from langchain_community.callbacks.manager import (
+ get_openai_callback,
+ wandb_tracing_enabled,
+)
+from langchain_community.callbacks.mlflow_callback import MlflowCallbackHandler
+from langchain_community.callbacks.openai_info import OpenAICallbackHandler
+from langchain_community.callbacks.promptlayer_callback import (
+ PromptLayerCallbackHandler,
+)
+from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
+from langchain_community.callbacks.streamlit import (
+ LLMThoughtLabeler,
+ StreamlitCallbackHandler,
+)
+from langchain_community.callbacks.trubrics_callback import TrubricsCallbackHandler
+from langchain_community.callbacks.wandb_callback import WandbCallbackHandler
+from langchain_community.callbacks.whylabs_callback import WhyLabsCallbackHandler
+
+__all__ = [
+ "AimCallbackHandler",
+ "ArgillaCallbackHandler",
+ "ArizeCallbackHandler",
+ "PromptLayerCallbackHandler",
+ "ArthurCallbackHandler",
+ "ClearMLCallbackHandler",
+ "CometCallbackHandler",
+ "ContextCallbackHandler",
+ "HumanApprovalCallbackHandler",
+ "InfinoCallbackHandler",
+ "MlflowCallbackHandler",
+ "LLMonitorCallbackHandler",
+ "OpenAICallbackHandler",
+ "LLMThoughtLabeler",
+ "StreamlitCallbackHandler",
+ "WandbCallbackHandler",
+ "WhyLabsCallbackHandler",
+ "get_openai_callback",
+ "wandb_tracing_enabled",
+ "FlyteCallbackHandler",
+ "SageMakerCallbackHandler",
+ "LabelStudioCallbackHandler",
+ "TrubricsCallbackHandler",
+]
diff --git a/libs/community/langchain_community/callbacks/aim_callback.py b/libs/community/langchain_community/callbacks/aim_callback.py
new file mode 100644
index 00000000000..da194b34b3d
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/aim_callback.py
@@ -0,0 +1,430 @@
+from copy import deepcopy
+from typing import Any, Dict, List, Optional
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+
+def import_aim() -> Any:
+ """Import the aim python package and raise an error if it is not installed."""
+ try:
+ import aim
+ except ImportError:
+ raise ImportError(
+ "To use the Aim callback manager you need to have the"
+ " `aim` python package installed."
+ "Please install it with `pip install aim`"
+ )
+ return aim
+
+
+class BaseMetadataCallbackHandler:
+ """This class handles the metadata and associated function states for callbacks.
+
+ Attributes:
+ step (int): The current step.
+ starts (int): The number of times the start method has been called.
+ ends (int): The number of times the end method has been called.
+ errors (int): The number of times the error method has been called.
+ text_ctr (int): The number of times the text method has been called.
+ ignore_llm_ (bool): Whether to ignore llm callbacks.
+ ignore_chain_ (bool): Whether to ignore chain callbacks.
+ ignore_agent_ (bool): Whether to ignore agent callbacks.
+ ignore_retriever_ (bool): Whether to ignore retriever callbacks.
+ always_verbose_ (bool): Whether to always be verbose.
+ chain_starts (int): The number of times the chain start method has been called.
+ chain_ends (int): The number of times the chain end method has been called.
+ llm_starts (int): The number of times the llm start method has been called.
+ llm_ends (int): The number of times the llm end method has been called.
+ llm_streams (int): The number of times the text method has been called.
+ tool_starts (int): The number of times the tool start method has been called.
+ tool_ends (int): The number of times the tool end method has been called.
+ agent_ends (int): The number of times the agent end method has been called.
+ """
+
+ def __init__(self) -> None:
+ self.step = 0
+
+ self.starts = 0
+ self.ends = 0
+ self.errors = 0
+ self.text_ctr = 0
+
+ self.ignore_llm_ = False
+ self.ignore_chain_ = False
+ self.ignore_agent_ = False
+ self.ignore_retriever_ = False
+ self.always_verbose_ = False
+
+ self.chain_starts = 0
+ self.chain_ends = 0
+
+ self.llm_starts = 0
+ self.llm_ends = 0
+ self.llm_streams = 0
+
+ self.tool_starts = 0
+ self.tool_ends = 0
+
+ self.agent_ends = 0
+
+ @property
+ def always_verbose(self) -> bool:
+ """Whether to call verbose callbacks even if verbose is False."""
+ return self.always_verbose_
+
+ @property
+ def ignore_llm(self) -> bool:
+ """Whether to ignore LLM callbacks."""
+ return self.ignore_llm_
+
+ @property
+ def ignore_chain(self) -> bool:
+ """Whether to ignore chain callbacks."""
+ return self.ignore_chain_
+
+ @property
+ def ignore_agent(self) -> bool:
+ """Whether to ignore agent callbacks."""
+ return self.ignore_agent_
+
+ @property
+ def ignore_retriever(self) -> bool:
+ """Whether to ignore retriever callbacks."""
+ return self.ignore_retriever_
+
+ def get_custom_callback_meta(self) -> Dict[str, Any]:
+ return {
+ "step": self.step,
+ "starts": self.starts,
+ "ends": self.ends,
+ "errors": self.errors,
+ "text_ctr": self.text_ctr,
+ "chain_starts": self.chain_starts,
+ "chain_ends": self.chain_ends,
+ "llm_starts": self.llm_starts,
+ "llm_ends": self.llm_ends,
+ "llm_streams": self.llm_streams,
+ "tool_starts": self.tool_starts,
+ "tool_ends": self.tool_ends,
+ "agent_ends": self.agent_ends,
+ }
+
+ def reset_callback_meta(self) -> None:
+ """Reset the callback metadata."""
+ self.step = 0
+
+ self.starts = 0
+ self.ends = 0
+ self.errors = 0
+ self.text_ctr = 0
+
+ self.ignore_llm_ = False
+ self.ignore_chain_ = False
+ self.ignore_agent_ = False
+ self.always_verbose_ = False
+
+ self.chain_starts = 0
+ self.chain_ends = 0
+
+ self.llm_starts = 0
+ self.llm_ends = 0
+ self.llm_streams = 0
+
+ self.tool_starts = 0
+ self.tool_ends = 0
+
+ self.agent_ends = 0
+
+ return None
+
+
+class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
+ """Callback Handler that logs to Aim.
+
+ Parameters:
+ repo (:obj:`str`, optional): Aim repository path or Repo object to which
+ Run object is bound. If skipped, default Repo is used.
+ experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
+ 'default' if not specified. Can be used later to query runs/sequences.
+ system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
+ in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
+ to disable system metrics tracking.
+ log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
+ params such as installed packages, git info, environment variables, etc.
+
+ This handler will utilize the associated callback method called and formats
+ the input of each callback function with metadata regarding the state of LLM run
+ and then logs the response to Aim.
+ """
+
+ def __init__(
+ self,
+ repo: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ system_tracking_interval: Optional[int] = 10,
+ log_system_params: bool = True,
+ ) -> None:
+ """Initialize callback handler."""
+
+ super().__init__()
+
+ aim = import_aim()
+ self.repo = repo
+ self.experiment_name = experiment_name
+ self.system_tracking_interval = system_tracking_interval
+ self.log_system_params = log_system_params
+ self._run = aim.Run(
+ repo=self.repo,
+ experiment=self.experiment_name,
+ system_tracking_interval=self.system_tracking_interval,
+ log_system_params=self.log_system_params,
+ )
+ self._run_hash = self._run.hash
+ self.action_records: list = []
+
+ def setup(self, **kwargs: Any) -> None:
+ aim = import_aim()
+
+ if not self._run:
+ if self._run_hash:
+ self._run = aim.Run(
+ self._run_hash,
+ repo=self.repo,
+ system_tracking_interval=self.system_tracking_interval,
+ )
+ else:
+ self._run = aim.Run(
+ repo=self.repo,
+ experiment=self.experiment_name,
+ system_tracking_interval=self.system_tracking_interval,
+ log_system_params=self.log_system_params,
+ )
+ self._run_hash = self._run.hash
+
+ if kwargs:
+ for key, value in kwargs.items():
+ self._run.set(key, value, strict=False)
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Run when LLM starts."""
+ aim = import_aim()
+
+ self.step += 1
+ self.llm_starts += 1
+ self.starts += 1
+
+ resp = {"action": "on_llm_start"}
+ resp.update(self.get_custom_callback_meta())
+
+ prompts_res = deepcopy(prompts)
+
+ self._run.track(
+ [aim.Text(prompt) for prompt in prompts_res],
+ name="on_llm_start",
+ context=resp,
+ )
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends running."""
+ aim = import_aim()
+ self.step += 1
+ self.llm_ends += 1
+ self.ends += 1
+
+ resp = {"action": "on_llm_end"}
+ resp.update(self.get_custom_callback_meta())
+
+ response_res = deepcopy(response)
+
+ generated = [
+ aim.Text(generation.text)
+ for generations in response_res.generations
+ for generation in generations
+ ]
+ self._run.track(
+ generated,
+ name="on_llm_end",
+ context=resp,
+ )
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run when LLM generates a new token."""
+ self.step += 1
+ self.llm_streams += 1
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when LLM errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Run when chain starts running."""
+ aim = import_aim()
+ self.step += 1
+ self.chain_starts += 1
+ self.starts += 1
+
+ resp = {"action": "on_chain_start"}
+ resp.update(self.get_custom_callback_meta())
+
+ inputs_res = deepcopy(inputs)
+
+ self._run.track(
+ aim.Text(inputs_res["input"]), name="on_chain_start", context=resp
+ )
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Run when chain ends running."""
+ aim = import_aim()
+ self.step += 1
+ self.chain_ends += 1
+ self.ends += 1
+
+ resp = {"action": "on_chain_end"}
+ resp.update(self.get_custom_callback_meta())
+
+ outputs_res = deepcopy(outputs)
+
+ self._run.track(
+ aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
+ )
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when chain errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ """Run when tool starts running."""
+ aim = import_aim()
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp = {"action": "on_tool_start"}
+ resp.update(self.get_custom_callback_meta())
+
+ self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
+
+ def on_tool_end(self, output: str, **kwargs: Any) -> None:
+ """Run when tool ends running."""
+ aim = import_aim()
+ self.step += 1
+ self.tool_ends += 1
+ self.ends += 1
+
+ resp = {"action": "on_tool_end"}
+ resp.update(self.get_custom_callback_meta())
+
+ self._run.track(aim.Text(output), name="on_tool_end", context=resp)
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when tool errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """
+ Run when agent is ending.
+ """
+ self.step += 1
+ self.text_ctr += 1
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Run when agent ends running."""
+ aim = import_aim()
+ self.step += 1
+ self.agent_ends += 1
+ self.ends += 1
+
+ resp = {"action": "on_agent_finish"}
+ resp.update(self.get_custom_callback_meta())
+
+ finish_res = deepcopy(finish)
+
+ text = "OUTPUT:\n{}\n\nLOG:\n{}".format(
+ finish_res.return_values["output"], finish_res.log
+ )
+ self._run.track(aim.Text(text), name="on_agent_finish", context=resp)
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Run on agent action."""
+ aim = import_aim()
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp = {
+ "action": "on_agent_action",
+ "tool": action.tool,
+ }
+ resp.update(self.get_custom_callback_meta())
+
+ action_res = deepcopy(action)
+
+ text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format(
+ action_res.tool_input, action_res.log
+ )
+ self._run.track(aim.Text(text), name="on_agent_action", context=resp)
+
+ def flush_tracker(
+ self,
+ repo: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ system_tracking_interval: Optional[int] = 10,
+ log_system_params: bool = True,
+ langchain_asset: Any = None,
+ reset: bool = True,
+ finish: bool = False,
+ ) -> None:
+ """Flush the tracker and reset the session.
+
+ Args:
+ repo (:obj:`str`, optional): Aim repository path or Repo object to which
+ Run object is bound. If skipped, default Repo is used.
+ experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
+ 'default' if not specified. Can be used later to query runs/sequences.
+ system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
+ in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
+ to disable system metrics tracking.
+ log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
+ params such as installed packages, git info, environment variables, etc.
+ langchain_asset: The langchain asset to save.
+ reset: Whether to reset the session.
+ finish: Whether to finish the run.
+
+ Returns:
+ None
+ """
+
+ if langchain_asset:
+ try:
+ for key, value in langchain_asset.dict().items():
+ self._run.set(key, value, strict=False)
+ except Exception:
+ pass
+
+ if finish or reset:
+ self._run.close()
+ self.reset_callback_meta()
+ if reset:
+ self.__init__( # type: ignore
+ repo=repo if repo else self.repo,
+ experiment_name=experiment_name
+ if experiment_name
+ else self.experiment_name,
+ system_tracking_interval=system_tracking_interval
+ if system_tracking_interval
+ else self.system_tracking_interval,
+ log_system_params=log_system_params
+ if log_system_params
+ else self.log_system_params,
+ )
diff --git a/libs/community/langchain_community/callbacks/argilla_callback.py b/libs/community/langchain_community/callbacks/argilla_callback.py
new file mode 100644
index 00000000000..b280e0ce1a6
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/argilla_callback.py
@@ -0,0 +1,352 @@
+import os
+import warnings
+from typing import Any, Dict, List, Optional
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+from packaging.version import parse
+
+
+class ArgillaCallbackHandler(BaseCallbackHandler):
+ """Callback Handler that logs into Argilla.
+
+ Args:
+ dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
+ exist in advance. If you need help on how to create a `FeedbackDataset` in
+ Argilla, please visit
+ https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
+ workspace_name: name of the workspace in Argilla where the specified
+ `FeedbackDataset` lives in. Defaults to `None`, which means that the
+ default workspace will be used.
+ api_url: URL of the Argilla Server that we want to use, and where the
+ `FeedbackDataset` lives in. Defaults to `None`, which means that either
+ `ARGILLA_API_URL` environment variable or the default will be used.
+ api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
+ means that either `ARGILLA_API_KEY` environment variable or the default
+ will be used.
+
+ Raises:
+ ImportError: if the `argilla` package is not installed.
+ ConnectionError: if the connection to Argilla fails.
+ FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
+
+ Examples:
+ >>> from langchain_community.llms import OpenAI
+ >>> from langchain_community.callbacks import ArgillaCallbackHandler
+ >>> argilla_callback = ArgillaCallbackHandler(
+ ... dataset_name="my-dataset",
+ ... workspace_name="my-workspace",
+ ... api_url="http://localhost:6900",
+ ... api_key="argilla.apikey",
+ ... )
+ >>> llm = OpenAI(
+ ... temperature=0,
+ ... callbacks=[argilla_callback],
+ ... verbose=True,
+ ... openai_api_key="API_KEY_HERE",
+ ... )
+ >>> llm.generate([
+ ... "What is the best NLP-annotation tool out there? (no bias at all)",
+ ... ])
+ "Argilla, no doubt about it."
+ """
+
+ REPO_URL: str = "https://github.com/argilla-io/argilla"
+ ISSUES_URL: str = f"{REPO_URL}/issues"
+ BLOG_URL: str = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
+
+ DEFAULT_API_URL: str = "http://localhost:6900"
+
+ def __init__(
+ self,
+ dataset_name: str,
+ workspace_name: Optional[str] = None,
+ api_url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ ) -> None:
+ """Initializes the `ArgillaCallbackHandler`.
+
+ Args:
+ dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
+ exist in advance. If you need help on how to create a `FeedbackDataset`
+ in Argilla, please visit
+ https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
+ workspace_name: name of the workspace in Argilla where the specified
+ `FeedbackDataset` lives in. Defaults to `None`, which means that the
+ default workspace will be used.
+ api_url: URL of the Argilla Server that we want to use, and where the
+ `FeedbackDataset` lives in. Defaults to `None`, which means that either
+ `ARGILLA_API_URL` environment variable or the default will be used.
+ api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
+ means that either `ARGILLA_API_KEY` environment variable or the default
+ will be used.
+
+ Raises:
+ ImportError: if the `argilla` package is not installed.
+ ConnectionError: if the connection to Argilla fails.
+ FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
+ """
+
+ super().__init__()
+
+ # Import Argilla (not via `import_argilla` to keep hints in IDEs)
+ try:
+ import argilla as rg # noqa: F401
+
+ self.ARGILLA_VERSION = rg.__version__
+ except ImportError:
+ raise ImportError(
+ "To use the Argilla callback manager you need to have the `argilla` "
+ "Python package installed. Please install it with `pip install argilla`"
+ )
+
+ # Check whether the Argilla version is compatible
+ if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
+ raise ImportError(
+ f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
+ "`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
+ "upgrade `argilla` with `pip install --upgrade argilla`."
+ )
+
+ # Show a warning message if Argilla will assume the default values will be used
+ if api_url is None and os.getenv("ARGILLA_API_URL") is None:
+ warnings.warn(
+ (
+ "Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
+ f" set, it will default to `{self.DEFAULT_API_URL}`, which is the"
+ " default API URL in Argilla Quickstart."
+ ),
+ )
+ api_url = self.DEFAULT_API_URL
+
+ if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
+ self.DEFAULT_API_KEY = (
+ "admin.apikey"
+ if parse(self.ARGILLA_VERSION) < parse("1.11.0")
+ else "owner.apikey"
+ )
+
+ warnings.warn(
+ (
+ "Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
+ f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the"
+ " default API key in Argilla Quickstart."
+ ),
+ )
+ api_url = self.DEFAULT_API_URL
+
+ # Connect to Argilla with the provided credentials, if applicable
+ try:
+ rg.init(api_key=api_key, api_url=api_url)
+ except Exception as e:
+ raise ConnectionError(
+ f"Could not connect to Argilla with exception: '{e}'.\n"
+ "Please check your `api_key` and `api_url`, and make sure that "
+ "the Argilla server is up and running. If the problem persists "
+ f"please report it to {self.ISSUES_URL} as an `integration` issue."
+ ) from e
+
+ # Set the Argilla variables
+ self.dataset_name = dataset_name
+ self.workspace_name = workspace_name or rg.get_workspace()
+
+ # Retrieve the `FeedbackDataset` from Argilla (without existing records)
+ try:
+ extra_args = {}
+ if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
+ warnings.warn(
+ f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
+ " higher is recommended.",
+ UserWarning,
+ )
+ extra_args = {"with_records": False}
+ self.dataset = rg.FeedbackDataset.from_argilla(
+ name=self.dataset_name,
+ workspace=self.workspace_name,
+ **extra_args,
+ )
+ except Exception as e:
+ raise FileNotFoundError(
+ f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
+ f"\nPlease check that the dataset with name={self.dataset_name} in the"
+ f" workspace={self.workspace_name} exists in advance. If you need help"
+ " on how to create a `langchain`-compatible `FeedbackDataset` in"
+ f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
+ f" please report it to {self.ISSUES_URL} as an `integration` issue."
+ ) from e
+
+ supported_fields = ["prompt", "response"]
+ if supported_fields != [field.name for field in self.dataset.fields]:
+ raise ValueError(
+ f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
+ f"{self.workspace_name} had fields that are not supported yet for the"
+ f"`langchain` integration. Supported fields are: {supported_fields},"
+ f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
+ " For more information on how to create a `langchain`-compatible"
+ f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
+ )
+
+ self.prompts: Dict[str, List[str]] = {}
+
+ warnings.warn(
+ (
+ "The `ArgillaCallbackHandler` is currently in beta and is subject to"
+ " change based on updates to `langchain`. Please report any issues to"
+ f" {self.ISSUES_URL} as an `integration` issue."
+ ),
+ )
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Save the prompts in memory when an LLM starts."""
+ self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Do nothing when a new token is generated."""
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Log records to Argilla when an LLM ends."""
+ # Do nothing if there's a parent_run_id, since we will log the records when
+ # the chain ends
+ if kwargs["parent_run_id"]:
+ return
+
+ # Creates the records and adds them to the `FeedbackDataset`
+ prompts = self.prompts[str(kwargs["run_id"])]
+ for prompt, generations in zip(prompts, response.generations):
+ self.dataset.add_records(
+ records=[
+ {
+ "fields": {
+ "prompt": prompt,
+ "response": generation.text.strip(),
+ },
+ }
+ for generation in generations
+ ]
+ )
+
+ # Pop current run from `self.runs`
+ self.prompts.pop(str(kwargs["run_id"]))
+
+ if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
+ # Push the records to Argilla
+ self.dataset.push_to_argilla()
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when LLM outputs an error."""
+ pass
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """If the key `input` is in `inputs`, then save it in `self.prompts` using
+ either the `parent_run_id` or the `run_id` as the key. This is done so that
+ we don't log the same input prompt twice, once when the LLM starts and once
+ when the chain starts.
+ """
+ if "input" in inputs:
+ self.prompts.update(
+ {
+ str(kwargs["parent_run_id"] or kwargs["run_id"]): (
+ inputs["input"]
+ if isinstance(inputs["input"], list)
+ else [inputs["input"]]
+ )
+ }
+ )
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
+ log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
+ differs if the output is a list or not.
+ """
+ if not any(
+ key in self.prompts
+ for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
+ ):
+ return
+ prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
+ str(kwargs["run_id"])
+ )
+ for chain_output_key, chain_output_val in outputs.items():
+ if isinstance(chain_output_val, list):
+ # Creates the records and adds them to the `FeedbackDataset`
+ self.dataset.add_records(
+ records=[
+ {
+ "fields": {
+ "prompt": prompt,
+ "response": output["text"].strip(),
+ },
+ }
+ for prompt, output in zip(
+ prompts, # type: ignore
+ chain_output_val,
+ )
+ ]
+ )
+ else:
+ # Creates the records and adds them to the `FeedbackDataset`
+ self.dataset.add_records(
+ records=[
+ {
+ "fields": {
+ "prompt": " ".join(prompts), # type: ignore
+ "response": chain_output_val.strip(),
+ },
+ }
+ ]
+ )
+
+ # Pop current run from `self.runs`
+ if str(kwargs["parent_run_id"]) in self.prompts:
+ self.prompts.pop(str(kwargs["parent_run_id"]))
+ if str(kwargs["run_id"]) in self.prompts:
+ self.prompts.pop(str(kwargs["run_id"]))
+
+ if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
+ # Push the records to Argilla
+ self.dataset.push_to_argilla()
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when LLM chain outputs an error."""
+ pass
+
+ def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool starts."""
+ pass
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Do nothing when agent takes a specific action."""
+ pass
+
+ def on_tool_end(
+ self,
+ output: str,
+ observation_prefix: Optional[str] = None,
+ llm_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool ends."""
+ pass
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when tool outputs an error."""
+ pass
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """Do nothing"""
+ pass
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Do nothing"""
+ pass
diff --git a/libs/community/langchain_community/callbacks/arize_callback.py b/libs/community/langchain_community/callbacks/arize_callback.py
new file mode 100644
index 00000000000..44212b61917
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/arize_callback.py
@@ -0,0 +1,213 @@
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+from langchain_community.callbacks.utils import import_pandas
+
+
+class ArizeCallbackHandler(BaseCallbackHandler):
+ """Callback Handler that logs to Arize."""
+
+ def __init__(
+ self,
+ model_id: Optional[str] = None,
+ model_version: Optional[str] = None,
+ SPACE_KEY: Optional[str] = None,
+ API_KEY: Optional[str] = None,
+ ) -> None:
+ """Initialize callback handler."""
+
+ super().__init__()
+ self.model_id = model_id
+ self.model_version = model_version
+ self.space_key = SPACE_KEY
+ self.api_key = API_KEY
+ self.prompt_records: List[str] = []
+ self.response_records: List[str] = []
+ self.prediction_ids: List[str] = []
+ self.pred_timestamps: List[int] = []
+ self.response_embeddings: List[float] = []
+ self.prompt_embeddings: List[float] = []
+ self.prompt_tokens = 0
+ self.completion_tokens = 0
+ self.total_tokens = 0
+ self.step = 0
+
+ from arize.pandas.embeddings import EmbeddingGenerator, UseCases
+ from arize.pandas.logger import Client
+
+ self.generator = EmbeddingGenerator.from_use_case(
+ use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
+ model_name="distilbert-base-uncased",
+ tokenizer_max_length=512,
+ batch_size=256,
+ )
+ self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
+ if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
+ raise ValueError("β CHANGE SPACE AND API KEYS")
+ else:
+ print("β Arize client setup done! Now you can start using Arize!")
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ for prompt in prompts:
+ self.prompt_records.append(prompt.replace("\n", ""))
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Do nothing."""
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ pd = import_pandas()
+ from arize.utils.types import (
+ EmbeddingColumnNames,
+ Environments,
+ ModelTypes,
+ Schema,
+ )
+
+ # Safe check if 'llm_output' and 'token_usage' exist
+ if response.llm_output and "token_usage" in response.llm_output:
+ self.prompt_tokens = response.llm_output["token_usage"].get(
+ "prompt_tokens", 0
+ )
+ self.total_tokens = response.llm_output["token_usage"].get(
+ "total_tokens", 0
+ )
+ self.completion_tokens = response.llm_output["token_usage"].get(
+ "completion_tokens", 0
+ )
+ else:
+ self.prompt_tokens = (
+ self.total_tokens
+ ) = self.completion_tokens = 0 # assign default value
+
+ for generations in response.generations:
+ for generation in generations:
+ prompt = self.prompt_records[self.step]
+ self.step = self.step + 1
+ prompt_embedding = pd.Series(
+ self.generator.generate_embeddings(
+ text_col=pd.Series(prompt.replace("\n", " "))
+ ).reset_index(drop=True)
+ )
+
+ # Assigning text to response_text instead of response
+ response_text = generation.text.replace("\n", " ")
+ response_embedding = pd.Series(
+ self.generator.generate_embeddings(
+ text_col=pd.Series(generation.text.replace("\n", " "))
+ ).reset_index(drop=True)
+ )
+ pred_timestamp = datetime.now().timestamp()
+
+ # Define the columns and data
+ columns = [
+ "prediction_ts",
+ "response",
+ "prompt",
+ "response_vector",
+ "prompt_vector",
+ "prompt_token",
+ "completion_token",
+ "total_token",
+ ]
+ data = [
+ [
+ pred_timestamp,
+ response_text,
+ prompt,
+ response_embedding[0],
+ prompt_embedding[0],
+ self.prompt_tokens,
+ self.total_tokens,
+ self.completion_tokens,
+ ]
+ ]
+
+ # Create the DataFrame
+ df = pd.DataFrame(data, columns=columns)
+
+ # Declare prompt and response columns
+ prompt_columns = EmbeddingColumnNames(
+ vector_column_name="prompt_vector", data_column_name="prompt"
+ )
+
+ response_columns = EmbeddingColumnNames(
+ vector_column_name="response_vector", data_column_name="response"
+ )
+
+ schema = Schema(
+ timestamp_column_name="prediction_ts",
+ tag_column_names=[
+ "prompt_token",
+ "completion_token",
+ "total_token",
+ ],
+ prompt_column_names=prompt_columns,
+ response_column_names=response_columns,
+ )
+
+ response_from_arize = self.arize_client.log(
+ dataframe=df,
+ schema=schema,
+ model_id=self.model_id,
+ model_version=self.model_version,
+ model_type=ModelTypes.GENERATIVE_LLM,
+ environment=Environments.PRODUCTION,
+ )
+ if response_from_arize.status_code == 200:
+ print("β Successfully logged data to Arize!")
+ else:
+ print(f'β Logging failed "{response_from_arize.text}"')
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing."""
+ pass
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ pass
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Do nothing."""
+ pass
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing."""
+ pass
+
+ def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ **kwargs: Any,
+ ) -> None:
+ pass
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Do nothing."""
+ pass
+
+ def on_tool_end(
+ self,
+ output: str,
+ observation_prefix: Optional[str] = None,
+ llm_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ pass
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ pass
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ pass
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ pass
diff --git a/libs/community/langchain_community/callbacks/arthur_callback.py b/libs/community/langchain_community/callbacks/arthur_callback.py
new file mode 100644
index 00000000000..a5fce582ed1
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/arthur_callback.py
@@ -0,0 +1,296 @@
+"""ArthurAI's Callback Handler."""
+from __future__ import annotations
+
+import os
+import uuid
+from collections import defaultdict
+from datetime import datetime
+from time import time
+from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
+
+import numpy as np
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+if TYPE_CHECKING:
+ import arthurai
+ from arthurai.core.models import ArthurModel
+
+PROMPT_TOKENS = "prompt_tokens"
+COMPLETION_TOKENS = "completion_tokens"
+TOKEN_USAGE = "token_usage"
+FINISH_REASON = "finish_reason"
+DURATION = "duration"
+
+
+def _lazy_load_arthur() -> arthurai:
+ """Lazy load Arthur."""
+ try:
+ import arthurai
+ except ImportError as e:
+ raise ImportError(
+ "To use the ArthurCallbackHandler you need the"
+ " `arthurai` package. Please install it with"
+ " `pip install arthurai`.",
+ e,
+ )
+
+ return arthurai
+
+
+class ArthurCallbackHandler(BaseCallbackHandler):
+ """Callback Handler that logs to Arthur platform.
+
+ Arthur helps enterprise teams optimize model operations
+ and performance at scale. The Arthur API tracks model
+ performance, explainability, and fairness across tabular,
+ NLP, and CV models. Our API is model- and platform-agnostic,
+ and continuously scales with complex and dynamic enterprise needs.
+ To learn more about Arthur, visit our website at
+ https://www.arthur.ai/ or read the Arthur docs at
+ https://docs.arthur.ai/
+ """
+
+ def __init__(
+ self,
+ arthur_model: ArthurModel,
+ ) -> None:
+ """Initialize callback handler."""
+ super().__init__()
+ arthurai = _lazy_load_arthur()
+ Stage = arthurai.common.constants.Stage
+ ValueType = arthurai.common.constants.ValueType
+ self.arthur_model = arthur_model
+ # save the attributes of this model to be used when preparing
+ # inferences to log to Arthur in on_llm_end()
+ self.attr_names = set([a.name for a in self.arthur_model.get_attributes()])
+ self.input_attr = [
+ x
+ for x in self.arthur_model.get_attributes()
+ if x.stage == Stage.ModelPipelineInput
+ and x.value_type == ValueType.Unstructured_Text
+ ][0].name
+ self.output_attr = [
+ x
+ for x in self.arthur_model.get_attributes()
+ if x.stage == Stage.PredictedValue
+ and x.value_type == ValueType.Unstructured_Text
+ ][0].name
+ self.token_likelihood_attr = None
+ if (
+ len(
+ [
+ x
+ for x in self.arthur_model.get_attributes()
+ if x.value_type == ValueType.TokenLikelihoods
+ ]
+ )
+ > 0
+ ):
+ self.token_likelihood_attr = [
+ x
+ for x in self.arthur_model.get_attributes()
+ if x.value_type == ValueType.TokenLikelihoods
+ ][0].name
+
+ self.run_map: DefaultDict[str, Any] = defaultdict(dict)
+
+ @classmethod
+ def from_credentials(
+ cls,
+ model_id: str,
+ arthur_url: Optional[str] = "https://app.arthur.ai",
+ arthur_login: Optional[str] = None,
+ arthur_password: Optional[str] = None,
+ ) -> ArthurCallbackHandler:
+ """Initialize callback handler from Arthur credentials.
+
+ Args:
+ model_id (str): The ID of the arthur model to log to.
+ arthur_url (str, optional): The URL of the Arthur instance to log to.
+ Defaults to "https://app.arthur.ai".
+ arthur_login (str, optional): The login to use to connect to Arthur.
+ Defaults to None.
+ arthur_password (str, optional): The password to use to connect to
+ Arthur. Defaults to None.
+
+ Returns:
+ ArthurCallbackHandler: The initialized callback handler.
+ """
+ arthurai = _lazy_load_arthur()
+ ArthurAI = arthurai.ArthurAI
+ ResponseClientError = arthurai.common.exceptions.ResponseClientError
+
+ # connect to Arthur
+ if arthur_login is None:
+ try:
+ arthur_api_key = os.environ["ARTHUR_API_KEY"]
+ except KeyError:
+ raise ValueError(
+ "No Arthur authentication provided. Either give"
+ " a login to the ArthurCallbackHandler"
+ " or set an ARTHUR_API_KEY as an environment variable."
+ )
+ arthur = ArthurAI(url=arthur_url, access_key=arthur_api_key)
+ else:
+ if arthur_password is None:
+ arthur = ArthurAI(url=arthur_url, login=arthur_login)
+ else:
+ arthur = ArthurAI(
+ url=arthur_url, login=arthur_login, password=arthur_password
+ )
+ # get model from Arthur by the provided model ID
+ try:
+ arthur_model = arthur.get_model(model_id)
+ except ResponseClientError:
+ raise ValueError(
+ f"Was unable to retrieve model with id {model_id} from Arthur."
+ " Make sure the ID corresponds to a model that is currently"
+ " registered with your Arthur account."
+ )
+ return cls(arthur_model)
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """On LLM start, save the input prompts"""
+ run_id = kwargs["run_id"]
+ self.run_map[run_id]["input_texts"] = prompts
+ self.run_map[run_id]["start_time"] = time()
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """On LLM end, send data to Arthur."""
+ try:
+ import pytz # type: ignore[import]
+ except ImportError as e:
+ raise ImportError(
+ "Could not import pytz. Please install it with 'pip install pytz'."
+ ) from e
+
+ run_id = kwargs["run_id"]
+
+ # get the run params from this run ID,
+ # or raise an error if this run ID has no corresponding metadata in self.run_map
+ try:
+ run_map_data = self.run_map[run_id]
+ except KeyError as e:
+ raise KeyError(
+ "This function has been called with a run_id"
+ " that was never registered in on_llm_start()."
+ " Restart and try running the LLM again"
+ ) from e
+
+ # mark the duration time between on_llm_start() and on_llm_end()
+ time_from_start_to_end = time() - run_map_data["start_time"]
+
+ # create inferences to log to Arthur
+ inferences = []
+ for i, generations in enumerate(response.generations):
+ for generation in generations:
+ inference = {
+ "partner_inference_id": str(uuid.uuid4()),
+ "inference_timestamp": datetime.now(tz=pytz.UTC),
+ self.input_attr: run_map_data["input_texts"][i],
+ self.output_attr: generation.text,
+ }
+
+ if generation.generation_info is not None:
+ # add finish reason to the inference
+ # if generation info contains a finish reason and
+ # if the ArthurModel was registered to monitor finish_reason
+ if (
+ FINISH_REASON in generation.generation_info
+ and FINISH_REASON in self.attr_names
+ ):
+ inference[FINISH_REASON] = generation.generation_info[
+ FINISH_REASON
+ ]
+
+ # add token likelihoods data to the inference if the ArthurModel
+ # was registered to monitor token likelihoods
+ logprobs_data = generation.generation_info["logprobs"]
+ if (
+ logprobs_data is not None
+ and self.token_likelihood_attr is not None
+ ):
+ logprobs = logprobs_data["top_logprobs"]
+ likelihoods = [
+ {k: np.exp(v) for k, v in logprobs[i].items()}
+ for i in range(len(logprobs))
+ ]
+ inference[self.token_likelihood_attr] = likelihoods
+
+ # add token usage counts to the inference if the
+ # ArthurModel was registered to monitor token usage
+ if (
+ isinstance(response.llm_output, dict)
+ and TOKEN_USAGE in response.llm_output
+ ):
+ token_usage = response.llm_output[TOKEN_USAGE]
+ if (
+ PROMPT_TOKENS in token_usage
+ and PROMPT_TOKENS in self.attr_names
+ ):
+ inference[PROMPT_TOKENS] = token_usage[PROMPT_TOKENS]
+ if (
+ COMPLETION_TOKENS in token_usage
+ and COMPLETION_TOKENS in self.attr_names
+ ):
+ inference[COMPLETION_TOKENS] = token_usage[COMPLETION_TOKENS]
+
+ # add inference duration to the inference if the ArthurModel
+ # was registered to monitor inference duration
+ if DURATION in self.attr_names:
+ inference[DURATION] = time_from_start_to_end
+
+ inferences.append(inference)
+
+ # send inferences to arthur
+ self.arthur_model.send_inferences(inferences)
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """On chain start, do nothing."""
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """On chain end, do nothing."""
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when LLM outputs an error."""
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """On new token, pass."""
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when LLM chain outputs an error."""
+
+ def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool starts."""
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Do nothing when agent takes a specific action."""
+
+ def on_tool_end(
+ self,
+ output: str,
+ observation_prefix: Optional[str] = None,
+ llm_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool ends."""
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when tool outputs an error."""
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """Do nothing"""
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Do nothing"""
diff --git a/libs/community/langchain_community/callbacks/clearml_callback.py b/libs/community/langchain_community/callbacks/clearml_callback.py
new file mode 100644
index 00000000000..71f30fccdd8
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/clearml_callback.py
@@ -0,0 +1,527 @@
+from __future__ import annotations
+
+import tempfile
+from copy import deepcopy
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+from langchain_community.callbacks.utils import (
+ BaseMetadataCallbackHandler,
+ flatten_dict,
+ hash_string,
+ import_pandas,
+ import_spacy,
+ import_textstat,
+ load_json,
+)
+
+if TYPE_CHECKING:
+ import pandas as pd
+
+
+def import_clearml() -> Any:
+ """Import the clearml python package and raise an error if it is not installed."""
+ try:
+ import clearml # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "To use the clearml callback manager you need to have the `clearml` python "
+ "package installed. Please install it with `pip install clearml`"
+ )
+ return clearml
+
+
+class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
+ """Callback Handler that logs to ClearML.
+
+ Parameters:
+ job_type (str): The type of clearml task such as "inference", "testing" or "qc"
+ project_name (str): The clearml project name
+ tags (list): Tags to add to the task
+ task_name (str): Name of the clearml task
+ visualize (bool): Whether to visualize the run.
+ complexity_metrics (bool): Whether to log complexity metrics
+ stream_logs (bool): Whether to stream callback actions to ClearML
+
+ This handler will utilize the associated callback method and formats
+ the input of each callback function with metadata regarding the state of LLM run,
+ and adds the response to the list of records for both the {method}_records and
+ action. It then logs the response to the ClearML console.
+ """
+
+ def __init__(
+ self,
+ task_type: Optional[str] = "inference",
+ project_name: Optional[str] = "langchain_callback_demo",
+ tags: Optional[Sequence] = None,
+ task_name: Optional[str] = None,
+ visualize: bool = False,
+ complexity_metrics: bool = False,
+ stream_logs: bool = False,
+ ) -> None:
+ """Initialize callback handler."""
+
+ clearml = import_clearml()
+ spacy = import_spacy()
+ super().__init__()
+
+ self.task_type = task_type
+ self.project_name = project_name
+ self.tags = tags
+ self.task_name = task_name
+ self.visualize = visualize
+ self.complexity_metrics = complexity_metrics
+ self.stream_logs = stream_logs
+
+ self.temp_dir = tempfile.TemporaryDirectory()
+
+ # Check if ClearML task already exists (e.g. in pipeline)
+ if clearml.Task.current_task():
+ self.task = clearml.Task.current_task()
+ else:
+ self.task = clearml.Task.init( # type: ignore
+ task_type=self.task_type,
+ project_name=self.project_name,
+ tags=self.tags,
+ task_name=self.task_name,
+ output_uri=True,
+ )
+ self.logger = self.task.get_logger()
+ warning = (
+ "The clearml callback is currently in beta and is subject to change "
+ "based on updates to `langchain`. Please report any issues to "
+ "https://github.com/allegroai/clearml/issues with the tag `langchain`."
+ )
+ self.logger.report_text(warning, level=30, print_console=True)
+ self.callback_columns: list = []
+ self.action_records: list = []
+ self.complexity_metrics = complexity_metrics
+ self.visualize = visualize
+ self.nlp = spacy.load("en_core_web_sm")
+
+ def _init_resp(self) -> Dict:
+ return {k: None for k in self.callback_columns}
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Run when LLM starts."""
+ self.step += 1
+ self.llm_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_llm_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ for prompt in prompts:
+ prompt_resp = deepcopy(resp)
+ prompt_resp["prompts"] = prompt
+ self.on_llm_start_records.append(prompt_resp)
+ self.action_records.append(prompt_resp)
+ if self.stream_logs:
+ self.logger.report_text(prompt_resp)
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run when LLM generates a new token."""
+ self.step += 1
+ self.llm_streams += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_llm_new_token", "token": token})
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_llm_token_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.logger.report_text(resp)
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends running."""
+ self.step += 1
+ self.llm_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_llm_end"})
+ resp.update(flatten_dict(response.llm_output or {}))
+ resp.update(self.get_custom_callback_meta())
+
+ for generations in response.generations:
+ for generation in generations:
+ generation_resp = deepcopy(resp)
+ generation_resp.update(flatten_dict(generation.dict()))
+ generation_resp.update(self.analyze_text(generation.text))
+ self.on_llm_end_records.append(generation_resp)
+ self.action_records.append(generation_resp)
+ if self.stream_logs:
+ self.logger.report_text(generation_resp)
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when LLM errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Run when chain starts running."""
+ self.step += 1
+ self.chain_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_chain_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ chain_input = inputs.get("input", inputs.get("human_input"))
+
+ if isinstance(chain_input, str):
+ input_resp = deepcopy(resp)
+ input_resp["input"] = chain_input
+ self.on_chain_start_records.append(input_resp)
+ self.action_records.append(input_resp)
+ if self.stream_logs:
+ self.logger.report_text(input_resp)
+ elif isinstance(chain_input, list):
+ for inp in chain_input:
+ input_resp = deepcopy(resp)
+ input_resp.update(inp)
+ self.on_chain_start_records.append(input_resp)
+ self.action_records.append(input_resp)
+ if self.stream_logs:
+ self.logger.report_text(input_resp)
+ else:
+ raise ValueError("Unexpected data format provided!")
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Run when chain ends running."""
+ self.step += 1
+ self.chain_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update(
+ {
+ "action": "on_chain_end",
+ "outputs": outputs.get("output", outputs.get("text")),
+ }
+ )
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_chain_end_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.logger.report_text(resp)
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when chain errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ """Run when tool starts running."""
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_tool_start", "input_str": input_str})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_tool_start_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.logger.report_text(resp)
+
+ def on_tool_end(self, output: str, **kwargs: Any) -> None:
+ """Run when tool ends running."""
+ self.step += 1
+ self.tool_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_tool_end", "output": output})
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_tool_end_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.logger.report_text(resp)
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when tool errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """
+ Run when agent is ending.
+ """
+ self.step += 1
+ self.text_ctr += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_text", "text": text})
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_text_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.logger.report_text(resp)
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Run when agent ends running."""
+ self.step += 1
+ self.agent_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update(
+ {
+ "action": "on_agent_finish",
+ "output": finish.return_values["output"],
+ "log": finish.log,
+ }
+ )
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_agent_finish_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.logger.report_text(resp)
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Run on agent action."""
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update(
+ {
+ "action": "on_agent_action",
+ "tool": action.tool,
+ "tool_input": action.tool_input,
+ "log": action.log,
+ }
+ )
+ resp.update(self.get_custom_callback_meta())
+ self.on_agent_action_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.logger.report_text(resp)
+
+ def analyze_text(self, text: str) -> dict:
+ """Analyze text using textstat and spacy.
+
+ Parameters:
+ text (str): The text to analyze.
+
+ Returns:
+ (dict): A dictionary containing the complexity metrics.
+ """
+ resp = {}
+ textstat = import_textstat()
+ spacy = import_spacy()
+ if self.complexity_metrics:
+ text_complexity_metrics = {
+ "flesch_reading_ease": textstat.flesch_reading_ease(text),
+ "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
+ "smog_index": textstat.smog_index(text),
+ "coleman_liau_index": textstat.coleman_liau_index(text),
+ "automated_readability_index": textstat.automated_readability_index(
+ text
+ ),
+ "dale_chall_readability_score": textstat.dale_chall_readability_score(
+ text
+ ),
+ "difficult_words": textstat.difficult_words(text),
+ "linsear_write_formula": textstat.linsear_write_formula(text),
+ "gunning_fog": textstat.gunning_fog(text),
+ "text_standard": textstat.text_standard(text),
+ "fernandez_huerta": textstat.fernandez_huerta(text),
+ "szigriszt_pazos": textstat.szigriszt_pazos(text),
+ "gutierrez_polini": textstat.gutierrez_polini(text),
+ "crawford": textstat.crawford(text),
+ "gulpease_index": textstat.gulpease_index(text),
+ "osman": textstat.osman(text),
+ }
+ resp.update(text_complexity_metrics)
+
+ if self.visualize and self.nlp and self.temp_dir.name is not None:
+ doc = self.nlp(text)
+
+ dep_out = spacy.displacy.render( # type: ignore
+ doc, style="dep", jupyter=False, page=True
+ )
+ dep_output_path = Path(
+ self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
+ )
+ dep_output_path.open("w", encoding="utf-8").write(dep_out)
+
+ ent_out = spacy.displacy.render( # type: ignore
+ doc, style="ent", jupyter=False, page=True
+ )
+ ent_output_path = Path(
+ self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
+ )
+ ent_output_path.open("w", encoding="utf-8").write(ent_out)
+
+ self.logger.report_media(
+ "Dependencies Plot", text, local_path=dep_output_path
+ )
+ self.logger.report_media("Entities Plot", text, local_path=ent_output_path)
+
+ return resp
+
+ @staticmethod
+ def _build_llm_df(
+ base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
+ ) -> pd.DataFrame:
+ base_df_fields = [field for field in base_df_fields if field in base_df]
+ rename_map = {
+ map_entry_k: map_entry_v
+ for map_entry_k, map_entry_v in rename_map.items()
+ if map_entry_k in base_df_fields
+ }
+ llm_df = base_df[base_df_fields].dropna(axis=1)
+ if rename_map:
+ llm_df = llm_df.rename(rename_map, axis=1)
+ return llm_df
+
+ def _create_session_analysis_df(self) -> Any:
+ """Create a dataframe with all the information from the session."""
+ pd = import_pandas()
+ on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
+
+ llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
+ base_df=on_llm_end_records_df,
+ base_df_fields=["step", "prompts"]
+ + (["name"] if "name" in on_llm_end_records_df else ["id"]),
+ rename_map={"step": "prompt_step"},
+ )
+ complexity_metrics_columns = []
+ visualizations_columns: List = []
+
+ if self.complexity_metrics:
+ complexity_metrics_columns = [
+ "flesch_reading_ease",
+ "flesch_kincaid_grade",
+ "smog_index",
+ "coleman_liau_index",
+ "automated_readability_index",
+ "dale_chall_readability_score",
+ "difficult_words",
+ "linsear_write_formula",
+ "gunning_fog",
+ "text_standard",
+ "fernandez_huerta",
+ "szigriszt_pazos",
+ "gutierrez_polini",
+ "crawford",
+ "gulpease_index",
+ "osman",
+ ]
+
+ llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
+ on_llm_end_records_df,
+ [
+ "step",
+ "text",
+ "token_usage_total_tokens",
+ "token_usage_prompt_tokens",
+ "token_usage_completion_tokens",
+ ]
+ + complexity_metrics_columns
+ + visualizations_columns,
+ {"step": "output_step", "text": "output"},
+ )
+ session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
+ return session_analysis_df
+
+ def flush_tracker(
+ self,
+ name: Optional[str] = None,
+ langchain_asset: Any = None,
+ finish: bool = False,
+ ) -> None:
+ """Flush the tracker and setup the session.
+
+ Everything after this will be a new table.
+
+ Args:
+ name: Name of the performed session so far so it is identifiable
+ langchain_asset: The langchain asset to save.
+ finish: Whether to finish the run.
+
+ Returns:
+ None
+ """
+ pd = import_pandas()
+ clearml = import_clearml()
+
+ # Log the action records
+ self.logger.report_table(
+ "Action Records", name, table_plot=pd.DataFrame(self.action_records)
+ )
+
+ # Session analysis
+ session_analysis_df = self._create_session_analysis_df()
+ self.logger.report_table(
+ "Session Analysis", name, table_plot=session_analysis_df
+ )
+
+ if self.stream_logs:
+ self.logger.report_text(
+ {
+ "action_records": pd.DataFrame(self.action_records),
+ "session_analysis": session_analysis_df,
+ }
+ )
+
+ if langchain_asset:
+ langchain_asset_path = Path(self.temp_dir.name, "model.json")
+ try:
+ langchain_asset.save(langchain_asset_path)
+ # Create output model and connect it to the task
+ output_model = clearml.OutputModel(
+ task=self.task, config_text=load_json(langchain_asset_path)
+ )
+ output_model.update_weights(
+ weights_filename=str(langchain_asset_path),
+ auto_delete_file=False,
+ target_filename=name,
+ )
+ except ValueError:
+ langchain_asset.save_agent(langchain_asset_path)
+ output_model = clearml.OutputModel(
+ task=self.task, config_text=load_json(langchain_asset_path)
+ )
+ output_model.update_weights(
+ weights_filename=str(langchain_asset_path),
+ auto_delete_file=False,
+ target_filename=name,
+ )
+ except NotImplementedError as e:
+ print("Could not save model.")
+ print(repr(e))
+ pass
+
+ # Cleanup after adding everything to ClearML
+ self.task.flush(wait_for_uploads=True)
+ self.temp_dir.cleanup()
+ self.temp_dir = tempfile.TemporaryDirectory()
+ self.reset_callback_meta()
+
+ if finish:
+ self.task.close()
diff --git a/libs/community/langchain_community/callbacks/comet_ml_callback.py b/libs/community/langchain_community/callbacks/comet_ml_callback.py
new file mode 100644
index 00000000000..5493c947ae8
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/comet_ml_callback.py
@@ -0,0 +1,645 @@
+import tempfile
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Sequence
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import Generation, LLMResult
+
+import langchain_community
+from langchain_community.callbacks.utils import (
+ BaseMetadataCallbackHandler,
+ flatten_dict,
+ import_pandas,
+ import_spacy,
+ import_textstat,
+)
+
+LANGCHAIN_MODEL_NAME = "langchain-model"
+
+
+def import_comet_ml() -> Any:
+ """Import comet_ml and raise an error if it is not installed."""
+ try:
+ import comet_ml # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "To use the comet_ml callback manager you need to have the "
+ "`comet_ml` python package installed. Please install it with"
+ " `pip install comet_ml`"
+ )
+ return comet_ml
+
+
+def _get_experiment(
+ workspace: Optional[str] = None, project_name: Optional[str] = None
+) -> Any:
+ comet_ml = import_comet_ml()
+
+ experiment = comet_ml.Experiment( # type: ignore
+ workspace=workspace,
+ project_name=project_name,
+ )
+
+ return experiment
+
+
+def _fetch_text_complexity_metrics(text: str) -> dict:
+ textstat = import_textstat()
+ text_complexity_metrics = {
+ "flesch_reading_ease": textstat.flesch_reading_ease(text),
+ "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
+ "smog_index": textstat.smog_index(text),
+ "coleman_liau_index": textstat.coleman_liau_index(text),
+ "automated_readability_index": textstat.automated_readability_index(text),
+ "dale_chall_readability_score": textstat.dale_chall_readability_score(text),
+ "difficult_words": textstat.difficult_words(text),
+ "linsear_write_formula": textstat.linsear_write_formula(text),
+ "gunning_fog": textstat.gunning_fog(text),
+ "text_standard": textstat.text_standard(text),
+ "fernandez_huerta": textstat.fernandez_huerta(text),
+ "szigriszt_pazos": textstat.szigriszt_pazos(text),
+ "gutierrez_polini": textstat.gutierrez_polini(text),
+ "crawford": textstat.crawford(text),
+ "gulpease_index": textstat.gulpease_index(text),
+ "osman": textstat.osman(text),
+ }
+ return text_complexity_metrics
+
+
+def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
+ pd = import_pandas()
+ metrics_df = pd.DataFrame(metrics)
+ metrics_summary = metrics_df.describe()
+
+ return metrics_summary.to_dict()
+
+
+class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
+ """Callback Handler that logs to Comet.
+
+ Parameters:
+ job_type (str): The type of comet_ml task such as "inference",
+ "testing" or "qc"
+ project_name (str): The comet_ml project name
+ tags (list): Tags to add to the task
+ task_name (str): Name of the comet_ml task
+ visualize (bool): Whether to visualize the run.
+ complexity_metrics (bool): Whether to log complexity metrics
+ stream_logs (bool): Whether to stream callback actions to Comet
+
+ This handler will utilize the associated callback method and formats
+ the input of each callback function with metadata regarding the state of LLM run,
+ and adds the response to the list of records for both the {method}_records and
+ action. It then logs the response to Comet.
+ """
+
+ def __init__(
+ self,
+ task_type: Optional[str] = "inference",
+ workspace: Optional[str] = None,
+ project_name: Optional[str] = None,
+ tags: Optional[Sequence] = None,
+ name: Optional[str] = None,
+ visualizations: Optional[List[str]] = None,
+ complexity_metrics: bool = False,
+ custom_metrics: Optional[Callable] = None,
+ stream_logs: bool = True,
+ ) -> None:
+ """Initialize callback handler."""
+
+ self.comet_ml = import_comet_ml()
+ super().__init__()
+
+ self.task_type = task_type
+ self.workspace = workspace
+ self.project_name = project_name
+ self.tags = tags
+ self.visualizations = visualizations
+ self.complexity_metrics = complexity_metrics
+ self.custom_metrics = custom_metrics
+ self.stream_logs = stream_logs
+ self.temp_dir = tempfile.TemporaryDirectory()
+
+ self.experiment = _get_experiment(workspace, project_name)
+ self.experiment.log_other("Created from", "langchain")
+ if tags:
+ self.experiment.add_tags(tags)
+ self.name = name
+ if self.name:
+ self.experiment.set_name(self.name)
+
+ warning = (
+ "The comet_ml callback is currently in beta and is subject to change "
+ "based on updates to `langchain`. Please report any issues to "
+ "https://github.com/comet-ml/issue-tracking/issues with the tag "
+ "`langchain`."
+ )
+ self.comet_ml.LOGGER.warning(warning)
+
+ self.callback_columns: list = []
+ self.action_records: list = []
+ self.complexity_metrics = complexity_metrics
+ if self.visualizations:
+ spacy = import_spacy()
+ self.nlp = spacy.load("en_core_web_sm")
+ else:
+ self.nlp = None
+
+ def _init_resp(self) -> Dict:
+ return {k: None for k in self.callback_columns}
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Run when LLM starts."""
+ self.step += 1
+ self.llm_starts += 1
+ self.starts += 1
+
+ metadata = self._init_resp()
+ metadata.update({"action": "on_llm_start"})
+ metadata.update(flatten_dict(serialized))
+ metadata.update(self.get_custom_callback_meta())
+
+ for prompt in prompts:
+ prompt_resp = deepcopy(metadata)
+ prompt_resp["prompts"] = prompt
+ self.on_llm_start_records.append(prompt_resp)
+ self.action_records.append(prompt_resp)
+
+ if self.stream_logs:
+ self._log_stream(prompt, metadata, self.step)
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run when LLM generates a new token."""
+ self.step += 1
+ self.llm_streams += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_llm_new_token", "token": token})
+ resp.update(self.get_custom_callback_meta())
+
+ self.action_records.append(resp)
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends running."""
+ self.step += 1
+ self.llm_ends += 1
+ self.ends += 1
+
+ metadata = self._init_resp()
+ metadata.update({"action": "on_llm_end"})
+ metadata.update(flatten_dict(response.llm_output or {}))
+ metadata.update(self.get_custom_callback_meta())
+
+ output_complexity_metrics = []
+ output_custom_metrics = []
+
+ for prompt_idx, generations in enumerate(response.generations):
+ for gen_idx, generation in enumerate(generations):
+ text = generation.text
+
+ generation_resp = deepcopy(metadata)
+ generation_resp.update(flatten_dict(generation.dict()))
+
+ complexity_metrics = self._get_complexity_metrics(text)
+ if complexity_metrics:
+ output_complexity_metrics.append(complexity_metrics)
+ generation_resp.update(complexity_metrics)
+
+ custom_metrics = self._get_custom_metrics(
+ generation, prompt_idx, gen_idx
+ )
+ if custom_metrics:
+ output_custom_metrics.append(custom_metrics)
+ generation_resp.update(custom_metrics)
+
+ if self.stream_logs:
+ self._log_stream(text, metadata, self.step)
+
+ self.action_records.append(generation_resp)
+ self.on_llm_end_records.append(generation_resp)
+
+ self._log_text_metrics(output_complexity_metrics, step=self.step)
+ self._log_text_metrics(output_custom_metrics, step=self.step)
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when LLM errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Run when chain starts running."""
+ self.step += 1
+ self.chain_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_chain_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ for chain_input_key, chain_input_val in inputs.items():
+ if isinstance(chain_input_val, str):
+ input_resp = deepcopy(resp)
+ if self.stream_logs:
+ self._log_stream(chain_input_val, resp, self.step)
+ input_resp.update({chain_input_key: chain_input_val})
+ self.action_records.append(input_resp)
+
+ else:
+ self.comet_ml.LOGGER.warning(
+ f"Unexpected data format provided! "
+ f"Input Value for {chain_input_key} will not be logged"
+ )
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Run when chain ends running."""
+ self.step += 1
+ self.chain_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_chain_end"})
+ resp.update(self.get_custom_callback_meta())
+
+ for chain_output_key, chain_output_val in outputs.items():
+ if isinstance(chain_output_val, str):
+ output_resp = deepcopy(resp)
+ if self.stream_logs:
+ self._log_stream(chain_output_val, resp, self.step)
+ output_resp.update({chain_output_key: chain_output_val})
+ self.action_records.append(output_resp)
+ else:
+ self.comet_ml.LOGGER.warning(
+ f"Unexpected data format provided! "
+ f"Output Value for {chain_output_key} will not be logged"
+ )
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when chain errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ """Run when tool starts running."""
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_tool_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+ if self.stream_logs:
+ self._log_stream(input_str, resp, self.step)
+
+ resp.update({"input_str": input_str})
+ self.action_records.append(resp)
+
+ def on_tool_end(self, output: str, **kwargs: Any) -> None:
+ """Run when tool ends running."""
+ self.step += 1
+ self.tool_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_tool_end"})
+ resp.update(self.get_custom_callback_meta())
+ if self.stream_logs:
+ self._log_stream(output, resp, self.step)
+
+ resp.update({"output": output})
+ self.action_records.append(resp)
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when tool errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """
+ Run when agent is ending.
+ """
+ self.step += 1
+ self.text_ctr += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_text"})
+ resp.update(self.get_custom_callback_meta())
+ if self.stream_logs:
+ self._log_stream(text, resp, self.step)
+
+ resp.update({"text": text})
+ self.action_records.append(resp)
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Run when agent ends running."""
+ self.step += 1
+ self.agent_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ output = finish.return_values["output"]
+ log = finish.log
+
+ resp.update({"action": "on_agent_finish", "log": log})
+ resp.update(self.get_custom_callback_meta())
+ if self.stream_logs:
+ self._log_stream(output, resp, self.step)
+
+ resp.update({"output": output})
+ self.action_records.append(resp)
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Run on agent action."""
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ tool = action.tool
+ tool_input = str(action.tool_input)
+ log = action.log
+
+ resp = self._init_resp()
+ resp.update({"action": "on_agent_action", "log": log, "tool": tool})
+ resp.update(self.get_custom_callback_meta())
+ if self.stream_logs:
+ self._log_stream(tool_input, resp, self.step)
+
+ resp.update({"tool_input": tool_input})
+ self.action_records.append(resp)
+
+ def _get_complexity_metrics(self, text: str) -> dict:
+ """Compute text complexity metrics using textstat.
+
+ Parameters:
+ text (str): The text to analyze.
+
+ Returns:
+ (dict): A dictionary containing the complexity metrics.
+ """
+ resp = {}
+ if self.complexity_metrics:
+ text_complexity_metrics = _fetch_text_complexity_metrics(text)
+ resp.update(text_complexity_metrics)
+
+ return resp
+
+ def _get_custom_metrics(
+ self, generation: Generation, prompt_idx: int, gen_idx: int
+ ) -> dict:
+ """Compute Custom Metrics for an LLM Generated Output
+
+ Args:
+ generation (LLMResult): Output generation from an LLM
+ prompt_idx (int): List index of the input prompt
+ gen_idx (int): List index of the generated output
+
+ Returns:
+ dict: A dictionary containing the custom metrics.
+ """
+
+ resp = {}
+ if self.custom_metrics:
+ custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx)
+ resp.update(custom_metrics)
+
+ return resp
+
+ def flush_tracker(
+ self,
+ langchain_asset: Any = None,
+ task_type: Optional[str] = "inference",
+ workspace: Optional[str] = None,
+ project_name: Optional[str] = "comet-langchain-demo",
+ tags: Optional[Sequence] = None,
+ name: Optional[str] = None,
+ visualizations: Optional[List[str]] = None,
+ complexity_metrics: bool = False,
+ custom_metrics: Optional[Callable] = None,
+ finish: bool = False,
+ reset: bool = False,
+ ) -> None:
+ """Flush the tracker and setup the session.
+
+ Everything after this will be a new table.
+
+ Args:
+ name: Name of the performed session so far so it is identifiable
+ langchain_asset: The langchain asset to save.
+ finish: Whether to finish the run.
+
+ Returns:
+ None
+ """
+ self._log_session(langchain_asset)
+
+ if langchain_asset:
+ try:
+ self._log_model(langchain_asset)
+ except Exception:
+ self.comet_ml.LOGGER.error(
+ "Failed to export agent or LLM to Comet",
+ exc_info=True,
+ extra={"show_traceback": True},
+ )
+
+ if finish:
+ self.experiment.end()
+
+ if reset:
+ self._reset(
+ task_type,
+ workspace,
+ project_name,
+ tags,
+ name,
+ visualizations,
+ complexity_metrics,
+ custom_metrics,
+ )
+
+ def _log_stream(self, prompt: str, metadata: dict, step: int) -> None:
+ self.experiment.log_text(prompt, metadata=metadata, step=step)
+
+ def _log_model(self, langchain_asset: Any) -> None:
+ model_parameters = self._get_llm_parameters(langchain_asset)
+ self.experiment.log_parameters(model_parameters, prefix="model")
+
+ langchain_asset_path = Path(self.temp_dir.name, "model.json")
+ model_name = self.name if self.name else LANGCHAIN_MODEL_NAME
+
+ try:
+ if hasattr(langchain_asset, "save"):
+ langchain_asset.save(langchain_asset_path)
+ self.experiment.log_model(model_name, str(langchain_asset_path))
+ except (ValueError, AttributeError, NotImplementedError) as e:
+ if hasattr(langchain_asset, "save_agent"):
+ langchain_asset.save_agent(langchain_asset_path)
+ self.experiment.log_model(model_name, str(langchain_asset_path))
+ else:
+ self.comet_ml.LOGGER.error(
+ f"{e}"
+ " Could not save Langchain Asset "
+ f"for {langchain_asset.__class__.__name__}"
+ )
+
+ def _log_session(self, langchain_asset: Optional[Any] = None) -> None:
+ try:
+ llm_session_df = self._create_session_analysis_dataframe(langchain_asset)
+ # Log the cleaned dataframe as a table
+ self.experiment.log_table("langchain-llm-session.csv", llm_session_df)
+ except Exception:
+ self.comet_ml.LOGGER.warning(
+ "Failed to log session data to Comet",
+ exc_info=True,
+ extra={"show_traceback": True},
+ )
+
+ try:
+ metadata = {"langchain_version": str(langchain_community.__version__)}
+ # Log the langchain low-level records as a JSON file directly
+ self.experiment.log_asset_data(
+ self.action_records, "langchain-action_records.json", metadata=metadata
+ )
+ except Exception:
+ self.comet_ml.LOGGER.warning(
+ "Failed to log session data to Comet",
+ exc_info=True,
+ extra={"show_traceback": True},
+ )
+
+ try:
+ self._log_visualizations(llm_session_df)
+ except Exception:
+ self.comet_ml.LOGGER.warning(
+ "Failed to log visualizations to Comet",
+ exc_info=True,
+ extra={"show_traceback": True},
+ )
+
+ def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None:
+ if not metrics:
+ return
+
+ metrics_summary = _summarize_metrics_for_generated_outputs(metrics)
+ for key, value in metrics_summary.items():
+ self.experiment.log_metrics(value, prefix=key, step=step)
+
+ def _log_visualizations(self, session_df: Any) -> None:
+ if not (self.visualizations and self.nlp):
+ return
+
+ spacy = import_spacy()
+
+ prompts = session_df["prompts"].tolist()
+ outputs = session_df["text"].tolist()
+
+ for idx, (prompt, output) in enumerate(zip(prompts, outputs)):
+ doc = self.nlp(output)
+ sentence_spans = list(doc.sents)
+
+ for visualization in self.visualizations:
+ try:
+ html = spacy.displacy.render(
+ sentence_spans,
+ style=visualization,
+ options={"compact": True},
+ jupyter=False,
+ page=True,
+ )
+ self.experiment.log_asset_data(
+ html,
+ name=f"langchain-viz-{visualization}-{idx}.html",
+ metadata={"prompt": prompt},
+ step=idx,
+ )
+ except Exception as e:
+ self.comet_ml.LOGGER.warning(
+ e, exc_info=True, extra={"show_traceback": True}
+ )
+
+ return
+
+ def _reset(
+ self,
+ task_type: Optional[str] = None,
+ workspace: Optional[str] = None,
+ project_name: Optional[str] = None,
+ tags: Optional[Sequence] = None,
+ name: Optional[str] = None,
+ visualizations: Optional[List[str]] = None,
+ complexity_metrics: bool = False,
+ custom_metrics: Optional[Callable] = None,
+ ) -> None:
+ _task_type = task_type if task_type else self.task_type
+ _workspace = workspace if workspace else self.workspace
+ _project_name = project_name if project_name else self.project_name
+ _tags = tags if tags else self.tags
+ _name = name if name else self.name
+ _visualizations = visualizations if visualizations else self.visualizations
+ _complexity_metrics = (
+ complexity_metrics if complexity_metrics else self.complexity_metrics
+ )
+ _custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
+
+ self.__init__( # type: ignore
+ task_type=_task_type,
+ workspace=_workspace,
+ project_name=_project_name,
+ tags=_tags,
+ name=_name,
+ visualizations=_visualizations,
+ complexity_metrics=_complexity_metrics,
+ custom_metrics=_custom_metrics,
+ )
+
+ self.reset_callback_meta()
+ self.temp_dir = tempfile.TemporaryDirectory()
+
+ def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
+ pd = import_pandas()
+
+ llm_parameters = self._get_llm_parameters(langchain_asset)
+ num_generations_per_prompt = llm_parameters.get("n", 1)
+
+ llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
+ # Repeat each input row based on the number of outputs generated per prompt
+ llm_start_records_df = llm_start_records_df.loc[
+ llm_start_records_df.index.repeat(num_generations_per_prompt)
+ ].reset_index(drop=True)
+ llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
+
+ llm_session_df = pd.merge(
+ llm_start_records_df,
+ llm_end_records_df,
+ left_index=True,
+ right_index=True,
+ suffixes=["_llm_start", "_llm_end"],
+ )
+
+ return llm_session_df
+
+ def _get_llm_parameters(self, langchain_asset: Any = None) -> dict:
+ if not langchain_asset:
+ return {}
+ try:
+ if hasattr(langchain_asset, "agent"):
+ llm_parameters = langchain_asset.agent.llm_chain.llm.dict()
+ elif hasattr(langchain_asset, "llm_chain"):
+ llm_parameters = langchain_asset.llm_chain.llm.dict()
+ elif hasattr(langchain_asset, "llm"):
+ llm_parameters = langchain_asset.llm.dict()
+ else:
+ llm_parameters = langchain_asset.dict()
+ except Exception:
+ return {}
+
+ return llm_parameters
diff --git a/libs/community/langchain_community/callbacks/confident_callback.py b/libs/community/langchain_community/callbacks/confident_callback.py
new file mode 100644
index 00000000000..d9432e5d567
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/confident_callback.py
@@ -0,0 +1,183 @@
+# flake8: noqa
+import os
+import warnings
+from typing import Any, Dict, List, Optional, Union
+
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.outputs import LLMResult
+
+
+class DeepEvalCallbackHandler(BaseCallbackHandler):
+ """Callback Handler that logs into deepeval.
+
+ Args:
+ implementation_name: name of the `implementation` in deepeval
+ metrics: A list of metrics
+
+ Raises:
+ ImportError: if the `deepeval` package is not installed.
+
+ Examples:
+ >>> from langchain_community.llms import OpenAI
+ >>> from langchain_community.callbacks import DeepEvalCallbackHandler
+ >>> from deepeval.metrics import AnswerRelevancy
+ >>> metric = AnswerRelevancy(minimum_score=0.3)
+ >>> deepeval_callback = DeepEvalCallbackHandler(
+ ... implementation_name="exampleImplementation",
+ ... metrics=[metric],
+ ... )
+ >>> llm = OpenAI(
+ ... temperature=0,
+ ... callbacks=[deepeval_callback],
+ ... verbose=True,
+ ... openai_api_key="API_KEY_HERE",
+ ... )
+ >>> llm.generate([
+ ... "What is the best evaluation tool out there? (no bias at all)",
+ ... ])
+ "Deepeval, no doubt about it."
+ """
+
+ REPO_URL: str = "https://github.com/confident-ai/deepeval"
+ ISSUES_URL: str = f"{REPO_URL}/issues"
+ BLOG_URL: str = "https://docs.confident-ai.com" # noqa: E501
+
+ def __init__(
+ self,
+ metrics: List[Any],
+ implementation_name: Optional[str] = None,
+ ) -> None:
+ """Initializes the `deepevalCallbackHandler`.
+
+ Args:
+ implementation_name: Name of the implementation you want.
+ metrics: What metrics do you want to track?
+
+ Raises:
+ ImportError: if the `deepeval` package is not installed.
+ ConnectionError: if the connection to deepeval fails.
+ """
+
+ super().__init__()
+
+ # Import deepeval (not via `import_deepeval` to keep hints in IDEs)
+ try:
+ import deepeval # ignore: F401,I001
+ except ImportError:
+ raise ImportError(
+ """To use the deepeval callback manager you need to have the
+ `deepeval` Python package installed. Please install it with
+ `pip install deepeval`"""
+ )
+
+ if os.path.exists(".deepeval"):
+ warnings.warn(
+ """You are currently not logging anything to the dashboard, we
+ recommend using `deepeval login`."""
+ )
+
+ # Set the deepeval variables
+ self.implementation_name = implementation_name
+ self.metrics = metrics
+
+ warnings.warn(
+ (
+ "The `DeepEvalCallbackHandler` is currently in beta and is subject to"
+ " change based on updates to `langchain`. Please report any issues to"
+ f" {self.ISSUES_URL} as an `integration` issue."
+ ),
+ )
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Store the prompts"""
+ self.prompts = prompts
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Do nothing when a new token is generated."""
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Log records to deepeval when an LLM ends."""
+ from deepeval.metrics.answer_relevancy import AnswerRelevancy
+ from deepeval.metrics.bias_classifier import UnBiasedMetric
+ from deepeval.metrics.metric import Metric
+ from deepeval.metrics.toxic_classifier import NonToxicMetric
+
+ for metric in self.metrics:
+ for i, generation in enumerate(response.generations):
+ # Here, we only measure the first generation's output
+ output = generation[0].text
+ query = self.prompts[i]
+ if isinstance(metric, AnswerRelevancy):
+ result = metric.measure(
+ output=output,
+ query=query,
+ )
+ print(f"Answer Relevancy: {result}")
+ elif isinstance(metric, UnBiasedMetric):
+ score = metric.measure(output)
+ print(f"Bias Score: {score}")
+ elif isinstance(metric, NonToxicMetric):
+ score = metric.measure(output)
+ print(f"Toxic Score: {score}")
+ else:
+ raise ValueError(
+ f"""Metric {metric.__name__} is not supported by deepeval
+ callbacks."""
+ )
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when LLM outputs an error."""
+ pass
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Do nothing when chain starts"""
+ pass
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Do nothing when chain ends."""
+ pass
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when LLM chain outputs an error."""
+ pass
+
+ def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool starts."""
+ pass
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Do nothing when agent takes a specific action."""
+ pass
+
+ def on_tool_end(
+ self,
+ output: str,
+ observation_prefix: Optional[str] = None,
+ llm_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool ends."""
+ pass
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when tool outputs an error."""
+ pass
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """Do nothing"""
+ pass
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Do nothing"""
+ pass
diff --git a/libs/community/langchain_community/callbacks/context_callback.py b/libs/community/langchain_community/callbacks/context_callback.py
new file mode 100644
index 00000000000..8514976687d
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/context_callback.py
@@ -0,0 +1,194 @@
+"""Callback handler for Context AI"""
+import os
+from typing import Any, Dict, List
+from uuid import UUID
+
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.messages import BaseMessage
+from langchain_core.outputs import LLMResult
+
+
+def import_context() -> Any:
+ """Import the `getcontext` package."""
+ try:
+ import getcontext # noqa: F401
+ from getcontext.generated.models import (
+ Conversation,
+ Message,
+ MessageRole,
+ Rating,
+ )
+ from getcontext.token import Credential # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "To use the context callback manager you need to have the "
+ "`getcontext` python package installed (version >=0.3.0). "
+ "Please install it with `pip install --upgrade python-context`"
+ )
+ return getcontext, Credential, Conversation, Message, MessageRole, Rating
+
+
+class ContextCallbackHandler(BaseCallbackHandler):
+ """Callback Handler that records transcripts to the Context service.
+
+ (https://context.ai).
+
+ Keyword Args:
+ token (optional): The token with which to authenticate requests to Context.
+ Visit https://with.context.ai/settings to generate a token.
+ If not provided, the value of the `CONTEXT_TOKEN` environment
+ variable will be used.
+
+ Raises:
+ ImportError: if the `context-python` package is not installed.
+
+ Chat Example:
+ >>> from langchain_community.llms import ChatOpenAI
+ >>> from langchain_community.callbacks import ContextCallbackHandler
+ >>> context_callback = ContextCallbackHandler(
+ ... token="",
+ ... )
+ >>> chat = ChatOpenAI(
+ ... temperature=0,
+ ... headers={"user_id": "123"},
+ ... callbacks=[context_callback],
+ ... openai_api_key="API_KEY_HERE",
+ ... )
+ >>> messages = [
+ ... SystemMessage(content="You translate English to French."),
+ ... HumanMessage(content="I love programming with LangChain."),
+ ... ]
+ >>> chat(messages)
+
+ Chain Example:
+ >>> from langchain.chains import LLMChain
+ >>> from langchain_community.chat_models import ChatOpenAI
+ >>> from langchain_community.callbacks import ContextCallbackHandler
+ >>> context_callback = ContextCallbackHandler(
+ ... token="",
+ ... )
+ >>> human_message_prompt = HumanMessagePromptTemplate(
+ ... prompt=PromptTemplate(
+ ... template="What is a good name for a company that makes {product}?",
+ ... input_variables=["product"],
+ ... ),
+ ... )
+ >>> chat_prompt_template = ChatPromptTemplate.from_messages(
+ ... [human_message_prompt]
+ ... )
+ >>> callback = ContextCallbackHandler(token)
+ >>> # Note: the same callback object must be shared between the
+ ... LLM and the chain.
+ >>> chat = ChatOpenAI(temperature=0.9, callbacks=[callback])
+ >>> chain = LLMChain(
+ ... llm=chat,
+ ... prompt=chat_prompt_template,
+ ... callbacks=[callback]
+ ... )
+ >>> chain.run("colorful socks")
+ """
+
+ def __init__(self, token: str = "", verbose: bool = False, **kwargs: Any) -> None:
+ (
+ self.context,
+ self.credential,
+ self.conversation_model,
+ self.message_model,
+ self.message_role_model,
+ self.rating_model,
+ ) = import_context()
+
+ token = token or os.environ.get("CONTEXT_TOKEN") or ""
+
+ self.client = self.context.ContextAPI(credential=self.credential(token))
+
+ self.chain_run_id = None
+
+ self.llm_model = None
+
+ self.messages: List[Any] = []
+ self.metadata: Dict[str, str] = {}
+
+ def on_chat_model_start(
+ self,
+ serialized: Dict[str, Any],
+ messages: List[List[BaseMessage]],
+ *,
+ run_id: UUID,
+ **kwargs: Any,
+ ) -> Any:
+ """Run when the chat model is started."""
+ llm_model = kwargs.get("invocation_params", {}).get("model", None)
+ if llm_model is not None:
+ self.metadata["model"] = llm_model
+
+ if len(messages) == 0:
+ return
+
+ for message in messages[0]:
+ role = self.message_role_model.SYSTEM
+ if message.type == "human":
+ role = self.message_role_model.USER
+ elif message.type == "system":
+ role = self.message_role_model.SYSTEM
+ elif message.type == "ai":
+ role = self.message_role_model.ASSISTANT
+
+ self.messages.append(
+ self.message_model(
+ message=message.content,
+ role=role,
+ )
+ )
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends."""
+ if len(response.generations) == 0 or len(response.generations[0]) == 0:
+ return
+
+ if not self.chain_run_id:
+ generation = response.generations[0][0]
+ self.messages.append(
+ self.message_model(
+ message=generation.text,
+ role=self.message_role_model.ASSISTANT,
+ )
+ )
+
+ self._log_conversation()
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Run when chain starts."""
+ self.chain_run_id = kwargs.get("run_id", None)
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Run when chain ends."""
+ self.messages.append(
+ self.message_model(
+ message=outputs["text"],
+ role=self.message_role_model.ASSISTANT,
+ )
+ )
+
+ self._log_conversation()
+
+ self.chain_run_id = None
+
+ def _log_conversation(self) -> None:
+ """Log the conversation to the context API."""
+ if len(self.messages) == 0:
+ return
+
+ self.client.log.conversation_upsert(
+ body={
+ "conversation": self.conversation_model(
+ messages=self.messages,
+ metadata=self.metadata,
+ )
+ }
+ )
+
+ self.messages = []
+ self.metadata = {}
diff --git a/libs/community/langchain_community/callbacks/flyte_callback.py b/libs/community/langchain_community/callbacks/flyte_callback.py
new file mode 100644
index 00000000000..23a8f473430
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/flyte_callback.py
@@ -0,0 +1,371 @@
+"""FlyteKit callback handler."""
+from __future__ import annotations
+
+import logging
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Dict, List, Tuple
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+from langchain_community.callbacks.utils import (
+ BaseMetadataCallbackHandler,
+ flatten_dict,
+ import_pandas,
+ import_spacy,
+ import_textstat,
+)
+
+if TYPE_CHECKING:
+ import flytekit
+ from flytekitplugins.deck import renderer
+
+logger = logging.getLogger(__name__)
+
+
+def import_flytekit() -> Tuple[flytekit, renderer]:
+ """Import flytekit and flytekitplugins-deck-standard."""
+ try:
+ import flytekit # noqa: F401
+ from flytekitplugins.deck import renderer # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "To use the flyte callback manager you need"
+ "to have the `flytekit` and `flytekitplugins-deck-standard`"
+ "packages installed. Please install them with `pip install flytekit`"
+ "and `pip install flytekitplugins-deck-standard`."
+ )
+ return flytekit, renderer
+
+
+def analyze_text(
+ text: str,
+ nlp: Any = None,
+ textstat: Any = None,
+) -> dict:
+ """Analyze text using textstat and spacy.
+
+ Parameters:
+ text (str): The text to analyze.
+ nlp (spacy.lang): The spacy language model to use for visualization.
+
+ Returns:
+ (dict): A dictionary containing the complexity metrics and visualization
+ files serialized to HTML string.
+ """
+ resp: Dict[str, Any] = {}
+ if textstat is not None:
+ text_complexity_metrics = {
+ "flesch_reading_ease": textstat.flesch_reading_ease(text),
+ "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
+ "smog_index": textstat.smog_index(text),
+ "coleman_liau_index": textstat.coleman_liau_index(text),
+ "automated_readability_index": textstat.automated_readability_index(text),
+ "dale_chall_readability_score": textstat.dale_chall_readability_score(text),
+ "difficult_words": textstat.difficult_words(text),
+ "linsear_write_formula": textstat.linsear_write_formula(text),
+ "gunning_fog": textstat.gunning_fog(text),
+ "fernandez_huerta": textstat.fernandez_huerta(text),
+ "szigriszt_pazos": textstat.szigriszt_pazos(text),
+ "gutierrez_polini": textstat.gutierrez_polini(text),
+ "crawford": textstat.crawford(text),
+ "gulpease_index": textstat.gulpease_index(text),
+ "osman": textstat.osman(text),
+ }
+ resp.update({"text_complexity_metrics": text_complexity_metrics})
+ resp.update(text_complexity_metrics)
+
+ if nlp is not None:
+ spacy = import_spacy()
+ doc = nlp(text)
+ dep_out = spacy.displacy.render( # type: ignore
+ doc, style="dep", jupyter=False, page=True
+ )
+ ent_out = spacy.displacy.render( # type: ignore
+ doc, style="ent", jupyter=False, page=True
+ )
+ text_visualizations = {
+ "dependency_tree": dep_out,
+ "entities": ent_out,
+ }
+ resp.update(text_visualizations)
+
+ return resp
+
+
+class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
+ """This callback handler that is used within a Flyte task."""
+
+ def __init__(self) -> None:
+ """Initialize callback handler."""
+ flytekit, renderer = import_flytekit()
+ self.pandas = import_pandas()
+
+ self.textstat = None
+ try:
+ self.textstat = import_textstat()
+ except ImportError:
+ logger.warning(
+ "Textstat library is not installed. \
+ It may result in the inability to log \
+ certain metrics that can be captured with Textstat."
+ )
+
+ spacy = None
+ try:
+ spacy = import_spacy()
+ except ImportError:
+ logger.warning(
+ "Spacy library is not installed. \
+ It may result in the inability to log \
+ certain metrics that can be captured with Spacy."
+ )
+
+ super().__init__()
+
+ self.nlp = None
+ if spacy:
+ try:
+ self.nlp = spacy.load("en_core_web_sm")
+ except OSError:
+ logger.warning(
+ "FlyteCallbackHandler uses spacy's en_core_web_sm model"
+ " for certain metrics. To download,"
+ " run the following command in your terminal:"
+ " `python -m spacy download en_core_web_sm`"
+ )
+
+ self.table_renderer = renderer.TableRenderer
+ self.markdown_renderer = renderer.MarkdownRenderer
+
+ self.deck = flytekit.Deck(
+ "LangChain Metrics",
+ self.markdown_renderer().to_html("## LangChain Metrics"),
+ )
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Run when LLM starts."""
+
+ self.step += 1
+ self.llm_starts += 1
+ self.starts += 1
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_llm_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ prompt_responses = []
+ for prompt in prompts:
+ prompt_responses.append(prompt)
+
+ resp.update({"prompts": prompt_responses})
+
+ self.deck.append(self.markdown_renderer().to_html("### LLM Start"))
+ self.deck.append(
+ self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
+ )
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run when LLM generates a new token."""
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends running."""
+ self.step += 1
+ self.llm_ends += 1
+ self.ends += 1
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_llm_end"})
+ resp.update(flatten_dict(response.llm_output or {}))
+ resp.update(self.get_custom_callback_meta())
+
+ self.deck.append(self.markdown_renderer().to_html("### LLM End"))
+ self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp])))
+
+ for generations in response.generations:
+ for generation in generations:
+ generation_resp = deepcopy(resp)
+ generation_resp.update(flatten_dict(generation.dict()))
+ if self.nlp or self.textstat:
+ generation_resp.update(
+ analyze_text(
+ generation.text, nlp=self.nlp, textstat=self.textstat
+ )
+ )
+
+ complexity_metrics: Dict[str, float] = generation_resp.pop(
+ "text_complexity_metrics"
+ ) # type: ignore # noqa: E501
+ self.deck.append(
+ self.markdown_renderer().to_html("#### Text Complexity Metrics")
+ )
+ self.deck.append(
+ self.table_renderer().to_html(
+ self.pandas.DataFrame([complexity_metrics])
+ )
+ + "\n"
+ )
+
+ dependency_tree = generation_resp["dependency_tree"]
+ self.deck.append(
+ self.markdown_renderer().to_html("#### Dependency Tree")
+ )
+ self.deck.append(dependency_tree)
+
+ entities = generation_resp["entities"]
+ self.deck.append(self.markdown_renderer().to_html("#### Entities"))
+ self.deck.append(entities)
+ else:
+ self.deck.append(
+ self.markdown_renderer().to_html("#### Generated Response")
+ )
+ self.deck.append(self.markdown_renderer().to_html(generation.text))
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when LLM errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Run when chain starts running."""
+ self.step += 1
+ self.chain_starts += 1
+ self.starts += 1
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_chain_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
+ input_resp = deepcopy(resp)
+ input_resp["inputs"] = chain_input
+
+ self.deck.append(self.markdown_renderer().to_html("### Chain Start"))
+ self.deck.append(
+ self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n"
+ )
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Run when chain ends running."""
+ self.step += 1
+ self.chain_ends += 1
+ self.ends += 1
+
+ resp: Dict[str, Any] = {}
+ chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
+ resp.update({"action": "on_chain_end", "outputs": chain_output})
+ resp.update(self.get_custom_callback_meta())
+
+ self.deck.append(self.markdown_renderer().to_html("### Chain End"))
+ self.deck.append(
+ self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
+ )
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when chain errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ """Run when tool starts running."""
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_tool_start", "input_str": input_str})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ self.deck.append(self.markdown_renderer().to_html("### Tool Start"))
+ self.deck.append(
+ self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
+ )
+
+ def on_tool_end(self, output: str, **kwargs: Any) -> None:
+ """Run when tool ends running."""
+ self.step += 1
+ self.tool_ends += 1
+ self.ends += 1
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_tool_end", "output": output})
+ resp.update(self.get_custom_callback_meta())
+
+ self.deck.append(self.markdown_renderer().to_html("### Tool End"))
+ self.deck.append(
+ self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
+ )
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when tool errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """
+ Run when agent is ending.
+ """
+ self.step += 1
+ self.text_ctr += 1
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_text", "text": text})
+ resp.update(self.get_custom_callback_meta())
+
+ self.deck.append(self.markdown_renderer().to_html("### On Text"))
+ self.deck.append(
+ self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
+ )
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Run when agent ends running."""
+ self.step += 1
+ self.agent_ends += 1
+ self.ends += 1
+
+ resp: Dict[str, Any] = {}
+ resp.update(
+ {
+ "action": "on_agent_finish",
+ "output": finish.return_values["output"],
+ "log": finish.log,
+ }
+ )
+ resp.update(self.get_custom_callback_meta())
+
+ self.deck.append(self.markdown_renderer().to_html("### Agent Finish"))
+ self.deck.append(
+ self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
+ )
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Run on agent action."""
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp: Dict[str, Any] = {}
+ resp.update(
+ {
+ "action": "on_agent_action",
+ "tool": action.tool,
+ "tool_input": action.tool_input,
+ "log": action.log,
+ }
+ )
+ resp.update(self.get_custom_callback_meta())
+
+ self.deck.append(self.markdown_renderer().to_html("### Agent Action"))
+ self.deck.append(
+ self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
+ )
diff --git a/libs/community/langchain_community/callbacks/human.py b/libs/community/langchain_community/callbacks/human.py
new file mode 100644
index 00000000000..64ea01f99f5
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/human.py
@@ -0,0 +1,88 @@
+from typing import Any, Awaitable, Callable, Dict, Optional
+from uuid import UUID
+
+from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler
+
+
+def _default_approve(_input: str) -> bool:
+ msg = (
+ "Do you approve of the following input? "
+ "Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
+ )
+ msg += "\n\n" + _input + "\n"
+ resp = input(msg)
+ return resp.lower() in ("yes", "y")
+
+
+async def _adefault_approve(_input: str) -> bool:
+ msg = (
+ "Do you approve of the following input? "
+ "Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
+ )
+ msg += "\n\n" + _input + "\n"
+ resp = input(msg)
+ return resp.lower() in ("yes", "y")
+
+
+def _default_true(_: Dict[str, Any]) -> bool:
+ return True
+
+
+class HumanRejectedException(Exception):
+ """Exception to raise when a person manually review and rejects a value."""
+
+
+class HumanApprovalCallbackHandler(BaseCallbackHandler):
+ """Callback for manually validating values."""
+
+ raise_error: bool = True
+
+ def __init__(
+ self,
+ approve: Callable[[Any], bool] = _default_approve,
+ should_check: Callable[[Dict[str, Any]], bool] = _default_true,
+ ):
+ self._approve = approve
+ self._should_check = should_check
+
+ def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self._should_check(serialized) and not self._approve(input_str):
+ raise HumanRejectedException(
+ f"Inputs {input_str} to tool {serialized} were rejected."
+ )
+
+
+class AsyncHumanApprovalCallbackHandler(AsyncCallbackHandler):
+ """Asynchronous callback for manually validating values."""
+
+ raise_error: bool = True
+
+ def __init__(
+ self,
+ approve: Callable[[Any], Awaitable[bool]] = _adefault_approve,
+ should_check: Callable[[Dict[str, Any]], bool] = _default_true,
+ ):
+ self._approve = approve
+ self._should_check = should_check
+
+ async def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self._should_check(serialized) and not await self._approve(input_str):
+ raise HumanRejectedException(
+ f"Inputs {input_str} to tool {serialized} were rejected."
+ )
diff --git a/libs/community/langchain_community/callbacks/infino_callback.py b/libs/community/langchain_community/callbacks/infino_callback.py
new file mode 100644
index 00000000000..57d756948ef
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/infino_callback.py
@@ -0,0 +1,266 @@
+import time
+from typing import Any, Dict, List, Optional, cast
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.messages import BaseMessage
+from langchain_core.outputs import LLMResult
+
+
+def import_infino() -> Any:
+ """Import the infino client."""
+ try:
+ from infinopy import InfinoClient
+ except ImportError:
+ raise ImportError(
+ "To use the Infino callbacks manager you need to have the"
+ " `infinopy` python package installed."
+ "Please install it with `pip install infinopy`"
+ )
+ return InfinoClient()
+
+
+def import_tiktoken() -> Any:
+ """Import tiktoken for counting tokens for OpenAI models."""
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "To use the ChatOpenAI model with Infino callback manager, you need to "
+ "have the `tiktoken` python package installed."
+ "Please install it with `pip install tiktoken`"
+ )
+ return tiktoken
+
+
+def get_num_tokens(string: str, openai_model_name: str) -> int:
+ """Calculate num tokens for OpenAI with tiktoken package.
+
+ Official documentation: https://github.com/openai/openai-cookbook/blob/main
+ /examples/How_to_count_tokens_with_tiktoken.ipynb
+ """
+ tiktoken = import_tiktoken()
+
+ encoding = tiktoken.encoding_for_model(openai_model_name)
+ num_tokens = len(encoding.encode(string))
+ return num_tokens
+
+
+class InfinoCallbackHandler(BaseCallbackHandler):
+ """Callback Handler that logs to Infino."""
+
+ def __init__(
+ self,
+ model_id: Optional[str] = None,
+ model_version: Optional[str] = None,
+ verbose: bool = False,
+ ) -> None:
+ # Set Infino client
+ self.client = import_infino()
+ self.model_id = model_id
+ self.model_version = model_version
+ self.verbose = verbose
+ self.is_chat_openai_model = False
+ self.chat_openai_model_name = "gpt-3.5-turbo"
+
+ def _send_to_infino(
+ self,
+ key: str,
+ value: Any,
+ is_ts: bool = True,
+ ) -> None:
+ """Send the key-value to Infino.
+
+ Parameters:
+ key (str): the key to send to Infino.
+ value (Any): the value to send to Infino.
+ is_ts (bool): if True, the value is part of a time series, else it
+ is sent as a log message.
+ """
+ payload = {
+ "date": int(time.time()),
+ key: value,
+ "labels": {
+ "model_id": self.model_id,
+ "model_version": self.model_version,
+ },
+ }
+ if self.verbose:
+ print(f"Tracking {key} with Infino: {payload}")
+
+ # Append to Infino time series only if is_ts is True, otherwise
+ # append to Infino log.
+ if is_ts:
+ self.client.append_ts(payload)
+ else:
+ self.client.append_log(payload)
+
+ def on_llm_start(
+ self,
+ serialized: Dict[str, Any],
+ prompts: List[str],
+ **kwargs: Any,
+ ) -> None:
+ """Log the prompts to Infino, and set start time and error flag."""
+ for prompt in prompts:
+ self._send_to_infino("prompt", prompt, is_ts=False)
+
+ # Set the error flag to indicate no error (this will get overridden
+ # in on_llm_error if an error occurs).
+ self.error = 0
+
+ # Set the start time (so that we can calculate the request
+ # duration in on_llm_end).
+ self.start_time = time.time()
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Do nothing when a new token is generated."""
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Log the latency, error, token usage, and response to Infino."""
+ # Calculate and track the request latency.
+ self.end_time = time.time()
+ duration = self.end_time - self.start_time
+ self._send_to_infino("latency", duration)
+
+ # Track success or error flag.
+ self._send_to_infino("error", self.error)
+
+ # Track prompt response.
+ for generations in response.generations:
+ for generation in generations:
+ self._send_to_infino("prompt_response", generation.text, is_ts=False)
+
+ # Track token usage (for non-chat models).
+ if (response.llm_output is not None) and isinstance(response.llm_output, Dict):
+ token_usage = response.llm_output["token_usage"]
+ if token_usage is not None:
+ prompt_tokens = token_usage["prompt_tokens"]
+ total_tokens = token_usage["total_tokens"]
+ completion_tokens = token_usage["completion_tokens"]
+ self._send_to_infino("prompt_tokens", prompt_tokens)
+ self._send_to_infino("total_tokens", total_tokens)
+ self._send_to_infino("completion_tokens", completion_tokens)
+
+ # Track completion token usage (for openai chat models).
+ if self.is_chat_openai_model:
+ messages = " ".join(
+ generation.message.content # type: ignore[attr-defined]
+ for generation in generations
+ )
+ completion_tokens = get_num_tokens(
+ messages, openai_model_name=self.chat_openai_model_name
+ )
+ self._send_to_infino("completion_tokens", completion_tokens)
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Set the error flag."""
+ self.error = 1
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Do nothing when LLM chain starts."""
+ pass
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Do nothing when LLM chain ends."""
+ pass
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Need to log the error."""
+ pass
+
+ def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool starts."""
+ pass
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Do nothing when agent takes a specific action."""
+ pass
+
+ def on_tool_end(
+ self,
+ output: str,
+ observation_prefix: Optional[str] = None,
+ llm_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool ends."""
+ pass
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when tool outputs an error."""
+ pass
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """Do nothing."""
+ pass
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Do nothing."""
+ pass
+
+ def on_chat_model_start(
+ self,
+ serialized: Dict[str, Any],
+ messages: List[List[BaseMessage]],
+ **kwargs: Any,
+ ) -> None:
+ """Run when LLM starts running."""
+
+ # Currently, for chat models, we only support input prompts for ChatOpenAI.
+ # Check if this model is a ChatOpenAI model.
+ values = serialized.get("id")
+ if values:
+ for value in values:
+ if value == "ChatOpenAI":
+ self.is_chat_openai_model = True
+ break
+
+ # Track prompt tokens for ChatOpenAI model.
+ if self.is_chat_openai_model:
+ invocation_params = kwargs.get("invocation_params")
+ if invocation_params:
+ model_name = invocation_params.get("model_name")
+ if model_name:
+ self.chat_openai_model_name = model_name
+ prompt_tokens = 0
+ for message_list in messages:
+ message_string = " ".join(
+ cast(str, msg.content) for msg in message_list
+ )
+ num_tokens = get_num_tokens(
+ message_string,
+ openai_model_name=self.chat_openai_model_name,
+ )
+ prompt_tokens += num_tokens
+
+ self._send_to_infino("prompt_tokens", prompt_tokens)
+
+ if self.verbose:
+ print(
+ f"on_chat_model_start: is_chat_openai_model= \
+ {self.is_chat_openai_model}, \
+ chat_openai_model_name={self.chat_openai_model_name}"
+ )
+
+ # Send the prompt to infino
+ prompt = " ".join(
+ cast(str, msg.content) for sublist in messages for msg in sublist
+ )
+ self._send_to_infino("prompt", prompt, is_ts=False)
+
+ # Set the error flag to indicate no error (this will get overridden
+ # in on_llm_error if an error occurs).
+ self.error = 0
+
+ # Set the start time (so that we can calculate the request
+ # duration in on_llm_end).
+ self.start_time = time.time()
diff --git a/libs/community/langchain_community/callbacks/labelstudio_callback.py b/libs/community/langchain_community/callbacks/labelstudio_callback.py
new file mode 100644
index 00000000000..73954820b23
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/labelstudio_callback.py
@@ -0,0 +1,390 @@
+import os
+import warnings
+from datetime import datetime
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Union
+from uuid import UUID
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.messages import BaseMessage, ChatMessage
+from langchain_core.outputs import Generation, LLMResult
+
+
+class LabelStudioMode(Enum):
+ """Label Studio mode enumerator."""
+
+ PROMPT = "prompt"
+ CHAT = "chat"
+
+
+def get_default_label_configs(
+ mode: Union[str, LabelStudioMode],
+) -> Tuple[str, LabelStudioMode]:
+ """Get default Label Studio configs for the given mode.
+
+ Parameters:
+ mode: Label Studio mode ("prompt" or "chat")
+
+ Returns: Tuple of Label Studio config and mode
+ """
+ _default_label_configs = {
+ LabelStudioMode.PROMPT.value: """
+
+
+
+
+
+
+
+
+
+
+""",
+ LabelStudioMode.CHAT.value: """
+
+
+
+
+
+
+
+
+""",
+ }
+
+ if isinstance(mode, str):
+ mode = LabelStudioMode(mode)
+
+ return _default_label_configs[mode.value], mode
+
+
+class LabelStudioCallbackHandler(BaseCallbackHandler):
+ """Label Studio callback handler.
+ Provides the ability to send predictions to Label Studio
+ for human evaluation, feedback and annotation.
+
+ Parameters:
+ api_key: Label Studio API key
+ url: Label Studio URL
+ project_id: Label Studio project ID
+ project_name: Label Studio project name
+ project_config: Label Studio project config (XML)
+ mode: Label Studio mode ("prompt" or "chat")
+
+ Examples:
+ >>> from langchain_community.llms import OpenAI
+ >>> from langchain_community.callbacks import LabelStudioCallbackHandler
+ >>> handler = LabelStudioCallbackHandler(
+ ... api_key='',
+ ... url='http://localhost:8080',
+ ... project_name='LangChain-%Y-%m-%d',
+ ... mode='prompt'
+ ... )
+ >>> llm = OpenAI(callbacks=[handler])
+ >>> llm.predict('Tell me a story about a dog.')
+ """
+
+ DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"
+
+ def __init__(
+ self,
+ api_key: Optional[str] = None,
+ url: Optional[str] = None,
+ project_id: Optional[int] = None,
+ project_name: str = DEFAULT_PROJECT_NAME,
+ project_config: Optional[str] = None,
+ mode: Union[str, LabelStudioMode] = LabelStudioMode.PROMPT,
+ ):
+ super().__init__()
+
+ # Import LabelStudio SDK
+ try:
+ import label_studio_sdk as ls
+ except ImportError:
+ raise ImportError(
+ f"You're using {self.__class__.__name__} in your code,"
+ f" but you don't have the LabelStudio SDK "
+ f"Python package installed or upgraded to the latest version. "
+ f"Please run `pip install -U label-studio-sdk`"
+ f" before using this callback."
+ )
+
+ # Check if Label Studio API key is provided
+ if not api_key:
+ if os.getenv("LABEL_STUDIO_API_KEY"):
+ api_key = str(os.getenv("LABEL_STUDIO_API_KEY"))
+ else:
+ raise ValueError(
+ f"You're using {self.__class__.__name__} in your code,"
+ f" Label Studio API key is not provided. "
+ f"Please provide Label Studio API key: "
+ f"go to the Label Studio instance, navigate to "
+ f"Account & Settings -> Access Token and copy the key. "
+ f"Use the key as a parameter for the callback: "
+ f"{self.__class__.__name__}"
+ f"(label_studio_api_key='', ...) or "
+ f"set the environment variable LABEL_STUDIO_API_KEY="
+ )
+ self.api_key = api_key
+
+ if not url:
+ if os.getenv("LABEL_STUDIO_URL"):
+ url = os.getenv("LABEL_STUDIO_URL")
+ else:
+ warnings.warn(
+ f"Label Studio URL is not provided, "
+ f"using default URL: {ls.LABEL_STUDIO_DEFAULT_URL}"
+ f"If you want to provide your own URL, use the parameter: "
+ f"{self.__class__.__name__}"
+ f"(label_studio_url='', ...) "
+ f"or set the environment variable LABEL_STUDIO_URL="
+ )
+ url = ls.LABEL_STUDIO_DEFAULT_URL
+ self.url = url
+
+ # Maps run_id to prompts
+ self.payload: Dict[str, Dict] = {}
+
+ self.ls_client = ls.Client(url=self.url, api_key=self.api_key)
+ self.project_name = project_name
+ if project_config:
+ self.project_config = project_config
+ self.mode = None
+ else:
+ self.project_config, self.mode = get_default_label_configs(mode)
+
+ self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
+ if self.project_id is not None:
+ self.ls_project = self.ls_client.get_project(int(self.project_id))
+ else:
+ project_title = datetime.today().strftime(self.project_name)
+ existing_projects = self.ls_client.get_projects(title=project_title)
+ if existing_projects:
+ self.ls_project = existing_projects[0]
+ self.project_id = self.ls_project.id
+ else:
+ self.ls_project = self.ls_client.create_project(
+ title=project_title, label_config=self.project_config
+ )
+ self.project_id = self.ls_project.id
+ self.parsed_label_config = self.ls_project.parsed_label_config
+
+ # Find the first TextArea tag
+ # "from_name", "to_name", "value" will be used to create predictions
+ self.from_name, self.to_name, self.value, self.input_type = (
+ None,
+ None,
+ None,
+ None,
+ )
+ for tag_name, tag_info in self.parsed_label_config.items():
+ if tag_info["type"] == "TextArea":
+ self.from_name = tag_name
+ self.to_name = tag_info["to_name"][0]
+ self.value = tag_info["inputs"][0]["value"]
+ self.input_type = tag_info["inputs"][0]["type"]
+ break
+ if not self.from_name:
+ error_message = (
+ f'Label Studio project "{self.project_name}" '
+ f"does not have a TextArea tag. "
+ f"Please add a TextArea tag to the project."
+ )
+ if self.mode == LabelStudioMode.PROMPT:
+ error_message += (
+ "\nHINT: go to project Settings -> "
+ "Labeling Interface -> Browse Templates"
+ ' and select "Generative AI -> '
+ 'Supervised Language Model Fine-tuning" template.'
+ )
+ else:
+ error_message += (
+ "\nHINT: go to project Settings -> "
+ "Labeling Interface -> Browse Templates"
+ " and check available templates under "
+ '"Generative AI" section.'
+ )
+ raise ValueError(error_message)
+
+ def add_prompts_generations(
+ self, run_id: str, generations: List[List[Generation]]
+ ) -> None:
+ # Create tasks in Label Studio
+ tasks = []
+ prompts = self.payload[run_id]["prompts"]
+ model_version = (
+ self.payload[run_id]["kwargs"]
+ .get("invocation_params", {})
+ .get("model_name")
+ )
+ for prompt, generation in zip(prompts, generations):
+ tasks.append(
+ {
+ "data": {
+ self.value: prompt,
+ "run_id": run_id,
+ },
+ "predictions": [
+ {
+ "result": [
+ {
+ "from_name": self.from_name,
+ "to_name": self.to_name,
+ "type": "textarea",
+ "value": {"text": [g.text for g in generation]},
+ }
+ ],
+ "model_version": model_version,
+ }
+ ],
+ }
+ )
+ self.ls_project.import_tasks(tasks)
+
+ def on_llm_start(
+ self,
+ serialized: Dict[str, Any],
+ prompts: List[str],
+ **kwargs: Any,
+ ) -> None:
+ """Save the prompts in memory when an LLM starts."""
+ if self.input_type != "Text":
+ raise ValueError(
+ f'\nLabel Studio project "{self.project_name}" '
+ f"has an input type <{self.input_type}>. "
+ f'To make it work with the mode="chat", '
+ f"the input type should be .\n"
+ f"Read more here https://labelstud.io/tags/text"
+ )
+ run_id = str(kwargs["run_id"])
+ self.payload[run_id] = {"prompts": prompts, "kwargs": kwargs}
+
+ def _get_message_role(self, message: BaseMessage) -> str:
+ """Get the role of the message."""
+ if isinstance(message, ChatMessage):
+ return message.role
+ else:
+ return message.__class__.__name__
+
+ def on_chat_model_start(
+ self,
+ serialized: Dict[str, Any],
+ messages: List[List[BaseMessage]],
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> Any:
+ """Save the prompts in memory when an LLM starts."""
+ if self.input_type != "Paragraphs":
+ raise ValueError(
+ f'\nLabel Studio project "{self.project_name}" '
+ f"has an input type <{self.input_type}>. "
+ f'To make it work with the mode="chat", '
+ f"the input type should be .\n"
+ f"Read more here https://labelstud.io/tags/paragraphs"
+ )
+
+ prompts = []
+ for message_list in messages:
+ dialog = []
+ for message in message_list:
+ dialog.append(
+ {
+ "role": self._get_message_role(message),
+ "content": message.content,
+ }
+ )
+ prompts.append(dialog)
+ self.payload[str(run_id)] = {
+ "prompts": prompts,
+ "tags": tags,
+ "metadata": metadata,
+ "run_id": run_id,
+ "parent_run_id": parent_run_id,
+ "kwargs": kwargs,
+ }
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Do nothing when a new token is generated."""
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Create a new Label Studio task for each prompt and generation."""
+ run_id = str(kwargs["run_id"])
+
+ # Submit results to Label Studio
+ self.add_prompts_generations(run_id, response.generations)
+
+ # Pop current run from `self.runs`
+ self.payload.pop(run_id)
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when LLM outputs an error."""
+ pass
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ pass
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ pass
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when LLM chain outputs an error."""
+ pass
+
+ def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool starts."""
+ pass
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Do nothing when agent takes a specific action."""
+ pass
+
+ def on_tool_end(
+ self,
+ output: str,
+ observation_prefix: Optional[str] = None,
+ llm_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Do nothing when tool ends."""
+ pass
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Do nothing when tool outputs an error."""
+ pass
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """Do nothing"""
+ pass
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Do nothing"""
+ pass
diff --git a/libs/community/langchain_community/callbacks/llmonitor_callback.py b/libs/community/langchain_community/callbacks/llmonitor_callback.py
new file mode 100644
index 00000000000..f4f2882dac2
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/llmonitor_callback.py
@@ -0,0 +1,680 @@
+import importlib.metadata
+import logging
+import os
+import traceback
+import warnings
+from contextvars import ContextVar
+from typing import Any, Dict, List, Union, cast
+from uuid import UUID
+
+import requests
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.messages import BaseMessage
+from langchain_core.outputs import LLMResult
+from packaging.version import parse
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_API_URL = "https://app.llmonitor.com"
+
+user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None)
+user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None)
+
+PARAMS_TO_CAPTURE = [
+ "temperature",
+ "top_p",
+ "top_k",
+ "stop",
+ "presence_penalty",
+ "frequence_penalty",
+ "seed",
+ "function_call",
+ "functions",
+ "tools",
+ "tool_choice",
+ "response_format",
+ "max_tokens",
+ "logit_bias",
+]
+
+
+class UserContextManager:
+ """Context manager for LLMonitor user context."""
+
+ def __init__(self, user_id: str, user_props: Any = None) -> None:
+ user_ctx.set(user_id)
+ user_props_ctx.set(user_props)
+
+ def __enter__(self) -> Any:
+ pass
+
+ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> Any:
+ user_ctx.set(None)
+ user_props_ctx.set(None)
+
+
+def identify(user_id: str, user_props: Any = None) -> UserContextManager:
+ """Builds an LLMonitor UserContextManager
+
+ Parameters:
+ - `user_id`: The user id.
+ - `user_props`: The user properties.
+
+ Returns:
+ A context manager that sets the user context.
+ """
+ return UserContextManager(user_id, user_props)
+
+
+def _serialize(obj: Any) -> Union[Dict[str, Any], List[Any], Any]:
+ if hasattr(obj, "to_json"):
+ return obj.to_json()
+
+ if isinstance(obj, dict):
+ return {key: _serialize(value) for key, value in obj.items()}
+
+ if isinstance(obj, list):
+ return [_serialize(element) for element in obj]
+
+ return obj
+
+
+def _parse_input(raw_input: Any) -> Any:
+ if not raw_input:
+ return None
+
+ # if it's an array of 1, just parse the first element
+ if isinstance(raw_input, list) and len(raw_input) == 1:
+ return _parse_input(raw_input[0])
+
+ if not isinstance(raw_input, dict):
+ return _serialize(raw_input)
+
+ input_value = raw_input.get("input")
+ inputs_value = raw_input.get("inputs")
+ question_value = raw_input.get("question")
+ query_value = raw_input.get("query")
+
+ if input_value:
+ return input_value
+ if inputs_value:
+ return inputs_value
+ if question_value:
+ return question_value
+ if query_value:
+ return query_value
+
+ return _serialize(raw_input)
+
+
+def _parse_output(raw_output: dict) -> Any:
+ if not raw_output:
+ return None
+
+ if not isinstance(raw_output, dict):
+ return _serialize(raw_output)
+
+ text_value = raw_output.get("text")
+ output_value = raw_output.get("output")
+ output_text_value = raw_output.get("output_text")
+ answer_value = raw_output.get("answer")
+ result_value = raw_output.get("result")
+
+ if text_value:
+ return text_value
+ if answer_value:
+ return answer_value
+ if output_value:
+ return output_value
+ if output_text_value:
+ return output_text_value
+ if result_value:
+ return result_value
+
+ return _serialize(raw_output)
+
+
+def _parse_lc_role(
+ role: str,
+) -> str:
+ if role == "human":
+ return "user"
+ else:
+ return role
+
+
+def _get_user_id(metadata: Any) -> Any:
+ if user_ctx.get() is not None:
+ return user_ctx.get()
+
+ metadata = metadata or {}
+ user_id = metadata.get("user_id")
+ if user_id is None:
+ user_id = metadata.get("userId") # legacy, to delete in the future
+ return user_id
+
+
+def _get_user_props(metadata: Any) -> Any:
+ if user_props_ctx.get() is not None:
+ return user_props_ctx.get()
+
+ metadata = metadata or {}
+ return metadata.get("user_props", None)
+
+
+def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
+ keys = ["function_call", "tool_calls", "tool_call_id", "name"]
+ parsed = {"text": message.content, "role": _parse_lc_role(message.type)}
+ parsed.update(
+ {
+ key: cast(Any, message.additional_kwargs.get(key))
+ for key in keys
+ if message.additional_kwargs.get(key) is not None
+ }
+ )
+ return parsed
+
+
+def _parse_lc_messages(messages: Union[List[BaseMessage], Any]) -> List[Dict[str, Any]]:
+ return [_parse_lc_message(message) for message in messages]
+
+
+class LLMonitorCallbackHandler(BaseCallbackHandler):
+ """Callback Handler for LLMonitor`.
+
+ #### Parameters:
+ - `app_id`: The app id of the app you want to report to. Defaults to
+ `None`, which means that `LLMONITOR_APP_ID` will be used.
+ - `api_url`: The url of the LLMonitor API. Defaults to `None`,
+ which means that either `LLMONITOR_API_URL` environment variable
+ or `https://app.llmonitor.com` will be used.
+
+ #### Raises:
+ - `ValueError`: if `app_id` is not provided either as an
+ argument or as an environment variable.
+ - `ConnectionError`: if the connection to the API fails.
+
+
+ #### Example:
+ ```python
+ from langchain_community.llms import OpenAI
+ from langchain_community.callbacks import LLMonitorCallbackHandler
+
+ llmonitor_callback = LLMonitorCallbackHandler()
+ llm = OpenAI(callbacks=[llmonitor_callback],
+ metadata={"userId": "user-123"})
+ llm.predict("Hello, how are you?")
+ ```
+ """
+
+ __api_url: str
+ __app_id: str
+ __verbose: bool
+ __llmonitor_version: str
+ __has_valid_config: bool
+
+ def __init__(
+ self,
+ app_id: Union[str, None] = None,
+ api_url: Union[str, None] = None,
+ verbose: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.__has_valid_config = True
+
+ try:
+ import llmonitor
+
+ self.__llmonitor_version = importlib.metadata.version("llmonitor")
+ self.__track_event = llmonitor.track_event
+
+ except ImportError:
+ logger.warning(
+ """[LLMonitor] To use the LLMonitor callback handler you need to
+ have the `llmonitor` Python package installed. Please install it
+ with `pip install llmonitor`"""
+ )
+ self.__has_valid_config = False
+ return
+
+ if parse(self.__llmonitor_version) < parse("0.0.32"):
+ logger.warning(
+ f"""[LLMonitor] The installed `llmonitor` version is
+ {self.__llmonitor_version}
+ but `LLMonitorCallbackHandler` requires at least version 0.0.32
+ upgrade `llmonitor` with `pip install --upgrade llmonitor`"""
+ )
+ self.__has_valid_config = False
+
+ self.__has_valid_config = True
+
+ self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
+ self.__verbose = verbose or bool(os.getenv("LLMONITOR_VERBOSE"))
+
+ _app_id = app_id or os.getenv("LLMONITOR_APP_ID")
+ if _app_id is None:
+ logger.warning(
+ """[LLMonitor] app_id must be provided either as an argument or
+ as an environment variable"""
+ )
+ self.__has_valid_config = False
+ else:
+ self.__app_id = _app_id
+
+ if self.__has_valid_config is False:
+ return None
+
+ try:
+ res = requests.get(f"{self.__api_url}/api/app/{self.__app_id}")
+ if not res.ok:
+ raise ConnectionError()
+ except Exception:
+ logger.warning(
+ f"""[LLMonitor] Could not connect to the LLMonitor API at
+ {self.__api_url}"""
+ )
+
+ def on_llm_start(
+ self,
+ serialized: Dict[str, Any],
+ prompts: List[str],
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ tags: Union[List[str], None] = None,
+ metadata: Union[Dict[str, Any], None] = None,
+ **kwargs: Any,
+ ) -> None:
+ if self.__has_valid_config is False:
+ return
+ try:
+ user_id = _get_user_id(metadata)
+ user_props = _get_user_props(metadata)
+
+ params = kwargs.get("invocation_params", {})
+ params.update(
+ serialized.get("kwargs", {})
+ ) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
+
+ name = (
+ params.get("model")
+ or params.get("model_name")
+ or params.get("model_id")
+ )
+
+ if not name and "anthropic" in params.get("_type"):
+ name = "claude-2"
+
+ extra = {
+ param: params.get(param)
+ for param in PARAMS_TO_CAPTURE
+ if params.get(param) is not None
+ }
+
+ input = _parse_input(prompts)
+
+ self.__track_event(
+ "llm",
+ "start",
+ user_id=user_id,
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ name=name,
+ input=input,
+ tags=tags,
+ extra=extra,
+ metadata=metadata,
+ user_props=user_props,
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
+
+ def on_chat_model_start(
+ self,
+ serialized: Dict[str, Any],
+ messages: List[List[BaseMessage]],
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ tags: Union[List[str], None] = None,
+ metadata: Union[Dict[str, Any], None] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self.__has_valid_config is False:
+ return
+
+ try:
+ user_id = _get_user_id(metadata)
+ user_props = _get_user_props(metadata)
+
+ params = kwargs.get("invocation_params", {})
+ params.update(
+ serialized.get("kwargs", {})
+ ) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
+
+ name = (
+ params.get("model")
+ or params.get("model_name")
+ or params.get("model_id")
+ )
+
+ if not name and "anthropic" in params.get("_type"):
+ name = "claude-2"
+
+ extra = {
+ param: params.get(param)
+ for param in PARAMS_TO_CAPTURE
+ if params.get(param) is not None
+ }
+
+ input = _parse_lc_messages(messages[0])
+
+ self.__track_event(
+ "llm",
+ "start",
+ user_id=user_id,
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ name=name,
+ input=input,
+ tags=tags,
+ extra=extra,
+ metadata=metadata,
+ user_props=user_props,
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_chat_model_start: {e}")
+
+ def on_llm_end(
+ self,
+ response: LLMResult,
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ **kwargs: Any,
+ ) -> None:
+ if self.__has_valid_config is False:
+ return
+
+ try:
+ token_usage = (response.llm_output or {}).get("token_usage", {})
+
+ parsed_output: Any = [
+ _parse_lc_message(generation.message)
+ if hasattr(generation, "message")
+ else generation.text
+ for generation in response.generations[0]
+ ]
+
+ # if it's an array of 1, just parse the first element
+ if len(parsed_output) == 1:
+ parsed_output = parsed_output[0]
+
+ self.__track_event(
+ "llm",
+ "end",
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ output=parsed_output,
+ token_usage={
+ "prompt": token_usage.get("prompt_tokens"),
+ "completion": token_usage.get("completion_tokens"),
+ },
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_llm_end: {e}")
+
+ def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ tags: Union[List[str], None] = None,
+ metadata: Union[Dict[str, Any], None] = None,
+ **kwargs: Any,
+ ) -> None:
+ if self.__has_valid_config is False:
+ return
+ try:
+ user_id = _get_user_id(metadata)
+ user_props = _get_user_props(metadata)
+ name = serialized.get("name")
+
+ self.__track_event(
+ "tool",
+ "start",
+ user_id=user_id,
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ name=name,
+ input=input_str,
+ tags=tags,
+ metadata=metadata,
+ user_props=user_props,
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_tool_start: {e}")
+
+ def on_tool_end(
+ self,
+ output: str,
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ tags: Union[List[str], None] = None,
+ **kwargs: Any,
+ ) -> None:
+ if self.__has_valid_config is False:
+ return
+ try:
+ self.__track_event(
+ "tool",
+ "end",
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ output=output,
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_tool_end: {e}")
+
+ def on_chain_start(
+ self,
+ serialized: Dict[str, Any],
+ inputs: Dict[str, Any],
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ tags: Union[List[str], None] = None,
+ metadata: Union[Dict[str, Any], None] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self.__has_valid_config is False:
+ return
+ try:
+ name = serialized.get("id", [None, None, None, None])[3]
+ type = "chain"
+ metadata = metadata or {}
+
+ agentName = metadata.get("agent_name")
+ if agentName is None:
+ agentName = metadata.get("agentName")
+
+ if name == "AgentExecutor" or name == "PlanAndExecute":
+ type = "agent"
+ if agentName is not None:
+ type = "agent"
+ name = agentName
+ if parent_run_id is not None:
+ type = "chain"
+
+ user_id = _get_user_id(metadata)
+ user_props = _get_user_props(metadata)
+ input = _parse_input(inputs)
+
+ self.__track_event(
+ type,
+ "start",
+ user_id=user_id,
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ name=name,
+ input=input,
+ tags=tags,
+ metadata=metadata,
+ user_props=user_props,
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_chain_start: {e}")
+
+ def on_chain_end(
+ self,
+ outputs: Dict[str, Any],
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self.__has_valid_config is False:
+ return
+ try:
+ output = _parse_output(outputs)
+
+ self.__track_event(
+ "chain",
+ "end",
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ output=output,
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_chain_end: {e}")
+
+ def on_agent_action(
+ self,
+ action: AgentAction,
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self.__has_valid_config is False:
+ return
+ try:
+ name = action.tool
+ input = _parse_input(action.tool_input)
+
+ self.__track_event(
+ "tool",
+ "start",
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ name=name,
+ input=input,
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_agent_action: {e}")
+
+ def on_agent_finish(
+ self,
+ finish: AgentFinish,
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self.__has_valid_config is False:
+ return
+ try:
+ output = _parse_output(finish.return_values)
+
+ self.__track_event(
+ "agent",
+ "end",
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ output=output,
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
+
+ def on_chain_error(
+ self,
+ error: BaseException,
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self.__has_valid_config is False:
+ return
+ try:
+ self.__track_event(
+ "chain",
+ "error",
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ error={"message": str(error), "stack": traceback.format_exc()},
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_chain_error: {e}")
+
+ def on_tool_error(
+ self,
+ error: BaseException,
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self.__has_valid_config is False:
+ return
+ try:
+ self.__track_event(
+ "tool",
+ "error",
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ error={"message": str(error), "stack": traceback.format_exc()},
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_tool_error: {e}")
+
+ def on_llm_error(
+ self,
+ error: BaseException,
+ *,
+ run_id: UUID,
+ parent_run_id: Union[UUID, None] = None,
+ **kwargs: Any,
+ ) -> Any:
+ if self.__has_valid_config is False:
+ return
+ try:
+ self.__track_event(
+ "llm",
+ "error",
+ run_id=str(run_id),
+ parent_run_id=str(parent_run_id) if parent_run_id else None,
+ error={"message": str(error), "stack": traceback.format_exc()},
+ app_id=self.__app_id,
+ )
+ except Exception as e:
+ logger.error(f"[LLMonitor] An error occurred in on_llm_error: {e}")
+
+
+__all__ = ["LLMonitorCallbackHandler", "identify"]
diff --git a/libs/community/langchain_community/callbacks/manager.py b/libs/community/langchain_community/callbacks/manager.py
new file mode 100644
index 00000000000..196afed3d0f
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/manager.py
@@ -0,0 +1,69 @@
+from __future__ import annotations
+
+import logging
+from contextlib import contextmanager
+from contextvars import ContextVar
+from typing import (
+ Generator,
+ Optional,
+)
+
+from langchain_core.tracers.context import register_configure_hook
+
+from langchain_community.callbacks.openai_info import OpenAICallbackHandler
+from langchain_community.callbacks.tracers.wandb import WandbTracer
+
+logger = logging.getLogger(__name__)
+
+openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
+ "openai_callback", default=None
+)
+wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( # noqa: E501
+ "tracing_wandb_callback", default=None
+)
+
+register_configure_hook(openai_callback_var, True)
+register_configure_hook(
+ wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
+)
+
+
+@contextmanager
+def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
+ """Get the OpenAI callback handler in a context manager.
+ which conveniently exposes token and cost information.
+
+ Returns:
+ OpenAICallbackHandler: The OpenAI callback handler.
+
+ Example:
+ >>> with get_openai_callback() as cb:
+ ... # Use the OpenAI callback handler
+ """
+ cb = OpenAICallbackHandler()
+ openai_callback_var.set(cb)
+ yield cb
+ openai_callback_var.set(None)
+
+
+@contextmanager
+def wandb_tracing_enabled(
+ session_name: str = "default",
+) -> Generator[None, None, None]:
+ """Get the WandbTracer in a context manager.
+
+ Args:
+ session_name (str, optional): The name of the session.
+ Defaults to "default".
+
+ Returns:
+ None
+
+ Example:
+ >>> with wandb_tracing_enabled() as session:
+ ... # Use the WandbTracer session
+ """
+ cb = WandbTracer()
+ wandb_tracing_callback_var.set(cb)
+ yield None
+ wandb_tracing_callback_var.set(None)
diff --git a/libs/community/langchain_community/callbacks/mlflow_callback.py b/libs/community/langchain_community/callbacks/mlflow_callback.py
new file mode 100644
index 00000000000..6d93125d564
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/mlflow_callback.py
@@ -0,0 +1,660 @@
+import os
+import random
+import string
+import tempfile
+import traceback
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.callbacks.utils import (
+ BaseMetadataCallbackHandler,
+ flatten_dict,
+ hash_string,
+ import_pandas,
+ import_spacy,
+ import_textstat,
+)
+
+
+def import_mlflow() -> Any:
+ """Import the mlflow python package and raise an error if it is not installed."""
+ try:
+ import mlflow
+ except ImportError:
+ raise ImportError(
+ "To use the mlflow callback manager you need to have the `mlflow` python "
+ "package installed. Please install it with `pip install mlflow>=2.3.0`"
+ )
+ return mlflow
+
+
+def analyze_text(
+ text: str,
+ nlp: Any = None,
+) -> dict:
+ """Analyze text using textstat and spacy.
+
+ Parameters:
+ text (str): The text to analyze.
+ nlp (spacy.lang): The spacy language model to use for visualization.
+
+ Returns:
+ (dict): A dictionary containing the complexity metrics and visualization
+ files serialized to HTML string.
+ """
+ resp: Dict[str, Any] = {}
+ textstat = import_textstat()
+ spacy = import_spacy()
+ text_complexity_metrics = {
+ "flesch_reading_ease": textstat.flesch_reading_ease(text),
+ "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
+ "smog_index": textstat.smog_index(text),
+ "coleman_liau_index": textstat.coleman_liau_index(text),
+ "automated_readability_index": textstat.automated_readability_index(text),
+ "dale_chall_readability_score": textstat.dale_chall_readability_score(text),
+ "difficult_words": textstat.difficult_words(text),
+ "linsear_write_formula": textstat.linsear_write_formula(text),
+ "gunning_fog": textstat.gunning_fog(text),
+ # "text_standard": textstat.text_standard(text),
+ "fernandez_huerta": textstat.fernandez_huerta(text),
+ "szigriszt_pazos": textstat.szigriszt_pazos(text),
+ "gutierrez_polini": textstat.gutierrez_polini(text),
+ "crawford": textstat.crawford(text),
+ "gulpease_index": textstat.gulpease_index(text),
+ "osman": textstat.osman(text),
+ }
+ resp.update({"text_complexity_metrics": text_complexity_metrics})
+ resp.update(text_complexity_metrics)
+
+ if nlp is not None:
+ doc = nlp(text)
+
+ dep_out = spacy.displacy.render( # type: ignore
+ doc, style="dep", jupyter=False, page=True
+ )
+
+ ent_out = spacy.displacy.render( # type: ignore
+ doc, style="ent", jupyter=False, page=True
+ )
+
+ text_visualizations = {
+ "dependency_tree": dep_out,
+ "entities": ent_out,
+ }
+
+ resp.update(text_visualizations)
+
+ return resp
+
+
+def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
+ """Construct an html element from a prompt and a generation.
+
+ Parameters:
+ prompt (str): The prompt.
+ generation (str): The generation.
+
+ Returns:
+ (str): The html string."""
+ formatted_prompt = prompt.replace("\n", " ")
+ formatted_generation = generation.replace("\n", " ")
+
+ return f"""
+
{formatted_prompt}:
+
+
+ {formatted_generation}
+
+
+ """
+
+
+class MlflowLogger:
+ """Callback Handler that logs metrics and artifacts to mlflow server.
+
+ Parameters:
+ name (str): Name of the run.
+ experiment (str): Name of the experiment.
+ tags (dict): Tags to be attached for the run.
+ tracking_uri (str): MLflow tracking server uri.
+
+ This handler implements the helper functions to initialize,
+ log metrics and artifacts to the mlflow server.
+ """
+
+ def __init__(self, **kwargs: Any):
+ self.mlflow = import_mlflow()
+ if "DATABRICKS_RUNTIME_VERSION" in os.environ:
+ self.mlflow.set_tracking_uri("databricks")
+ self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
+ self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid)
+ else:
+ tracking_uri = get_from_dict_or_env(
+ kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", ""
+ )
+ self.mlflow.set_tracking_uri(tracking_uri)
+
+ # User can set other env variables described here
+ # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
+
+ experiment_name = get_from_dict_or_env(
+ kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
+ )
+ self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
+ if self.mlf_exp is not None:
+ self.mlf_expid = self.mlf_exp.experiment_id
+ else:
+ self.mlf_expid = self.mlflow.create_experiment(experiment_name)
+
+ self.start_run(kwargs["run_name"], kwargs["run_tags"])
+
+ def start_run(self, name: str, tags: Dict[str, str]) -> None:
+ """To start a new run, auto generates the random suffix for name"""
+ if name.endswith("-%"):
+ rname = "".join(random.choices(string.ascii_uppercase + string.digits, k=7))
+ name = name.replace("%", rname)
+ self.run = self.mlflow.MlflowClient().create_run(
+ self.mlf_expid, run_name=name, tags=tags
+ )
+
+ def finish_run(self) -> None:
+ """To finish the run."""
+ with self.mlflow.start_run(
+ run_id=self.run.info.run_id, experiment_id=self.mlf_expid
+ ):
+ self.mlflow.end_run()
+
+ def metric(self, key: str, value: float) -> None:
+ """To log metric to mlflow server."""
+ with self.mlflow.start_run(
+ run_id=self.run.info.run_id, experiment_id=self.mlf_expid
+ ):
+ self.mlflow.log_metric(key, value)
+
+ def metrics(
+ self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
+ ) -> None:
+ """To log all metrics in the input dict."""
+ with self.mlflow.start_run(
+ run_id=self.run.info.run_id, experiment_id=self.mlf_expid
+ ):
+ self.mlflow.log_metrics(data)
+
+ def jsonf(self, data: Dict[str, Any], filename: str) -> None:
+ """To log the input data as json file artifact."""
+ with self.mlflow.start_run(
+ run_id=self.run.info.run_id, experiment_id=self.mlf_expid
+ ):
+ self.mlflow.log_dict(data, f"{filename}.json")
+
+ def table(self, name: str, dataframe) -> None: # type: ignore
+ """To log the input pandas dataframe as a html table"""
+ self.html(dataframe.to_html(), f"table_{name}")
+
+ def html(self, html: str, filename: str) -> None:
+ """To log the input html string as html file artifact."""
+ with self.mlflow.start_run(
+ run_id=self.run.info.run_id, experiment_id=self.mlf_expid
+ ):
+ self.mlflow.log_text(html, f"{filename}.html")
+
+ def text(self, text: str, filename: str) -> None:
+ """To log the input text as text file artifact."""
+ with self.mlflow.start_run(
+ run_id=self.run.info.run_id, experiment_id=self.mlf_expid
+ ):
+ self.mlflow.log_text(text, f"{filename}.txt")
+
+ def artifact(self, path: str) -> None:
+ """To upload the file from given path as artifact."""
+ with self.mlflow.start_run(
+ run_id=self.run.info.run_id, experiment_id=self.mlf_expid
+ ):
+ self.mlflow.log_artifact(path)
+
+ def langchain_artifact(self, chain: Any) -> None:
+ with self.mlflow.start_run(
+ run_id=self.run.info.run_id, experiment_id=self.mlf_expid
+ ):
+ self.mlflow.langchain.log_model(chain, "langchain-model")
+
+
+class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
+ """Callback Handler that logs metrics and artifacts to mlflow server.
+
+ Parameters:
+ name (str): Name of the run.
+ experiment (str): Name of the experiment.
+ tags (dict): Tags to be attached for the run.
+ tracking_uri (str): MLflow tracking server uri.
+
+ This handler will utilize the associated callback method called and formats
+ the input of each callback function with metadata regarding the state of LLM run,
+ and adds the response to the list of records for both the {method}_records and
+ action. It then logs the response to mlflow server.
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = "langchainrun-%",
+ experiment: Optional[str] = "langchain",
+ tags: Optional[Dict] = None,
+ tracking_uri: Optional[str] = None,
+ ) -> None:
+ """Initialize callback handler."""
+ import_pandas()
+ import_textstat()
+ import_mlflow()
+ spacy = import_spacy()
+ super().__init__()
+
+ self.name = name
+ self.experiment = experiment
+ self.tags = tags or {}
+ self.tracking_uri = tracking_uri
+
+ self.temp_dir = tempfile.TemporaryDirectory()
+
+ self.mlflg = MlflowLogger(
+ tracking_uri=self.tracking_uri,
+ experiment_name=self.experiment,
+ run_name=self.name,
+ run_tags=self.tags,
+ )
+
+ self.action_records: list = []
+ self.nlp = spacy.load("en_core_web_sm")
+
+ self.metrics = {
+ "step": 0,
+ "starts": 0,
+ "ends": 0,
+ "errors": 0,
+ "text_ctr": 0,
+ "chain_starts": 0,
+ "chain_ends": 0,
+ "llm_starts": 0,
+ "llm_ends": 0,
+ "llm_streams": 0,
+ "tool_starts": 0,
+ "tool_ends": 0,
+ "agent_ends": 0,
+ }
+
+ self.records: Dict[str, Any] = {
+ "on_llm_start_records": [],
+ "on_llm_token_records": [],
+ "on_llm_end_records": [],
+ "on_chain_start_records": [],
+ "on_chain_end_records": [],
+ "on_tool_start_records": [],
+ "on_tool_end_records": [],
+ "on_text_records": [],
+ "on_agent_finish_records": [],
+ "on_agent_action_records": [],
+ "action_records": [],
+ }
+
+ def _reset(self) -> None:
+ for k, v in self.metrics.items():
+ self.metrics[k] = 0
+ for k, v in self.records.items():
+ self.records[k] = []
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Run when LLM starts."""
+ self.metrics["step"] += 1
+ self.metrics["llm_starts"] += 1
+ self.metrics["starts"] += 1
+
+ llm_starts = self.metrics["llm_starts"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_llm_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ for idx, prompt in enumerate(prompts):
+ prompt_resp = deepcopy(resp)
+ prompt_resp["prompt"] = prompt
+ self.records["on_llm_start_records"].append(prompt_resp)
+ self.records["action_records"].append(prompt_resp)
+ self.mlflg.jsonf(prompt_resp, f"llm_start_{llm_starts}_prompt_{idx}")
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run when LLM generates a new token."""
+ self.metrics["step"] += 1
+ self.metrics["llm_streams"] += 1
+
+ llm_streams = self.metrics["llm_streams"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_llm_new_token", "token": token})
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ self.records["on_llm_token_records"].append(resp)
+ self.records["action_records"].append(resp)
+ self.mlflg.jsonf(resp, f"llm_new_tokens_{llm_streams}")
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends running."""
+ self.metrics["step"] += 1
+ self.metrics["llm_ends"] += 1
+ self.metrics["ends"] += 1
+
+ llm_ends = self.metrics["llm_ends"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_llm_end"})
+ resp.update(flatten_dict(response.llm_output or {}))
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ for generations in response.generations:
+ for idx, generation in enumerate(generations):
+ generation_resp = deepcopy(resp)
+ generation_resp.update(flatten_dict(generation.dict()))
+ generation_resp.update(
+ analyze_text(
+ generation.text,
+ nlp=self.nlp,
+ )
+ )
+ complexity_metrics: Dict[str, float] = generation_resp.pop(
+ "text_complexity_metrics"
+ ) # type: ignore # noqa: E501
+ self.mlflg.metrics(
+ complexity_metrics,
+ step=self.metrics["step"],
+ )
+ self.records["on_llm_end_records"].append(generation_resp)
+ self.records["action_records"].append(generation_resp)
+ self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
+ dependency_tree = generation_resp["dependency_tree"]
+ entities = generation_resp["entities"]
+ self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
+ self.mlflg.html(entities, "ent-" + hash_string(generation.text))
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when LLM errors."""
+ self.metrics["step"] += 1
+ self.metrics["errors"] += 1
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Run when chain starts running."""
+ self.metrics["step"] += 1
+ self.metrics["chain_starts"] += 1
+ self.metrics["starts"] += 1
+
+ chain_starts = self.metrics["chain_starts"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_chain_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
+ input_resp = deepcopy(resp)
+ input_resp["inputs"] = chain_input
+ self.records["on_chain_start_records"].append(input_resp)
+ self.records["action_records"].append(input_resp)
+ self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Run when chain ends running."""
+ self.metrics["step"] += 1
+ self.metrics["chain_ends"] += 1
+ self.metrics["ends"] += 1
+
+ chain_ends = self.metrics["chain_ends"]
+
+ resp: Dict[str, Any] = {}
+ chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
+ resp.update({"action": "on_chain_end", "outputs": chain_output})
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ self.records["on_chain_end_records"].append(resp)
+ self.records["action_records"].append(resp)
+ self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when chain errors."""
+ self.metrics["step"] += 1
+ self.metrics["errors"] += 1
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ """Run when tool starts running."""
+ self.metrics["step"] += 1
+ self.metrics["tool_starts"] += 1
+ self.metrics["starts"] += 1
+
+ tool_starts = self.metrics["tool_starts"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_tool_start", "input_str": input_str})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ self.records["on_tool_start_records"].append(resp)
+ self.records["action_records"].append(resp)
+ self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
+
+ def on_tool_end(self, output: str, **kwargs: Any) -> None:
+ """Run when tool ends running."""
+ self.metrics["step"] += 1
+ self.metrics["tool_ends"] += 1
+ self.metrics["ends"] += 1
+
+ tool_ends = self.metrics["tool_ends"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_tool_end", "output": output})
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ self.records["on_tool_end_records"].append(resp)
+ self.records["action_records"].append(resp)
+ self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when tool errors."""
+ self.metrics["step"] += 1
+ self.metrics["errors"] += 1
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """
+ Run when agent is ending.
+ """
+ self.metrics["step"] += 1
+ self.metrics["text_ctr"] += 1
+
+ text_ctr = self.metrics["text_ctr"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_text", "text": text})
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ self.records["on_text_records"].append(resp)
+ self.records["action_records"].append(resp)
+ self.mlflg.jsonf(resp, f"on_text_{text_ctr}")
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Run when agent ends running."""
+ self.metrics["step"] += 1
+ self.metrics["agent_ends"] += 1
+ self.metrics["ends"] += 1
+
+ agent_ends = self.metrics["agent_ends"]
+ resp: Dict[str, Any] = {}
+ resp.update(
+ {
+ "action": "on_agent_finish",
+ "output": finish.return_values["output"],
+ "log": finish.log,
+ }
+ )
+ resp.update(self.metrics)
+
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+
+ self.records["on_agent_finish_records"].append(resp)
+ self.records["action_records"].append(resp)
+ self.mlflg.jsonf(resp, f"agent_finish_{agent_ends}")
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Run on agent action."""
+ self.metrics["step"] += 1
+ self.metrics["tool_starts"] += 1
+ self.metrics["starts"] += 1
+
+ tool_starts = self.metrics["tool_starts"]
+ resp: Dict[str, Any] = {}
+ resp.update(
+ {
+ "action": "on_agent_action",
+ "tool": action.tool,
+ "tool_input": action.tool_input,
+ "log": action.log,
+ }
+ )
+ resp.update(self.metrics)
+ self.mlflg.metrics(self.metrics, step=self.metrics["step"])
+ self.records["on_agent_action_records"].append(resp)
+ self.records["action_records"].append(resp)
+ self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
+
+ def _create_session_analysis_df(self) -> Any:
+ """Create a dataframe with all the information from the session."""
+ pd = import_pandas()
+ on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
+ on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
+
+ llm_input_columns = ["step", "prompt"]
+ if "name" in on_llm_start_records_df.columns:
+ llm_input_columns.append("name")
+ elif "id" in on_llm_start_records_df.columns:
+ # id is llm class's full import path. For example:
+ # ["langchain", "llms", "openai", "AzureOpenAI"]
+ on_llm_start_records_df["name"] = on_llm_start_records_df["id"].apply(
+ lambda id_: id_[-1]
+ )
+ llm_input_columns.append("name")
+ llm_input_prompts_df = (
+ on_llm_start_records_df[llm_input_columns]
+ .dropna(axis=1)
+ .rename({"step": "prompt_step"}, axis=1)
+ )
+ complexity_metrics_columns = []
+ visualizations_columns = []
+
+ complexity_metrics_columns = [
+ "flesch_reading_ease",
+ "flesch_kincaid_grade",
+ "smog_index",
+ "coleman_liau_index",
+ "automated_readability_index",
+ "dale_chall_readability_score",
+ "difficult_words",
+ "linsear_write_formula",
+ "gunning_fog",
+ # "text_standard",
+ "fernandez_huerta",
+ "szigriszt_pazos",
+ "gutierrez_polini",
+ "crawford",
+ "gulpease_index",
+ "osman",
+ ]
+
+ visualizations_columns = ["dependency_tree", "entities"]
+
+ llm_outputs_df = (
+ on_llm_end_records_df[
+ [
+ "step",
+ "text",
+ "token_usage_total_tokens",
+ "token_usage_prompt_tokens",
+ "token_usage_completion_tokens",
+ ]
+ + complexity_metrics_columns
+ + visualizations_columns
+ ]
+ .dropna(axis=1)
+ .rename({"step": "output_step", "text": "output"}, axis=1)
+ )
+ session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
+ session_analysis_df["chat_html"] = session_analysis_df[
+ ["prompt", "output"]
+ ].apply(
+ lambda row: construct_html_from_prompt_and_generation(
+ row["prompt"], row["output"]
+ ),
+ axis=1,
+ )
+ return session_analysis_df
+
+ def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
+ pd = import_pandas()
+ self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
+ session_analysis_df = self._create_session_analysis_df()
+ chat_html = session_analysis_df.pop("chat_html")
+ chat_html = chat_html.replace("\n", "", regex=True)
+ self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
+ self.mlflg.html("".join(chat_html.tolist()), "chat_html")
+
+ if langchain_asset:
+ # To avoid circular import error
+ # mlflow only supports LLMChain asset
+ if "langchain.chains.llm.LLMChain" in str(type(langchain_asset)):
+ self.mlflg.langchain_artifact(langchain_asset)
+ else:
+ langchain_asset_path = str(Path(self.temp_dir.name, "model.json"))
+ try:
+ langchain_asset.save(langchain_asset_path)
+ self.mlflg.artifact(langchain_asset_path)
+ except ValueError:
+ try:
+ langchain_asset.save_agent(langchain_asset_path)
+ self.mlflg.artifact(langchain_asset_path)
+ except AttributeError:
+ print("Could not save model.")
+ traceback.print_exc()
+ pass
+ except NotImplementedError:
+ print("Could not save model.")
+ traceback.print_exc()
+ pass
+ except NotImplementedError:
+ print("Could not save model.")
+ traceback.print_exc()
+ pass
+ if finish:
+ self.mlflg.finish_run()
+ self._reset()
diff --git a/libs/community/langchain_community/callbacks/openai_info.py b/libs/community/langchain_community/callbacks/openai_info.py
new file mode 100644
index 00000000000..bf0c59b746e
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/openai_info.py
@@ -0,0 +1,208 @@
+"""Callback Handler that prints to std out."""
+from typing import Any, Dict, List
+
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+MODEL_COST_PER_1K_TOKENS = {
+ # GPT-4 input
+ "gpt-4": 0.03,
+ "gpt-4-0314": 0.03,
+ "gpt-4-0613": 0.03,
+ "gpt-4-32k": 0.06,
+ "gpt-4-32k-0314": 0.06,
+ "gpt-4-32k-0613": 0.06,
+ "gpt-4-vision-preview": 0.01,
+ "gpt-4-1106-preview": 0.01,
+ # GPT-4 output
+ "gpt-4-completion": 0.06,
+ "gpt-4-0314-completion": 0.06,
+ "gpt-4-0613-completion": 0.06,
+ "gpt-4-32k-completion": 0.12,
+ "gpt-4-32k-0314-completion": 0.12,
+ "gpt-4-32k-0613-completion": 0.12,
+ "gpt-4-vision-preview-completion": 0.03,
+ "gpt-4-1106-preview-completion": 0.03,
+ # GPT-3.5 input
+ "gpt-3.5-turbo": 0.0015,
+ "gpt-3.5-turbo-0301": 0.0015,
+ "gpt-3.5-turbo-0613": 0.0015,
+ "gpt-3.5-turbo-1106": 0.001,
+ "gpt-3.5-turbo-instruct": 0.0015,
+ "gpt-3.5-turbo-16k": 0.003,
+ "gpt-3.5-turbo-16k-0613": 0.003,
+ # GPT-3.5 output
+ "gpt-3.5-turbo-completion": 0.002,
+ "gpt-3.5-turbo-0301-completion": 0.002,
+ "gpt-3.5-turbo-0613-completion": 0.002,
+ "gpt-3.5-turbo-1106-completion": 0.002,
+ "gpt-3.5-turbo-instruct-completion": 0.002,
+ "gpt-3.5-turbo-16k-completion": 0.004,
+ "gpt-3.5-turbo-16k-0613-completion": 0.004,
+ # Azure GPT-35 input
+ "gpt-35-turbo": 0.0015, # Azure OpenAI version of ChatGPT
+ "gpt-35-turbo-0301": 0.0015, # Azure OpenAI version of ChatGPT
+ "gpt-35-turbo-0613": 0.0015,
+ "gpt-35-turbo-instruct": 0.0015,
+ "gpt-35-turbo-16k": 0.003,
+ "gpt-35-turbo-16k-0613": 0.003,
+ # Azure GPT-35 output
+ "gpt-35-turbo-completion": 0.002, # Azure OpenAI version of ChatGPT
+ "gpt-35-turbo-0301-completion": 0.002, # Azure OpenAI version of ChatGPT
+ "gpt-35-turbo-0613-completion": 0.002,
+ "gpt-35-turbo-instruct-completion": 0.002,
+ "gpt-35-turbo-16k-completion": 0.004,
+ "gpt-35-turbo-16k-0613-completion": 0.004,
+ # Others
+ "text-ada-001": 0.0004,
+ "ada": 0.0004,
+ "text-babbage-001": 0.0005,
+ "babbage": 0.0005,
+ "text-curie-001": 0.002,
+ "curie": 0.002,
+ "text-davinci-003": 0.02,
+ "text-davinci-002": 0.02,
+ "code-davinci-002": 0.02,
+ # Fine Tuned input
+ "babbage-002-finetuned": 0.0016,
+ "davinci-002-finetuned": 0.012,
+ "gpt-3.5-turbo-0613-finetuned": 0.012,
+ # Fine Tuned output
+ "babbage-002-finetuned-completion": 0.0016,
+ "davinci-002-finetuned-completion": 0.012,
+ "gpt-3.5-turbo-0613-finetuned-completion": 0.016,
+ # Azure Fine Tuned input
+ "babbage-002-azure-finetuned": 0.0004,
+ "davinci-002-azure-finetuned": 0.002,
+ "gpt-35-turbo-0613-azure-finetuned": 0.0015,
+ # Azure Fine Tuned output
+ "babbage-002-azure-finetuned-completion": 0.0004,
+ "davinci-002-azure-finetuned-completion": 0.002,
+ "gpt-35-turbo-0613-azure-finetuned-completion": 0.002,
+ # Legacy fine-tuned models
+ "ada-finetuned-legacy": 0.0016,
+ "babbage-finetuned-legacy": 0.0024,
+ "curie-finetuned-legacy": 0.012,
+ "davinci-finetuned-legacy": 0.12,
+}
+
+
+def standardize_model_name(
+ model_name: str,
+ is_completion: bool = False,
+) -> str:
+ """
+ Standardize the model name to a format that can be used in the OpenAI API.
+
+ Args:
+ model_name: Model name to standardize.
+ is_completion: Whether the model is used for completion or not.
+ Defaults to False.
+
+ Returns:
+ Standardized model name.
+
+ """
+ model_name = model_name.lower()
+ if ".ft-" in model_name:
+ model_name = model_name.split(".ft-")[0] + "-azure-finetuned"
+ if ":ft-" in model_name:
+ model_name = model_name.split(":")[0] + "-finetuned-legacy"
+ if "ft:" in model_name:
+ model_name = model_name.split(":")[1] + "-finetuned"
+ if is_completion and (
+ model_name.startswith("gpt-4")
+ or model_name.startswith("gpt-3.5")
+ or model_name.startswith("gpt-35")
+ or ("finetuned" in model_name and "legacy" not in model_name)
+ ):
+ return model_name + "-completion"
+ else:
+ return model_name
+
+
+def get_openai_token_cost_for_model(
+ model_name: str, num_tokens: int, is_completion: bool = False
+) -> float:
+ """
+ Get the cost in USD for a given model and number of tokens.
+
+ Args:
+ model_name: Name of the model
+ num_tokens: Number of tokens.
+ is_completion: Whether the model is used for completion or not.
+ Defaults to False.
+
+ Returns:
+ Cost in USD.
+ """
+ model_name = standardize_model_name(model_name, is_completion=is_completion)
+ if model_name not in MODEL_COST_PER_1K_TOKENS:
+ raise ValueError(
+ f"Unknown model: {model_name}. Please provide a valid OpenAI model name."
+ "Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys())
+ )
+ return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000)
+
+
+class OpenAICallbackHandler(BaseCallbackHandler):
+ """Callback Handler that tracks OpenAI info."""
+
+ total_tokens: int = 0
+ prompt_tokens: int = 0
+ completion_tokens: int = 0
+ successful_requests: int = 0
+ total_cost: float = 0.0
+
+ def __repr__(self) -> str:
+ return (
+ f"Tokens Used: {self.total_tokens}\n"
+ f"\tPrompt Tokens: {self.prompt_tokens}\n"
+ f"\tCompletion Tokens: {self.completion_tokens}\n"
+ f"Successful Requests: {self.successful_requests}\n"
+ f"Total Cost (USD): ${self.total_cost}"
+ )
+
+ @property
+ def always_verbose(self) -> bool:
+ """Whether to call verbose callbacks even if verbose is False."""
+ return True
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Print out the prompts."""
+ pass
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Print out the token."""
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Collect token usage."""
+ if response.llm_output is None:
+ return None
+ self.successful_requests += 1
+ if "token_usage" not in response.llm_output:
+ return None
+ token_usage = response.llm_output["token_usage"]
+ completion_tokens = token_usage.get("completion_tokens", 0)
+ prompt_tokens = token_usage.get("prompt_tokens", 0)
+ model_name = standardize_model_name(response.llm_output.get("model_name", ""))
+ if model_name in MODEL_COST_PER_1K_TOKENS:
+ completion_cost = get_openai_token_cost_for_model(
+ model_name, completion_tokens, is_completion=True
+ )
+ prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
+ self.total_cost += prompt_cost + completion_cost
+ self.total_tokens += token_usage.get("total_tokens", 0)
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ def __copy__(self) -> "OpenAICallbackHandler":
+ """Return a copy of the callback handler."""
+ return self
+
+ def __deepcopy__(self, memo: Any) -> "OpenAICallbackHandler":
+ """Return a deep copy of the callback handler."""
+ return self
diff --git a/libs/community/langchain_community/callbacks/promptlayer_callback.py b/libs/community/langchain_community/callbacks/promptlayer_callback.py
new file mode 100644
index 00000000000..f9431681246
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/promptlayer_callback.py
@@ -0,0 +1,162 @@
+"""Callback handler for promptlayer."""
+from __future__ import annotations
+
+import datetime
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
+from uuid import UUID
+
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import (
+ ChatGeneration,
+ LLMResult,
+)
+
+if TYPE_CHECKING:
+ import promptlayer
+
+
+def _lazy_import_promptlayer() -> promptlayer:
+ """Lazy import promptlayer to avoid circular imports."""
+ try:
+ import promptlayer
+ except ImportError:
+ raise ImportError(
+ "The PromptLayerCallbackHandler requires the promptlayer package. "
+ " Please install it with `pip install promptlayer`."
+ )
+ return promptlayer
+
+
+class PromptLayerCallbackHandler(BaseCallbackHandler):
+ """Callback handler for promptlayer."""
+
+ def __init__(
+ self,
+ pl_id_callback: Optional[Callable[..., Any]] = None,
+ pl_tags: Optional[List[str]] = None,
+ ) -> None:
+ """Initialize the PromptLayerCallbackHandler."""
+ _lazy_import_promptlayer()
+ self.pl_id_callback = pl_id_callback
+ self.pl_tags = pl_tags or []
+ self.runs: Dict[UUID, Dict[str, Any]] = {}
+
+ def on_chat_model_start(
+ self,
+ serialized: Dict[str, Any],
+ messages: List[List[BaseMessage]],
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Any:
+ self.runs[run_id] = {
+ "messages": [self._create_message_dicts(m)[0] for m in messages],
+ "invocation_params": kwargs.get("invocation_params", {}),
+ "name": ".".join(serialized["id"]),
+ "request_start_time": datetime.datetime.now().timestamp(),
+ "tags": tags,
+ }
+
+ def on_llm_start(
+ self,
+ serialized: Dict[str, Any],
+ prompts: List[str],
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Any:
+ self.runs[run_id] = {
+ "prompts": prompts,
+ "invocation_params": kwargs.get("invocation_params", {}),
+ "name": ".".join(serialized["id"]),
+ "request_start_time": datetime.datetime.now().timestamp(),
+ "tags": tags,
+ }
+
+ def on_llm_end(
+ self,
+ response: LLMResult,
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ **kwargs: Any,
+ ) -> None:
+ from promptlayer.utils import get_api_key, promptlayer_api_request
+
+ run_info = self.runs.get(run_id, {})
+ if not run_info:
+ return
+ run_info["request_end_time"] = datetime.datetime.now().timestamp()
+ for i in range(len(response.generations)):
+ generation = response.generations[i][0]
+
+ resp = {
+ "text": generation.text,
+ "llm_output": response.llm_output,
+ }
+ model_params = run_info.get("invocation_params", {})
+ is_chat_model = run_info.get("messages", None) is not None
+ model_input = (
+ run_info.get("messages", [])[i]
+ if is_chat_model
+ else [run_info.get("prompts", [])[i]]
+ )
+ model_response = (
+ [self._convert_message_to_dict(generation.message)]
+ if is_chat_model and isinstance(generation, ChatGeneration)
+ else resp
+ )
+
+ pl_request_id = promptlayer_api_request(
+ run_info.get("name"),
+ "langchain",
+ model_input,
+ model_params,
+ self.pl_tags,
+ model_response,
+ run_info.get("request_start_time"),
+ run_info.get("request_end_time"),
+ get_api_key(),
+ return_pl_id=bool(self.pl_id_callback is not None),
+ metadata={
+ "_langchain_run_id": str(run_id),
+ "_langchain_parent_run_id": str(parent_run_id),
+ "_langchain_tags": str(run_info.get("tags", [])),
+ },
+ )
+
+ if self.pl_id_callback:
+ self.pl_id_callback(pl_request_id)
+
+ def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
+ if isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ if "name" in message.additional_kwargs:
+ message_dict["name"] = message.additional_kwargs["name"]
+ return message_dict
+
+ def _create_message_dicts(
+ self, messages: List[BaseMessage]
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ params: Dict[str, Any] = {}
+ message_dicts = [self._convert_message_to_dict(m) for m in messages]
+ return message_dicts, params
diff --git a/libs/community/langchain_community/callbacks/sagemaker_callback.py b/libs/community/langchain_community/callbacks/sagemaker_callback.py
new file mode 100644
index 00000000000..b791425ff00
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/sagemaker_callback.py
@@ -0,0 +1,276 @@
+import json
+import os
+import shutil
+import tempfile
+from copy import deepcopy
+from typing import Any, Dict, List, Optional
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+from langchain_community.callbacks.utils import (
+ flatten_dict,
+)
+
+
+def save_json(data: dict, file_path: str) -> None:
+ """Save dict to local file path.
+
+ Parameters:
+ data (dict): The dictionary to be saved.
+ file_path (str): Local file path.
+ """
+ with open(file_path, "w") as outfile:
+ json.dump(data, outfile)
+
+
+class SageMakerCallbackHandler(BaseCallbackHandler):
+ """Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
+
+ Parameters:
+ run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
+ """
+
+ def __init__(self, run: Any) -> None:
+ """Initialize callback handler."""
+ super().__init__()
+
+ self.run = run
+
+ self.metrics = {
+ "step": 0,
+ "starts": 0,
+ "ends": 0,
+ "errors": 0,
+ "text_ctr": 0,
+ "chain_starts": 0,
+ "chain_ends": 0,
+ "llm_starts": 0,
+ "llm_ends": 0,
+ "llm_streams": 0,
+ "tool_starts": 0,
+ "tool_ends": 0,
+ "agent_ends": 0,
+ }
+
+ # Create a temporary directory
+ self.temp_dir = tempfile.mkdtemp()
+
+ def _reset(self) -> None:
+ for k, v in self.metrics.items():
+ self.metrics[k] = 0
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Run when LLM starts."""
+ self.metrics["step"] += 1
+ self.metrics["llm_starts"] += 1
+ self.metrics["starts"] += 1
+
+ llm_starts = self.metrics["llm_starts"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_llm_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.metrics)
+
+ for idx, prompt in enumerate(prompts):
+ prompt_resp = deepcopy(resp)
+ prompt_resp["prompt"] = prompt
+ self.jsonf(
+ prompt_resp,
+ self.temp_dir,
+ f"llm_start_{llm_starts}_prompt_{idx}",
+ )
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run when LLM generates a new token."""
+ self.metrics["step"] += 1
+ self.metrics["llm_streams"] += 1
+
+ llm_streams = self.metrics["llm_streams"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_llm_new_token", "token": token})
+ resp.update(self.metrics)
+
+ self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends running."""
+ self.metrics["step"] += 1
+ self.metrics["llm_ends"] += 1
+ self.metrics["ends"] += 1
+
+ llm_ends = self.metrics["llm_ends"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_llm_end"})
+ resp.update(flatten_dict(response.llm_output or {}))
+
+ resp.update(self.metrics)
+
+ for generations in response.generations:
+ for idx, generation in enumerate(generations):
+ generation_resp = deepcopy(resp)
+ generation_resp.update(flatten_dict(generation.dict()))
+
+ self.jsonf(
+ resp,
+ self.temp_dir,
+ f"llm_end_{llm_ends}_generation_{idx}",
+ )
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when LLM errors."""
+ self.metrics["step"] += 1
+ self.metrics["errors"] += 1
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Run when chain starts running."""
+ self.metrics["step"] += 1
+ self.metrics["chain_starts"] += 1
+ self.metrics["starts"] += 1
+
+ chain_starts = self.metrics["chain_starts"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_chain_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.metrics)
+
+ chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
+ input_resp = deepcopy(resp)
+ input_resp["inputs"] = chain_input
+
+ self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Run when chain ends running."""
+ self.metrics["step"] += 1
+ self.metrics["chain_ends"] += 1
+ self.metrics["ends"] += 1
+
+ chain_ends = self.metrics["chain_ends"]
+
+ resp: Dict[str, Any] = {}
+ chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
+ resp.update({"action": "on_chain_end", "outputs": chain_output})
+ resp.update(self.metrics)
+
+ self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when chain errors."""
+ self.metrics["step"] += 1
+ self.metrics["errors"] += 1
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ """Run when tool starts running."""
+ self.metrics["step"] += 1
+ self.metrics["tool_starts"] += 1
+ self.metrics["starts"] += 1
+
+ tool_starts = self.metrics["tool_starts"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_tool_start", "input_str": input_str})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.metrics)
+
+ self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
+
+ def on_tool_end(self, output: str, **kwargs: Any) -> None:
+ """Run when tool ends running."""
+ self.metrics["step"] += 1
+ self.metrics["tool_ends"] += 1
+ self.metrics["ends"] += 1
+
+ tool_ends = self.metrics["tool_ends"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_tool_end", "output": output})
+ resp.update(self.metrics)
+
+ self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when tool errors."""
+ self.metrics["step"] += 1
+ self.metrics["errors"] += 1
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """
+ Run when agent is ending.
+ """
+ self.metrics["step"] += 1
+ self.metrics["text_ctr"] += 1
+
+ text_ctr = self.metrics["text_ctr"]
+
+ resp: Dict[str, Any] = {}
+ resp.update({"action": "on_text", "text": text})
+ resp.update(self.metrics)
+
+ self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Run when agent ends running."""
+ self.metrics["step"] += 1
+ self.metrics["agent_ends"] += 1
+ self.metrics["ends"] += 1
+
+ agent_ends = self.metrics["agent_ends"]
+ resp: Dict[str, Any] = {}
+ resp.update(
+ {
+ "action": "on_agent_finish",
+ "output": finish.return_values["output"],
+ "log": finish.log,
+ }
+ )
+ resp.update(self.metrics)
+
+ self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Run on agent action."""
+ self.metrics["step"] += 1
+ self.metrics["tool_starts"] += 1
+ self.metrics["starts"] += 1
+
+ tool_starts = self.metrics["tool_starts"]
+ resp: Dict[str, Any] = {}
+ resp.update(
+ {
+ "action": "on_agent_action",
+ "tool": action.tool,
+ "tool_input": action.tool_input,
+ "log": action.log,
+ }
+ )
+ resp.update(self.metrics)
+ self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
+
+ def jsonf(
+ self,
+ data: Dict[str, Any],
+ data_dir: str,
+ filename: str,
+ is_output: Optional[bool] = True,
+ ) -> None:
+ """To log the input data as json file artifact."""
+ file_path = os.path.join(data_dir, f"{filename}.json")
+ save_json(data, file_path)
+ self.run.log_file(file_path, name=filename, is_output=is_output)
+
+ def flush_tracker(self) -> None:
+ """Reset the steps and delete the temporary local directory."""
+ self._reset()
+ shutil.rmtree(self.temp_dir)
diff --git a/libs/community/langchain_community/callbacks/streamlit/__init__.py b/libs/community/langchain_community/callbacks/streamlit/__init__.py
new file mode 100644
index 00000000000..7a0fadb059d
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/streamlit/__init__.py
@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional
+
+from langchain_core.callbacks import BaseCallbackHandler
+
+from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
+ LLMThoughtLabeler as LLMThoughtLabeler,
+)
+from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
+ StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
+)
+
+if TYPE_CHECKING:
+ from streamlit.delta_generator import DeltaGenerator
+
+
+def StreamlitCallbackHandler(
+ parent_container: DeltaGenerator,
+ *,
+ max_thought_containers: int = 4,
+ expand_new_thoughts: bool = True,
+ collapse_completed_thoughts: bool = True,
+ thought_labeler: Optional[LLMThoughtLabeler] = None,
+) -> BaseCallbackHandler:
+ """Callback Handler that writes to a Streamlit app.
+
+ This CallbackHandler is geared towards
+ use with a LangChain Agent; it displays the Agent's LLM and tool-usage "thoughts"
+ inside a series of Streamlit expanders.
+
+ Parameters
+ ----------
+ parent_container
+ The `st.container` that will contain all the Streamlit elements that the
+ Handler creates.
+ max_thought_containers
+ The max number of completed LLM thought containers to show at once. When this
+ threshold is reached, a new thought will cause the oldest thoughts to be
+ collapsed into a "History" expander. Defaults to 4.
+ expand_new_thoughts
+ Each LLM "thought" gets its own `st.expander`. This param controls whether that
+ expander is expanded by default. Defaults to True.
+ collapse_completed_thoughts
+ If True, LLM thought expanders will be collapsed when completed.
+ Defaults to True.
+ thought_labeler
+ An optional custom LLMThoughtLabeler instance. If unspecified, the handler
+ will use the default thought labeling logic. Defaults to None.
+
+ Returns
+ -------
+ A new StreamlitCallbackHandler instance.
+
+ Note that this is an "auto-updating" API: if the installed version of Streamlit
+ has a more recent StreamlitCallbackHandler implementation, an instance of that class
+ will be used.
+
+ """
+ # If we're using a version of Streamlit that implements StreamlitCallbackHandler,
+ # delegate to it instead of using our built-in handler. The official handler is
+ # guaranteed to support the same set of kwargs.
+ try:
+ from streamlit.external.langchain import (
+ StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501
+ )
+
+ return OfficialStreamlitCallbackHandler(
+ parent_container,
+ max_thought_containers=max_thought_containers,
+ expand_new_thoughts=expand_new_thoughts,
+ collapse_completed_thoughts=collapse_completed_thoughts,
+ thought_labeler=thought_labeler,
+ )
+ except ImportError:
+ return _InternalStreamlitCallbackHandler(
+ parent_container,
+ max_thought_containers=max_thought_containers,
+ expand_new_thoughts=expand_new_thoughts,
+ collapse_completed_thoughts=collapse_completed_thoughts,
+ thought_labeler=thought_labeler,
+ )
diff --git a/libs/community/langchain_community/callbacks/streamlit/mutable_expander.py b/libs/community/langchain_community/callbacks/streamlit/mutable_expander.py
new file mode 100644
index 00000000000..7de1e9873fe
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/streamlit/mutable_expander.py
@@ -0,0 +1,156 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
+
+if TYPE_CHECKING:
+ from streamlit.delta_generator import DeltaGenerator
+ from streamlit.type_util import SupportsStr
+
+
+class ChildType(Enum):
+ """The enumerator of the child type."""
+
+ MARKDOWN = "MARKDOWN"
+ EXCEPTION = "EXCEPTION"
+
+
+class ChildRecord(NamedTuple):
+ """The child record as a NamedTuple."""
+
+ type: ChildType
+ kwargs: Dict[str, Any]
+ dg: DeltaGenerator
+
+
+class MutableExpander:
+ """A Streamlit expander that can be renamed and dynamically expanded/collapsed."""
+
+ def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
+ """Create a new MutableExpander.
+
+ Parameters
+ ----------
+ parent_container
+ The `st.container` that the expander will be created inside.
+
+ The expander transparently deletes and recreates its underlying
+ `st.expander` instance when its label changes, and it uses
+ `parent_container` to ensure it recreates this underlying expander in the
+ same location onscreen.
+ label
+ The expander's initial label.
+ expanded
+ The expander's initial `expanded` value.
+ """
+ self._label = label
+ self._expanded = expanded
+ self._parent_cursor = parent_container.empty()
+ self._container = self._parent_cursor.expander(label, expanded)
+ self._child_records: List[ChildRecord] = []
+
+ @property
+ def label(self) -> str:
+ """The expander's label string."""
+ return self._label
+
+ @property
+ def expanded(self) -> bool:
+ """True if the expander was created with `expanded=True`."""
+ return self._expanded
+
+ def clear(self) -> None:
+ """Remove the container and its contents entirely. A cleared container can't
+ be reused.
+ """
+ self._container = self._parent_cursor.empty()
+ self._child_records.clear()
+
+ def append_copy(self, other: MutableExpander) -> None:
+ """Append a copy of another MutableExpander's children to this
+ MutableExpander.
+ """
+ other_records = other._child_records.copy()
+ for record in other_records:
+ self._create_child(record.type, record.kwargs)
+
+ def update(
+ self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
+ ) -> None:
+ """Change the expander's label and expanded state"""
+ if new_label is None:
+ new_label = self._label
+ if new_expanded is None:
+ new_expanded = self._expanded
+
+ if self._label == new_label and self._expanded == new_expanded:
+ # No change!
+ return
+
+ self._label = new_label
+ self._expanded = new_expanded
+ self._container = self._parent_cursor.expander(new_label, new_expanded)
+
+ prev_records = self._child_records
+ self._child_records = []
+
+ # Replay all children into the new container
+ for record in prev_records:
+ self._create_child(record.type, record.kwargs)
+
+ def markdown(
+ self,
+ body: SupportsStr,
+ unsafe_allow_html: bool = False,
+ *,
+ help: Optional[str] = None,
+ index: Optional[int] = None,
+ ) -> int:
+ """Add a Markdown element to the container and return its index."""
+ kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
+ new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
+ record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
+ return self._add_record(record, index)
+
+ def exception(
+ self, exception: BaseException, *, index: Optional[int] = None
+ ) -> int:
+ """Add an Exception element to the container and return its index."""
+ kwargs = {"exception": exception}
+ new_dg = self._get_dg(index).exception(**kwargs)
+ record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
+ return self._add_record(record, index)
+
+ def _create_child(self, type: ChildType, kwargs: Dict[str, Any]) -> None:
+ """Create a new child with the given params"""
+ if type == ChildType.MARKDOWN:
+ self.markdown(**kwargs)
+ elif type == ChildType.EXCEPTION:
+ self.exception(**kwargs)
+ else:
+ raise RuntimeError(f"Unexpected child type {type}")
+
+ def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
+ """Add a ChildRecord to self._children. If `index` is specified, replace
+ the existing record at that index. Otherwise, append the record to the
+ end of the list.
+
+ Return the index of the added record.
+ """
+ if index is not None:
+ # Replace existing child
+ self._child_records[index] = record
+ return index
+
+ # Append new child
+ self._child_records.append(record)
+ return len(self._child_records) - 1
+
+ def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
+ if index is not None:
+ # Existing index: reuse child's DeltaGenerator
+ assert 0 <= index < len(self._child_records), f"Bad index: {index}"
+ return self._child_records[index].dg
+
+ # No index: use container's DeltaGenerator
+ return self._container
diff --git a/libs/community/langchain_community/callbacks/streamlit/streamlit_callback_handler.py b/libs/community/langchain_community/callbacks/streamlit/streamlit_callback_handler.py
new file mode 100644
index 00000000000..b336a09a3a6
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/streamlit/streamlit_callback_handler.py
@@ -0,0 +1,414 @@
+"""Callback Handler that prints to streamlit."""
+
+from __future__ import annotations
+
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+from langchain_community.callbacks.streamlit.mutable_expander import MutableExpander
+
+if TYPE_CHECKING:
+ from streamlit.delta_generator import DeltaGenerator
+
+
+def _convert_newlines(text: str) -> str:
+ """Convert newline characters to markdown newline sequences
+ (space, space, newline).
+ """
+ return text.replace("\n", " \n")
+
+
+CHECKMARK_EMOJI = "β "
+THINKING_EMOJI = ":thinking_face:"
+HISTORY_EMOJI = ":books:"
+EXCEPTION_EMOJI = "β οΈ"
+
+
+class LLMThoughtState(Enum):
+ """Enumerator of the LLMThought state."""
+
+ # The LLM is thinking about what to do next. We don't know which tool we'll run.
+ THINKING = "THINKING"
+ # The LLM has decided to run a tool. We don't have results from the tool yet.
+ RUNNING_TOOL = "RUNNING_TOOL"
+ # We have results from the tool.
+ COMPLETE = "COMPLETE"
+
+
+class ToolRecord(NamedTuple):
+ """The tool record as a NamedTuple."""
+
+ name: str
+ input_str: str
+
+
+class LLMThoughtLabeler:
+ """
+ Generates markdown labels for LLMThought containers. Pass a custom
+ subclass of this to StreamlitCallbackHandler to override its default
+ labeling logic.
+ """
+
+ def get_initial_label(self) -> str:
+ """Return the markdown label for a new LLMThought that doesn't have
+ an associated tool yet.
+ """
+ return f"{THINKING_EMOJI} **Thinking...**"
+
+ def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
+ """Return the label for an LLMThought that has an associated
+ tool.
+
+ Parameters
+ ----------
+ tool
+ The tool's ToolRecord
+
+ is_complete
+ True if the thought is complete; False if the thought
+ is still receiving input.
+
+ Returns
+ -------
+ The markdown label for the thought's container.
+
+ """
+ input = tool.input_str
+ name = tool.name
+ emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
+ if name == "_Exception":
+ emoji = EXCEPTION_EMOJI
+ name = "Parsing error"
+ idx = min([60, len(input)])
+ input = input[0:idx]
+ if len(tool.input_str) > idx:
+ input = input + "..."
+ input = input.replace("\n", " ")
+ label = f"{emoji} **{name}:** {input}"
+ return label
+
+ def get_history_label(self) -> str:
+ """Return a markdown label for the special 'history' container
+ that contains overflow thoughts.
+ """
+ return f"{HISTORY_EMOJI} **History**"
+
+ def get_final_agent_thought_label(self) -> str:
+ """Return the markdown label for the agent's final thought -
+ the "Now I have the answer" thought, that doesn't involve
+ a tool.
+ """
+ return f"{CHECKMARK_EMOJI} **Complete!**"
+
+
+class LLMThought:
+ """A thought in the LLM's thought stream."""
+
+ def __init__(
+ self,
+ parent_container: DeltaGenerator,
+ labeler: LLMThoughtLabeler,
+ expanded: bool,
+ collapse_on_complete: bool,
+ ):
+ """Initialize the LLMThought.
+
+ Args:
+ parent_container: The container we're writing into.
+ labeler: The labeler to use for this thought.
+ expanded: Whether the thought should be expanded by default.
+ collapse_on_complete: Whether the thought should be collapsed.
+ """
+ self._container = MutableExpander(
+ parent_container=parent_container,
+ label=labeler.get_initial_label(),
+ expanded=expanded,
+ )
+ self._state = LLMThoughtState.THINKING
+ self._llm_token_stream = ""
+ self._llm_token_writer_idx: Optional[int] = None
+ self._last_tool: Optional[ToolRecord] = None
+ self._collapse_on_complete = collapse_on_complete
+ self._labeler = labeler
+
+ @property
+ def container(self) -> MutableExpander:
+ """The container we're writing into."""
+ return self._container
+
+ @property
+ def last_tool(self) -> Optional[ToolRecord]:
+ """The last tool executed by this thought"""
+ return self._last_tool
+
+ def _reset_llm_token_stream(self) -> None:
+ self._llm_token_stream = ""
+ self._llm_token_writer_idx = None
+
+ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
+ self._reset_llm_token_stream()
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ # This is only called when the LLM is initialized with `streaming=True`
+ self._llm_token_stream += _convert_newlines(token)
+ self._llm_token_writer_idx = self._container.markdown(
+ self._llm_token_stream, index=self._llm_token_writer_idx
+ )
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ # `response` is the concatenation of all the tokens received by the LLM.
+ # If we're receiving streaming tokens from `on_llm_new_token`, this response
+ # data is redundant
+ self._reset_llm_token_stream()
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ self._container.markdown("**LLM encountered an error...**")
+ self._container.exception(error)
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ # Called with the name of the tool we're about to run (in `serialized[name]`),
+ # and its input. We change our container's label to be the tool name.
+ self._state = LLMThoughtState.RUNNING_TOOL
+ tool_name = serialized["name"]
+ self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
+ self._container.update(
+ new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
+ )
+
+ def on_tool_end(
+ self,
+ output: str,
+ color: Optional[str] = None,
+ observation_prefix: Optional[str] = None,
+ llm_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._container.markdown(f"**{output}**")
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ self._container.markdown("**Tool encountered an error...**")
+ self._container.exception(error)
+
+ def on_agent_action(
+ self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
+ ) -> Any:
+ # Called when we're about to kick off a new tool. The `action` data
+ # tells us the tool we're about to use, and the input we'll give it.
+ # We don't output anything here, because we'll receive this same data
+ # when `on_tool_start` is called immediately after.
+ pass
+
+ def complete(self, final_label: Optional[str] = None) -> None:
+ """Finish the thought."""
+ if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
+ assert (
+ self._last_tool is not None
+ ), "_last_tool should never be null when _state == RUNNING_TOOL"
+ final_label = self._labeler.get_tool_label(
+ self._last_tool, is_complete=True
+ )
+ self._state = LLMThoughtState.COMPLETE
+ if self._collapse_on_complete:
+ self._container.update(new_label=final_label, new_expanded=False)
+ else:
+ self._container.update(new_label=final_label)
+
+ def clear(self) -> None:
+ """Remove the thought from the screen. A cleared thought can't be reused."""
+ self._container.clear()
+
+
+class StreamlitCallbackHandler(BaseCallbackHandler):
+ """A callback handler that writes to a Streamlit app."""
+
+ def __init__(
+ self,
+ parent_container: DeltaGenerator,
+ *,
+ max_thought_containers: int = 4,
+ expand_new_thoughts: bool = True,
+ collapse_completed_thoughts: bool = True,
+ thought_labeler: Optional[LLMThoughtLabeler] = None,
+ ):
+ """Create a StreamlitCallbackHandler instance.
+
+ Parameters
+ ----------
+ parent_container
+ The `st.container` that will contain all the Streamlit elements that the
+ Handler creates.
+ max_thought_containers
+ The max number of completed LLM thought containers to show at once. When
+ this threshold is reached, a new thought will cause the oldest thoughts to
+ be collapsed into a "History" expander. Defaults to 4.
+ expand_new_thoughts
+ Each LLM "thought" gets its own `st.expander`. This param controls whether
+ that expander is expanded by default. Defaults to True.
+ collapse_completed_thoughts
+ If True, LLM thought expanders will be collapsed when completed.
+ Defaults to True.
+ thought_labeler
+ An optional custom LLMThoughtLabeler instance. If unspecified, the handler
+ will use the default thought labeling logic. Defaults to None.
+ """
+ self._parent_container = parent_container
+ self._history_parent = parent_container.container()
+ self._history_container: Optional[MutableExpander] = None
+ self._current_thought: Optional[LLMThought] = None
+ self._completed_thoughts: List[LLMThought] = []
+ self._max_thought_containers = max(max_thought_containers, 1)
+ self._expand_new_thoughts = expand_new_thoughts
+ self._collapse_completed_thoughts = collapse_completed_thoughts
+ self._thought_labeler = thought_labeler or LLMThoughtLabeler()
+
+ def _require_current_thought(self) -> LLMThought:
+ """Return our current LLMThought. Raise an error if we have no current
+ thought.
+ """
+ if self._current_thought is None:
+ raise RuntimeError("Current LLMThought is unexpectedly None!")
+ return self._current_thought
+
+ def _get_last_completed_thought(self) -> Optional[LLMThought]:
+ """Return our most recent completed LLMThought, or None if we don't have one."""
+ if len(self._completed_thoughts) > 0:
+ return self._completed_thoughts[len(self._completed_thoughts) - 1]
+ return None
+
+ @property
+ def _num_thought_containers(self) -> int:
+ """The number of 'thought containers' we're currently showing: the
+ number of completed thought containers, the history container (if it exists),
+ and the current thought container (if it exists).
+ """
+ count = len(self._completed_thoughts)
+ if self._history_container is not None:
+ count += 1
+ if self._current_thought is not None:
+ count += 1
+ return count
+
+ def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
+ """Complete the current thought, optionally assigning it a new label.
+ Add it to our _completed_thoughts list.
+ """
+ thought = self._require_current_thought()
+ thought.complete(final_label)
+ self._completed_thoughts.append(thought)
+ self._current_thought = None
+
+ def _prune_old_thought_containers(self) -> None:
+ """If we have too many thoughts onscreen, move older thoughts to the
+ 'history container.'
+ """
+ while (
+ self._num_thought_containers > self._max_thought_containers
+ and len(self._completed_thoughts) > 0
+ ):
+ # Create our history container if it doesn't exist, and if
+ # max_thought_containers is > 1. (if max_thought_containers is 1, we don't
+ # have room to show history.)
+ if self._history_container is None and self._max_thought_containers > 1:
+ self._history_container = MutableExpander(
+ self._history_parent,
+ label=self._thought_labeler.get_history_label(),
+ expanded=False,
+ )
+
+ oldest_thought = self._completed_thoughts.pop(0)
+ if self._history_container is not None:
+ self._history_container.markdown(oldest_thought.container.label)
+ self._history_container.append_copy(oldest_thought.container)
+ oldest_thought.clear()
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ if self._current_thought is None:
+ self._current_thought = LLMThought(
+ parent_container=self._parent_container,
+ expanded=self._expand_new_thoughts,
+ collapse_on_complete=self._collapse_completed_thoughts,
+ labeler=self._thought_labeler,
+ )
+
+ self._current_thought.on_llm_start(serialized, prompts)
+
+ # We don't prune_old_thought_containers here, because our container won't
+ # be visible until it has a child.
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ self._require_current_thought().on_llm_new_token(token, **kwargs)
+ self._prune_old_thought_containers()
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ self._require_current_thought().on_llm_end(response, **kwargs)
+ self._prune_old_thought_containers()
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ self._require_current_thought().on_llm_error(error, **kwargs)
+ self._prune_old_thought_containers()
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
+ self._prune_old_thought_containers()
+
+ def on_tool_end(
+ self,
+ output: str,
+ color: Optional[str] = None,
+ observation_prefix: Optional[str] = None,
+ llm_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._require_current_thought().on_tool_end(
+ output, color, observation_prefix, llm_prefix, **kwargs
+ )
+ self._complete_current_thought()
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ self._require_current_thought().on_tool_error(error, **kwargs)
+ self._prune_old_thought_containers()
+
+ def on_text(
+ self,
+ text: str,
+ color: Optional[str] = None,
+ end: str = "",
+ **kwargs: Any,
+ ) -> None:
+ pass
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ pass
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ pass
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ pass
+
+ def on_agent_action(
+ self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
+ ) -> Any:
+ self._require_current_thought().on_agent_action(action, color, **kwargs)
+ self._prune_old_thought_containers()
+
+ def on_agent_finish(
+ self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
+ ) -> None:
+ if self._current_thought is not None:
+ self._current_thought.complete(
+ self._thought_labeler.get_final_agent_thought_label()
+ )
+ self._current_thought = None
diff --git a/libs/community/langchain_community/callbacks/tracers/__init__.py b/libs/community/langchain_community/callbacks/tracers/__init__.py
new file mode 100644
index 00000000000..8af691585a6
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/tracers/__init__.py
@@ -0,0 +1,18 @@
+"""Tracers that record execution of LangChain runs."""
+
+from langchain_core.tracers.langchain import LangChainTracer
+from langchain_core.tracers.langchain_v1 import LangChainTracerV1
+from langchain_core.tracers.stdout import (
+ ConsoleCallbackHandler,
+ FunctionCallbackHandler,
+)
+
+from langchain_community.callbacks.tracers.wandb import WandbTracer
+
+__all__ = [
+ "ConsoleCallbackHandler",
+ "FunctionCallbackHandler",
+ "LangChainTracer",
+ "LangChainTracerV1",
+ "WandbTracer",
+]
diff --git a/libs/community/langchain_community/callbacks/tracers/comet.py b/libs/community/langchain_community/callbacks/tracers/comet.py
new file mode 100644
index 00000000000..a55972b5937
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/tracers/comet.py
@@ -0,0 +1,138 @@
+from types import ModuleType, SimpleNamespace
+from typing import TYPE_CHECKING, Any, Callable, Dict
+
+from langchain_core.tracers import BaseTracer
+
+if TYPE_CHECKING:
+ from uuid import UUID
+
+ from comet_llm import Span
+ from comet_llm.chains.chain import Chain
+
+ from langchain_community.callbacks.tracers.schemas import Run
+
+
+def _get_run_type(run: "Run") -> str:
+ if isinstance(run.run_type, str):
+ return run.run_type
+ elif hasattr(run.run_type, "value"):
+ return run.run_type.value
+ else:
+ return str(run.run_type)
+
+
+def import_comet_llm_api() -> SimpleNamespace:
+ """Import comet_llm api and raise an error if it is not installed."""
+ try:
+ from comet_llm import (
+ experiment_info, # noqa: F401
+ flush, # noqa: F401
+ )
+ from comet_llm.chains import api as chain_api # noqa: F401
+ from comet_llm.chains import (
+ chain, # noqa: F401
+ span, # noqa: F401
+ )
+
+ except ImportError:
+ raise ImportError(
+ "To use the CometTracer you need to have the "
+ "`comet_llm>=2.0.0` python package installed. Please install it with"
+ " `pip install -U comet_llm`"
+ )
+ return SimpleNamespace(
+ chain=chain,
+ span=span,
+ chain_api=chain_api,
+ experiment_info=experiment_info,
+ flush=flush,
+ )
+
+
+class CometTracer(BaseTracer):
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self._span_map: Dict["UUID", "Span"] = {}
+ self._chains_map: Dict["UUID", "Chain"] = {}
+ self._initialize_comet_modules()
+
+ def _initialize_comet_modules(self) -> None:
+ comet_llm_api = import_comet_llm_api()
+ self._chain: ModuleType = comet_llm_api.chain
+ self._span: ModuleType = comet_llm_api.span
+ self._chain_api: ModuleType = comet_llm_api.chain_api
+ self._experiment_info: ModuleType = comet_llm_api.experiment_info
+ self._flush: Callable[[], None] = comet_llm_api.flush
+
+ def _persist_run(self, run: "Run") -> None:
+ chain_ = self._chains_map[run.id]
+ chain_.set_outputs(outputs=run.outputs)
+ self._chain_api.log_chain(chain_)
+
+ def _process_start_trace(self, run: "Run") -> None:
+ if not run.parent_run_id:
+ # This is the first run, which maps to a chain
+ chain_: "Chain" = self._chain.Chain(
+ inputs=run.inputs,
+ metadata=None,
+ experiment_info=self._experiment_info.get(),
+ )
+ self._chains_map[run.id] = chain_
+ else:
+ span: "Span" = self._span.Span(
+ inputs=run.inputs,
+ category=_get_run_type(run),
+ metadata=run.extra,
+ name=run.name,
+ )
+ span.__api__start__(self._chains_map[run.parent_run_id])
+ self._chains_map[run.id] = self._chains_map[run.parent_run_id]
+ self._span_map[run.id] = span
+
+ def _process_end_trace(self, run: "Run") -> None:
+ if not run.parent_run_id:
+ pass
+ # Langchain will call _persist_run for us
+ else:
+ span = self._span_map[run.id]
+ span.set_outputs(outputs=run.outputs)
+ span.__api__end__()
+
+ def flush(self) -> None:
+ self._flush()
+
+ def _on_llm_start(self, run: "Run") -> None:
+ """Process the LLM Run upon start."""
+ self._process_start_trace(run)
+
+ def _on_llm_end(self, run: "Run") -> None:
+ """Process the LLM Run."""
+ self._process_end_trace(run)
+
+ def _on_llm_error(self, run: "Run") -> None:
+ """Process the LLM Run upon error."""
+ self._process_end_trace(run)
+
+ def _on_chain_start(self, run: "Run") -> None:
+ """Process the Chain Run upon start."""
+ self._process_start_trace(run)
+
+ def _on_chain_end(self, run: "Run") -> None:
+ """Process the Chain Run."""
+ self._process_end_trace(run)
+
+ def _on_chain_error(self, run: "Run") -> None:
+ """Process the Chain Run upon error."""
+ self._process_end_trace(run)
+
+ def _on_tool_start(self, run: "Run") -> None:
+ """Process the Tool Run upon start."""
+ self._process_start_trace(run)
+
+ def _on_tool_end(self, run: "Run") -> None:
+ """Process the Tool Run."""
+ self._process_end_trace(run)
+
+ def _on_tool_error(self, run: "Run") -> None:
+ """Process the Tool Run upon error."""
+ self._process_end_trace(run)
diff --git a/libs/community/langchain_community/callbacks/tracers/wandb.py b/libs/community/langchain_community/callbacks/tracers/wandb.py
new file mode 100644
index 00000000000..31df3352734
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/tracers/wandb.py
@@ -0,0 +1,514 @@
+"""A Tracer Implementation that records activity to Weights & Biases."""
+from __future__ import annotations
+
+import json
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ TypedDict,
+ Union,
+)
+
+from langchain_core.tracers.base import BaseTracer
+from langchain_core.tracers.schemas import Run
+
+if TYPE_CHECKING:
+ from wandb import Settings as WBSettings
+ from wandb.sdk.data_types.trace_tree import Span
+ from wandb.sdk.lib.paths import StrPath
+ from wandb.wandb_run import Run as WBRun
+
+
+PRINT_WARNINGS = True
+
+
+def _serialize_io(run_inputs: Optional[dict]) -> dict:
+ if not run_inputs:
+ return {}
+ from google.protobuf.json_format import MessageToJson
+ from google.protobuf.message import Message
+
+ serialized_inputs = {}
+ for key, value in run_inputs.items():
+ if isinstance(value, Message):
+ serialized_inputs[key] = MessageToJson(value)
+ elif key == "input_documents":
+ serialized_inputs.update(
+ {f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
+ )
+ else:
+ serialized_inputs[key] = value
+ return serialized_inputs
+
+
+class RunProcessor:
+ """Handles the conversion of a LangChain Runs into a WBTraceTree."""
+
+ def __init__(self, wandb_module: Any, trace_module: Any):
+ self.wandb = wandb_module
+ self.trace_tree = trace_module
+
+ def process_span(self, run: Run) -> Optional["Span"]:
+ """Converts a LangChain Run into a W&B Trace Span.
+ :param run: The LangChain Run to convert.
+ :return: The converted W&B Trace Span.
+ """
+ try:
+ span = self._convert_lc_run_to_wb_span(run)
+ return span
+ except Exception as e:
+ if PRINT_WARNINGS:
+ self.wandb.termwarn(
+ f"Skipping trace saving - unable to safely convert LangChain Run "
+ f"into W&B Trace due to: {e}"
+ )
+ return None
+
+ def _convert_run_to_wb_span(self, run: Run) -> "Span":
+ """Base utility to create a span from a run.
+ :param run: The run to convert.
+ :return: The converted Span.
+ """
+ attributes = {**run.extra} if run.extra else {}
+ attributes["execution_order"] = run.execution_order
+
+ return self.trace_tree.Span(
+ span_id=str(run.id) if run.id is not None else None,
+ name=run.name,
+ start_time_ms=int(run.start_time.timestamp() * 1000),
+ end_time_ms=int(run.end_time.timestamp() * 1000)
+ if run.end_time is not None
+ else None,
+ status_code=self.trace_tree.StatusCode.SUCCESS
+ if run.error is None
+ else self.trace_tree.StatusCode.ERROR,
+ status_message=run.error,
+ attributes=attributes,
+ )
+
+ def _convert_llm_run_to_wb_span(self, run: Run) -> "Span":
+ """Converts a LangChain LLM Run into a W&B Trace Span.
+ :param run: The LangChain LLM Run to convert.
+ :return: The converted W&B Trace Span.
+ """
+ base_span = self._convert_run_to_wb_span(run)
+ if base_span.attributes is None:
+ base_span.attributes = {}
+ base_span.attributes["llm_output"] = (run.outputs or {}).get("llm_output", {})
+
+ base_span.results = [
+ self.trace_tree.Result(
+ inputs={"prompt": prompt},
+ outputs={
+ f"gen_{g_i}": gen["text"]
+ for g_i, gen in enumerate(run.outputs["generations"][ndx])
+ }
+ if (
+ run.outputs is not None
+ and len(run.outputs["generations"]) > ndx
+ and len(run.outputs["generations"][ndx]) > 0
+ )
+ else None,
+ )
+ for ndx, prompt in enumerate(run.inputs["prompts"] or [])
+ ]
+ base_span.span_kind = self.trace_tree.SpanKind.LLM
+
+ return base_span
+
+ def _convert_chain_run_to_wb_span(self, run: Run) -> "Span":
+ """Converts a LangChain Chain Run into a W&B Trace Span.
+ :param run: The LangChain Chain Run to convert.
+ :return: The converted W&B Trace Span.
+ """
+ base_span = self._convert_run_to_wb_span(run)
+
+ base_span.results = [
+ self.trace_tree.Result(
+ inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
+ )
+ ]
+ base_span.child_spans = [
+ self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
+ ]
+ base_span.span_kind = (
+ self.trace_tree.SpanKind.AGENT
+ if "agent" in run.name.lower()
+ else self.trace_tree.SpanKind.CHAIN
+ )
+
+ return base_span
+
+ def _convert_tool_run_to_wb_span(self, run: Run) -> "Span":
+ """Converts a LangChain Tool Run into a W&B Trace Span.
+ :param run: The LangChain Tool Run to convert.
+ :return: The converted W&B Trace Span.
+ """
+ base_span = self._convert_run_to_wb_span(run)
+ base_span.results = [
+ self.trace_tree.Result(
+ inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
+ )
+ ]
+ base_span.child_spans = [
+ self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
+ ]
+ base_span.span_kind = self.trace_tree.SpanKind.TOOL
+
+ return base_span
+
+ def _convert_lc_run_to_wb_span(self, run: Run) -> "Span":
+ """Utility to convert any generic LangChain Run into a W&B Trace Span.
+ :param run: The LangChain Run to convert.
+ :return: The converted W&B Trace Span.
+ """
+ if run.run_type == "llm":
+ return self._convert_llm_run_to_wb_span(run)
+ elif run.run_type == "chain":
+ return self._convert_chain_run_to_wb_span(run)
+ elif run.run_type == "tool":
+ return self._convert_tool_run_to_wb_span(run)
+ else:
+ return self._convert_run_to_wb_span(run)
+
+ def process_model(self, run: Run) -> Optional[Dict[str, Any]]:
+ """Utility to process a run for wandb model_dict serialization.
+ :param run: The run to process.
+ :return: The convert model_dict to pass to WBTraceTree.
+ """
+ try:
+ data = json.loads(run.json())
+ processed = self.flatten_run(data)
+ keep_keys = (
+ "id",
+ "name",
+ "serialized",
+ "inputs",
+ "outputs",
+ "parent_run_id",
+ "execution_order",
+ )
+ processed = self.truncate_run_iterative(processed, keep_keys=keep_keys)
+ exact_keys, partial_keys = ("lc", "type"), ("api_key",)
+ processed = self.modify_serialized_iterative(
+ processed, exact_keys=exact_keys, partial_keys=partial_keys
+ )
+ output = self.build_tree(processed)
+ return output
+ except Exception as e:
+ if PRINT_WARNINGS:
+ self.wandb.termwarn(f"WARNING: Failed to serialize model: {e}")
+ return None
+
+ def flatten_run(self, run: Dict[str, Any]) -> List[Dict[str, Any]]:
+ """Utility to flatten a nest run object into a list of runs.
+ :param run: The base run to flatten.
+ :return: The flattened list of runs.
+ """
+
+ def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Utility to recursively flatten a list of child runs in a run.
+ :param child_runs: The list of child runs to flatten.
+ :return: The flattened list of runs.
+ """
+ if child_runs is None:
+ return []
+
+ result = []
+ for item in child_runs:
+ child_runs = item.pop("child_runs", [])
+ result.append(item)
+ result.extend(flatten(child_runs))
+
+ return result
+
+ return flatten([run])
+
+ def truncate_run_iterative(
+ self, runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
+ ) -> List[Dict[str, Any]]:
+ """Utility to truncate a list of runs dictionaries to only keep the specified
+ keys in each run.
+ :param runs: The list of runs to truncate.
+ :param keep_keys: The keys to keep in each run.
+ :return: The truncated list of runs.
+ """
+
+ def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]:
+ """Utility to truncate a single run dictionary to only keep the specified
+ keys.
+ :param run: The run dictionary to truncate.
+ :return: The truncated run dictionary
+ """
+ new_dict = {}
+ for key in run:
+ if key in keep_keys:
+ new_dict[key] = run.get(key)
+ return new_dict
+
+ return list(map(truncate_single, runs))
+
+ def modify_serialized_iterative(
+ self,
+ runs: List[Dict[str, Any]],
+ exact_keys: Tuple[str, ...] = (),
+ partial_keys: Tuple[str, ...] = (),
+ ) -> List[Dict[str, Any]]:
+ """Utility to modify the serialized field of a list of runs dictionaries.
+ removes any keys that match the exact_keys and any keys that contain any of the
+ partial_keys.
+ recursively moves the dictionaries under the kwargs key to the top level.
+ changes the "id" field to a string "_kind" field that tells WBTraceTree how to
+ visualize the run. promotes the "serialized" field to the top level.
+
+ :param runs: The list of runs to modify.
+ :param exact_keys: A tuple of keys to remove from the serialized field.
+ :param partial_keys: A tuple of partial keys to remove from the serialized
+ field.
+ :return: The modified list of runs.
+ """
+
+ def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
+ """Recursively removes exact and partial keys from a dictionary.
+ :param obj: The dictionary to remove keys from.
+ :return: The modified dictionary.
+ """
+ if isinstance(obj, dict):
+ obj = {
+ k: v
+ for k, v in obj.items()
+ if k not in exact_keys
+ and not any(partial in k for partial in partial_keys)
+ }
+ for k, v in obj.items():
+ obj[k] = remove_exact_and_partial_keys(v)
+ elif isinstance(obj, list):
+ obj = [remove_exact_and_partial_keys(x) for x in obj]
+ return obj
+
+ def handle_id_and_kwargs(
+ obj: Dict[str, Any], root: bool = False
+ ) -> Dict[str, Any]:
+ """Recursively handles the id and kwargs fields of a dictionary.
+ changes the id field to a string "_kind" field that tells WBTraceTree how
+ to visualize the run. recursively moves the dictionaries under the kwargs
+ key to the top level.
+ :param obj: a run dictionary with id and kwargs fields.
+ :param root: whether this is the root dictionary or the serialized
+ dictionary.
+ :return: The modified dictionary.
+ """
+ if isinstance(obj, dict):
+ if ("id" in obj or "name" in obj) and not root:
+ _kind = obj.get("id")
+ if not _kind:
+ _kind = [obj.get("name")]
+ obj["_kind"] = _kind[-1]
+ obj.pop("id", None)
+ obj.pop("name", None)
+ if "kwargs" in obj:
+ kwargs = obj.pop("kwargs")
+ for k, v in kwargs.items():
+ obj[k] = v
+ for k, v in obj.items():
+ obj[k] = handle_id_and_kwargs(v)
+ elif isinstance(obj, list):
+ obj = [handle_id_and_kwargs(x) for x in obj]
+ return obj
+
+ def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]:
+ """Transforms the serialized field of a run dictionary to be compatible
+ with WBTraceTree.
+ :param serialized: The serialized field of a run dictionary.
+ :return: The transformed serialized field.
+ """
+ serialized = handle_id_and_kwargs(serialized, root=True)
+ serialized = remove_exact_and_partial_keys(serialized)
+ return serialized
+
+ def transform_run(run: Dict[str, Any]) -> Dict[str, Any]:
+ """Transforms a run dictionary to be compatible with WBTraceTree.
+ :param run: The run dictionary to transform.
+ :return: The transformed run dictionary.
+ """
+ transformed_dict = transform_serialized(run)
+
+ serialized = transformed_dict.pop("serialized")
+ for k, v in serialized.items():
+ transformed_dict[k] = v
+
+ _kind = transformed_dict.get("_kind", None)
+ name = transformed_dict.pop("name", None)
+ exec_ord = transformed_dict.pop("execution_order", None)
+
+ if not name:
+ name = _kind
+
+ output_dict = {
+ f"{exec_ord}_{name}": transformed_dict,
+ }
+ return output_dict
+
+ return list(map(transform_run, runs))
+
+ def build_tree(self, runs: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """Builds a nested dictionary from a list of runs.
+ :param runs: The list of runs to build the tree from.
+ :return: The nested dictionary representing the langchain Run in a tree
+ structure compatible with WBTraceTree.
+ """
+ id_to_data = {}
+ child_to_parent = {}
+
+ for entity in runs:
+ for key, data in entity.items():
+ id_val = data.pop("id", None)
+ parent_run_id = data.pop("parent_run_id", None)
+ id_to_data[id_val] = {key: data}
+ if parent_run_id:
+ child_to_parent[id_val] = parent_run_id
+
+ for child_id, parent_id in child_to_parent.items():
+ parent_dict = id_to_data[parent_id]
+ parent_dict[next(iter(parent_dict))][
+ next(iter(id_to_data[child_id]))
+ ] = id_to_data[child_id][next(iter(id_to_data[child_id]))]
+
+ root_dict = next(
+ data for id_val, data in id_to_data.items() if id_val not in child_to_parent
+ )
+
+ return root_dict
+
+
+class WandbRunArgs(TypedDict):
+ """Arguments for the WandbTracer."""
+
+ job_type: Optional[str]
+ dir: Optional[StrPath]
+ config: Union[Dict, str, None]
+ project: Optional[str]
+ entity: Optional[str]
+ reinit: Optional[bool]
+ tags: Optional[Sequence]
+ group: Optional[str]
+ name: Optional[str]
+ notes: Optional[str]
+ magic: Optional[Union[dict, str, bool]]
+ config_exclude_keys: Optional[List[str]]
+ config_include_keys: Optional[List[str]]
+ anonymous: Optional[str]
+ mode: Optional[str]
+ allow_val_change: Optional[bool]
+ resume: Optional[Union[bool, str]]
+ force: Optional[bool]
+ tensorboard: Optional[bool]
+ sync_tensorboard: Optional[bool]
+ monitor_gym: Optional[bool]
+ save_code: Optional[bool]
+ id: Optional[str]
+ settings: Union[WBSettings, Dict[str, Any], None]
+
+
+class WandbTracer(BaseTracer):
+ """Callback Handler that logs to Weights and Biases.
+
+ This handler will log the model architecture and run traces to Weights and Biases.
+ This will ensure that all LangChain activity is logged to W&B.
+ """
+
+ _run: Optional[WBRun] = None
+ _run_args: Optional[WandbRunArgs] = None
+
+ def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
+ """Initializes the WandbTracer.
+
+ Parameters:
+ run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
+ provided, `wandb.init()` will be called with no arguments. Please
+ refer to the `wandb.init` for more details.
+
+ To use W&B to monitor all LangChain activity, add this tracer like any other
+ LangChain callback:
+ ```
+ from wandb.integration.langchain import WandbTracer
+
+ tracer = WandbTracer()
+ chain = LLMChain(llm, callbacks=[tracer])
+ # ...end of notebook / script:
+ tracer.finish()
+ ```
+ """
+ super().__init__(**kwargs)
+ try:
+ import wandb
+ from wandb.sdk.data_types import trace_tree
+ except ImportError as e:
+ raise ImportError(
+ "Could not import wandb python package."
+ "Please install it with `pip install -U wandb`."
+ ) from e
+ self._wandb = wandb
+ self._trace_tree = trace_tree
+ self._run_args = run_args
+ self._ensure_run(should_print_url=(wandb.run is None))
+ self.run_processor = RunProcessor(self._wandb, self._trace_tree)
+
+ def finish(self) -> None:
+ """Waits for all asynchronous processes to finish and data to upload.
+
+ Proxy for `wandb.finish()`.
+ """
+ self._wandb.finish()
+
+ def _log_trace_from_run(self, run: Run) -> None:
+ """Logs a LangChain Run to W*B as a W&B Trace."""
+ self._ensure_run()
+
+ root_span = self.run_processor.process_span(run)
+ model_dict = self.run_processor.process_model(run)
+
+ if root_span is None:
+ return
+
+ model_trace = self._trace_tree.WBTraceTree(
+ root_span=root_span,
+ model_dict=model_dict,
+ )
+ if self._wandb.run is not None:
+ self._wandb.run.log({"langchain_trace": model_trace})
+
+ def _ensure_run(self, should_print_url: bool = False) -> None:
+ """Ensures an active W&B run exists.
+
+ If not, will start a new run with the provided run_args.
+ """
+ if self._wandb.run is None:
+ run_args = self._run_args or {} # type: ignore
+ run_args: dict = {**run_args} # type: ignore
+
+ if "settings" not in run_args: # type: ignore
+ run_args["settings"] = {"silent": True} # type: ignore
+
+ self._wandb.init(**run_args)
+ if self._wandb.run is not None:
+ if should_print_url:
+ run_url = self._wandb.run.settings.run_url
+ self._wandb.termlog(
+ f"Streaming LangChain activity to W&B at {run_url}\n"
+ "`WandbTracer` is currently in beta.\n"
+ "Please report any issues to "
+ "https://github.com/wandb/wandb/issues with the tag "
+ "`langchain`."
+ )
+
+ self._wandb.run._label(repo="langchain")
+
+ def _persist_run(self, run: "Run") -> None:
+ """Persist a run."""
+ self._log_trace_from_run(run)
diff --git a/libs/community/langchain_community/callbacks/trubrics_callback.py b/libs/community/langchain_community/callbacks/trubrics_callback.py
new file mode 100644
index 00000000000..fa697a756ed
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/trubrics_callback.py
@@ -0,0 +1,125 @@
+import os
+from typing import Any, Dict, List, Optional
+from uuid import UUID
+
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ FunctionMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import LLMResult
+
+
+def _convert_message_to_dict(message: BaseMessage) -> dict:
+ message_dict: Dict[str, Any]
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ if "function_call" in message.additional_kwargs:
+ message_dict["function_call"] = message.additional_kwargs["function_call"]
+ # If function call only, content is None not empty string
+ if message_dict["content"] == "":
+ message_dict["content"] = None
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ message_dict = {
+ "role": "function",
+ "content": message.content,
+ "name": message.name,
+ }
+ else:
+ raise TypeError(f"Got unknown type {message}")
+ if "name" in message.additional_kwargs:
+ message_dict["name"] = message.additional_kwargs["name"]
+ return message_dict
+
+
+class TrubricsCallbackHandler(BaseCallbackHandler):
+ """
+ Callback handler for Trubrics.
+
+ Args:
+ project: a trubrics project, default project is "default"
+ email: a trubrics account email, can equally be set in env variables
+ password: a trubrics account password, can equally be set in env variables
+ **kwargs: all other kwargs are parsed and set to trubrics prompt variables,
+ or added to the `metadata` dict
+ """
+
+ def __init__(
+ self,
+ project: str = "default",
+ email: Optional[str] = None,
+ password: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__()
+ try:
+ from trubrics import Trubrics
+ except ImportError:
+ raise ImportError(
+ "The TrubricsCallbackHandler requires installation of "
+ "the trubrics package. "
+ "Please install it with `pip install trubrics`."
+ )
+
+ self.trubrics = Trubrics(
+ project=project,
+ email=email or os.environ["TRUBRICS_EMAIL"],
+ password=password or os.environ["TRUBRICS_PASSWORD"],
+ )
+ self.config_model: dict = {}
+ self.prompt: Optional[str] = None
+ self.messages: Optional[list] = None
+ self.trubrics_kwargs: Optional[dict] = kwargs if kwargs else None
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ self.prompt = prompts[0]
+
+ def on_chat_model_start(
+ self,
+ serialized: Dict[str, Any],
+ messages: List[List[BaseMessage]],
+ **kwargs: Any,
+ ) -> None:
+ self.messages = [_convert_message_to_dict(message) for message in messages[0]]
+ self.prompt = self.messages[-1]["content"]
+
+ def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None:
+ tags = ["langchain"]
+ user_id = None
+ session_id = None
+ metadata: dict = {"langchain_run_id": run_id}
+ if self.messages:
+ metadata["messages"] = self.messages
+ if self.trubrics_kwargs:
+ if self.trubrics_kwargs.get("tags"):
+ tags.append(*self.trubrics_kwargs.pop("tags"))
+ user_id = self.trubrics_kwargs.pop("user_id", None)
+ session_id = self.trubrics_kwargs.pop("session_id", None)
+ metadata.update(self.trubrics_kwargs)
+
+ for generation in response.generations:
+ self.trubrics.log_prompt(
+ config_model={
+ "model": response.llm_output.get("model_name")
+ if response.llm_output
+ else "NA"
+ },
+ prompt=self.prompt,
+ generation=generation[0].text,
+ user_id=user_id,
+ session_id=session_id,
+ tags=tags,
+ metadata=metadata,
+ )
diff --git a/libs/community/langchain_community/callbacks/utils.py b/libs/community/langchain_community/callbacks/utils.py
new file mode 100644
index 00000000000..b83bace6016
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/utils.py
@@ -0,0 +1,258 @@
+import hashlib
+from pathlib import Path
+from typing import Any, Dict, Iterable, Tuple, Union
+
+
+def import_spacy() -> Any:
+ """Import the spacy python package and raise an error if it is not installed."""
+ try:
+ import spacy
+ except ImportError:
+ raise ImportError(
+ "This callback manager requires the `spacy` python "
+ "package installed. Please install it with `pip install spacy`"
+ )
+ return spacy
+
+
+def import_pandas() -> Any:
+ """Import the pandas python package and raise an error if it is not installed."""
+ try:
+ import pandas
+ except ImportError:
+ raise ImportError(
+ "This callback manager requires the `pandas` python "
+ "package installed. Please install it with `pip install pandas`"
+ )
+ return pandas
+
+
+def import_textstat() -> Any:
+ """Import the textstat python package and raise an error if it is not installed."""
+ try:
+ import textstat
+ except ImportError:
+ raise ImportError(
+ "This callback manager requires the `textstat` python "
+ "package installed. Please install it with `pip install textstat`"
+ )
+ return textstat
+
+
+def _flatten_dict(
+ nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
+) -> Iterable[Tuple[str, Any]]:
+ """
+ Generator that yields flattened items from a nested dictionary for a flat dict.
+
+ Parameters:
+ nested_dict (dict): The nested dictionary to flatten.
+ parent_key (str): The prefix to prepend to the keys of the flattened dict.
+ sep (str): The separator to use between the parent key and the key of the
+ flattened dictionary.
+
+ Yields:
+ (str, any): A key-value pair from the flattened dictionary.
+ """
+ for key, value in nested_dict.items():
+ new_key = parent_key + sep + key if parent_key else key
+ if isinstance(value, dict):
+ yield from _flatten_dict(value, new_key, sep)
+ else:
+ yield new_key, value
+
+
+def flatten_dict(
+ nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
+) -> Dict[str, Any]:
+ """Flattens a nested dictionary into a flat dictionary.
+
+ Parameters:
+ nested_dict (dict): The nested dictionary to flatten.
+ parent_key (str): The prefix to prepend to the keys of the flattened dict.
+ sep (str): The separator to use between the parent key and the key of the
+ flattened dictionary.
+
+ Returns:
+ (dict): A flat dictionary.
+
+ """
+ flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
+ return flat_dict
+
+
+def hash_string(s: str) -> str:
+ """Hash a string using sha1.
+
+ Parameters:
+ s (str): The string to hash.
+
+ Returns:
+ (str): The hashed string.
+ """
+ return hashlib.sha1(s.encode("utf-8")).hexdigest()
+
+
+def load_json(json_path: Union[str, Path]) -> str:
+ """Load json file to a string.
+
+ Parameters:
+ json_path (str): The path to the json file.
+
+ Returns:
+ (str): The string representation of the json file.
+ """
+ with open(json_path, "r") as f:
+ data = f.read()
+ return data
+
+
+class BaseMetadataCallbackHandler:
+ """This class handles the metadata and associated function states for callbacks.
+
+ Attributes:
+ step (int): The current step.
+ starts (int): The number of times the start method has been called.
+ ends (int): The number of times the end method has been called.
+ errors (int): The number of times the error method has been called.
+ text_ctr (int): The number of times the text method has been called.
+ ignore_llm_ (bool): Whether to ignore llm callbacks.
+ ignore_chain_ (bool): Whether to ignore chain callbacks.
+ ignore_agent_ (bool): Whether to ignore agent callbacks.
+ ignore_retriever_ (bool): Whether to ignore retriever callbacks.
+ always_verbose_ (bool): Whether to always be verbose.
+ chain_starts (int): The number of times the chain start method has been called.
+ chain_ends (int): The number of times the chain end method has been called.
+ llm_starts (int): The number of times the llm start method has been called.
+ llm_ends (int): The number of times the llm end method has been called.
+ llm_streams (int): The number of times the text method has been called.
+ tool_starts (int): The number of times the tool start method has been called.
+ tool_ends (int): The number of times the tool end method has been called.
+ agent_ends (int): The number of times the agent end method has been called.
+ on_llm_start_records (list): A list of records of the on_llm_start method.
+ on_llm_token_records (list): A list of records of the on_llm_token method.
+ on_llm_end_records (list): A list of records of the on_llm_end method.
+ on_chain_start_records (list): A list of records of the on_chain_start method.
+ on_chain_end_records (list): A list of records of the on_chain_end method.
+ on_tool_start_records (list): A list of records of the on_tool_start method.
+ on_tool_end_records (list): A list of records of the on_tool_end method.
+ on_agent_finish_records (list): A list of records of the on_agent_end method.
+ """
+
+ def __init__(self) -> None:
+ self.step = 0
+
+ self.starts = 0
+ self.ends = 0
+ self.errors = 0
+ self.text_ctr = 0
+
+ self.ignore_llm_ = False
+ self.ignore_chain_ = False
+ self.ignore_agent_ = False
+ self.ignore_retriever_ = False
+ self.always_verbose_ = False
+
+ self.chain_starts = 0
+ self.chain_ends = 0
+
+ self.llm_starts = 0
+ self.llm_ends = 0
+ self.llm_streams = 0
+
+ self.tool_starts = 0
+ self.tool_ends = 0
+
+ self.agent_ends = 0
+
+ self.on_llm_start_records: list = []
+ self.on_llm_token_records: list = []
+ self.on_llm_end_records: list = []
+
+ self.on_chain_start_records: list = []
+ self.on_chain_end_records: list = []
+
+ self.on_tool_start_records: list = []
+ self.on_tool_end_records: list = []
+
+ self.on_text_records: list = []
+ self.on_agent_finish_records: list = []
+ self.on_agent_action_records: list = []
+
+ @property
+ def always_verbose(self) -> bool:
+ """Whether to call verbose callbacks even if verbose is False."""
+ return self.always_verbose_
+
+ @property
+ def ignore_llm(self) -> bool:
+ """Whether to ignore LLM callbacks."""
+ return self.ignore_llm_
+
+ @property
+ def ignore_chain(self) -> bool:
+ """Whether to ignore chain callbacks."""
+ return self.ignore_chain_
+
+ @property
+ def ignore_agent(self) -> bool:
+ """Whether to ignore agent callbacks."""
+ return self.ignore_agent_
+
+ def get_custom_callback_meta(self) -> Dict[str, Any]:
+ return {
+ "step": self.step,
+ "starts": self.starts,
+ "ends": self.ends,
+ "errors": self.errors,
+ "text_ctr": self.text_ctr,
+ "chain_starts": self.chain_starts,
+ "chain_ends": self.chain_ends,
+ "llm_starts": self.llm_starts,
+ "llm_ends": self.llm_ends,
+ "llm_streams": self.llm_streams,
+ "tool_starts": self.tool_starts,
+ "tool_ends": self.tool_ends,
+ "agent_ends": self.agent_ends,
+ }
+
+ def reset_callback_meta(self) -> None:
+ """Reset the callback metadata."""
+ self.step = 0
+
+ self.starts = 0
+ self.ends = 0
+ self.errors = 0
+ self.text_ctr = 0
+
+ self.ignore_llm_ = False
+ self.ignore_chain_ = False
+ self.ignore_agent_ = False
+ self.always_verbose_ = False
+
+ self.chain_starts = 0
+ self.chain_ends = 0
+
+ self.llm_starts = 0
+ self.llm_ends = 0
+ self.llm_streams = 0
+
+ self.tool_starts = 0
+ self.tool_ends = 0
+
+ self.agent_ends = 0
+
+ self.on_llm_start_records = []
+ self.on_llm_token_records = []
+ self.on_llm_end_records = []
+
+ self.on_chain_start_records = []
+ self.on_chain_end_records = []
+
+ self.on_tool_start_records = []
+ self.on_tool_end_records = []
+
+ self.on_text_records = []
+ self.on_agent_finish_records = []
+ self.on_agent_action_records = []
+ return None
diff --git a/libs/community/langchain_community/callbacks/wandb_callback.py b/libs/community/langchain_community/callbacks/wandb_callback.py
new file mode 100644
index 00000000000..35559ea539c
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/wandb_callback.py
@@ -0,0 +1,587 @@
+import json
+import tempfile
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+
+from langchain_community.callbacks.utils import (
+ BaseMetadataCallbackHandler,
+ flatten_dict,
+ hash_string,
+ import_pandas,
+ import_spacy,
+ import_textstat,
+)
+
+
+def import_wandb() -> Any:
+ """Import the wandb python package and raise an error if it is not installed."""
+ try:
+ import wandb # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "To use the wandb callback manager you need to have the `wandb` python "
+ "package installed. Please install it with `pip install wandb`"
+ )
+ return wandb
+
+
+def load_json_to_dict(json_path: Union[str, Path]) -> dict:
+ """Load json file to a dictionary.
+
+ Parameters:
+ json_path (str): The path to the json file.
+
+ Returns:
+ (dict): The dictionary representation of the json file.
+ """
+ with open(json_path, "r") as f:
+ data = json.load(f)
+ return data
+
+
+def analyze_text(
+ text: str,
+ complexity_metrics: bool = True,
+ visualize: bool = True,
+ nlp: Any = None,
+ output_dir: Optional[Union[str, Path]] = None,
+) -> dict:
+ """Analyze text using textstat and spacy.
+
+ Parameters:
+ text (str): The text to analyze.
+ complexity_metrics (bool): Whether to compute complexity metrics.
+ visualize (bool): Whether to visualize the text.
+ nlp (spacy.lang): The spacy language model to use for visualization.
+ output_dir (str): The directory to save the visualization files to.
+
+ Returns:
+ (dict): A dictionary containing the complexity metrics and visualization
+ files serialized in a wandb.Html element.
+ """
+ resp = {}
+ textstat = import_textstat()
+ wandb = import_wandb()
+ spacy = import_spacy()
+ if complexity_metrics:
+ text_complexity_metrics = {
+ "flesch_reading_ease": textstat.flesch_reading_ease(text),
+ "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
+ "smog_index": textstat.smog_index(text),
+ "coleman_liau_index": textstat.coleman_liau_index(text),
+ "automated_readability_index": textstat.automated_readability_index(text),
+ "dale_chall_readability_score": textstat.dale_chall_readability_score(text),
+ "difficult_words": textstat.difficult_words(text),
+ "linsear_write_formula": textstat.linsear_write_formula(text),
+ "gunning_fog": textstat.gunning_fog(text),
+ "text_standard": textstat.text_standard(text),
+ "fernandez_huerta": textstat.fernandez_huerta(text),
+ "szigriszt_pazos": textstat.szigriszt_pazos(text),
+ "gutierrez_polini": textstat.gutierrez_polini(text),
+ "crawford": textstat.crawford(text),
+ "gulpease_index": textstat.gulpease_index(text),
+ "osman": textstat.osman(text),
+ }
+ resp.update(text_complexity_metrics)
+
+ if visualize and nlp and output_dir is not None:
+ doc = nlp(text)
+
+ dep_out = spacy.displacy.render( # type: ignore
+ doc, style="dep", jupyter=False, page=True
+ )
+ dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
+ dep_output_path.open("w", encoding="utf-8").write(dep_out)
+
+ ent_out = spacy.displacy.render( # type: ignore
+ doc, style="ent", jupyter=False, page=True
+ )
+ ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
+ ent_output_path.open("w", encoding="utf-8").write(ent_out)
+
+ text_visualizations = {
+ "dependency_tree": wandb.Html(str(dep_output_path)),
+ "entities": wandb.Html(str(ent_output_path)),
+ }
+ resp.update(text_visualizations)
+
+ return resp
+
+
+def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
+ """Construct an html element from a prompt and a generation.
+
+ Parameters:
+ prompt (str): The prompt.
+ generation (str): The generation.
+
+ Returns:
+ (wandb.Html): The html element."""
+ wandb = import_wandb()
+ formatted_prompt = prompt.replace("\n", " ")
+ formatted_generation = generation.replace("\n", " ")
+
+ return wandb.Html(
+ f"""
+
{formatted_prompt}:
+
+
+ {formatted_generation}
+
+
+ """,
+ inject=False,
+ )
+
+
+class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
+ """Callback Handler that logs to Weights and Biases.
+
+ Parameters:
+ job_type (str): The type of job.
+ project (str): The project to log to.
+ entity (str): The entity to log to.
+ tags (list): The tags to log.
+ group (str): The group to log to.
+ name (str): The name of the run.
+ notes (str): The notes to log.
+ visualize (bool): Whether to visualize the run.
+ complexity_metrics (bool): Whether to log complexity metrics.
+ stream_logs (bool): Whether to stream callback actions to W&B
+
+ This handler will utilize the associated callback method called and formats
+ the input of each callback function with metadata regarding the state of LLM run,
+ and adds the response to the list of records for both the {method}_records and
+ action. It then logs the response using the run.log() method to Weights and Biases.
+ """
+
+ def __init__(
+ self,
+ job_type: Optional[str] = None,
+ project: Optional[str] = "langchain_callback_demo",
+ entity: Optional[str] = None,
+ tags: Optional[Sequence] = None,
+ group: Optional[str] = None,
+ name: Optional[str] = None,
+ notes: Optional[str] = None,
+ visualize: bool = False,
+ complexity_metrics: bool = False,
+ stream_logs: bool = False,
+ ) -> None:
+ """Initialize callback handler."""
+
+ wandb = import_wandb()
+ import_pandas()
+ import_textstat()
+ spacy = import_spacy()
+ super().__init__()
+
+ self.job_type = job_type
+ self.project = project
+ self.entity = entity
+ self.tags = tags
+ self.group = group
+ self.name = name
+ self.notes = notes
+ self.visualize = visualize
+ self.complexity_metrics = complexity_metrics
+ self.stream_logs = stream_logs
+
+ self.temp_dir = tempfile.TemporaryDirectory()
+ self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore
+ job_type=self.job_type,
+ project=self.project,
+ entity=self.entity,
+ tags=self.tags,
+ group=self.group,
+ name=self.name,
+ notes=self.notes,
+ )
+ warning = (
+ "DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor "
+ "of the `WandbTracer`. Please update your code to use the `WandbTracer` "
+ "instead."
+ )
+ wandb.termwarn(
+ warning,
+ repeat=False,
+ )
+ self.callback_columns: list = []
+ self.action_records: list = []
+ self.complexity_metrics = complexity_metrics
+ self.visualize = visualize
+ self.nlp = spacy.load("en_core_web_sm")
+
+ def _init_resp(self) -> Dict:
+ return {k: None for k in self.callback_columns}
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Run when LLM starts."""
+ self.step += 1
+ self.llm_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_llm_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ for prompt in prompts:
+ prompt_resp = deepcopy(resp)
+ prompt_resp["prompts"] = prompt
+ self.on_llm_start_records.append(prompt_resp)
+ self.action_records.append(prompt_resp)
+ if self.stream_logs:
+ self.run.log(prompt_resp)
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run when LLM generates a new token."""
+ self.step += 1
+ self.llm_streams += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_llm_new_token", "token": token})
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_llm_token_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.run.log(resp)
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends running."""
+ self.step += 1
+ self.llm_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_llm_end"})
+ resp.update(flatten_dict(response.llm_output or {}))
+ resp.update(self.get_custom_callback_meta())
+
+ for generations in response.generations:
+ for generation in generations:
+ generation_resp = deepcopy(resp)
+ generation_resp.update(flatten_dict(generation.dict()))
+ generation_resp.update(
+ analyze_text(
+ generation.text,
+ complexity_metrics=self.complexity_metrics,
+ visualize=self.visualize,
+ nlp=self.nlp,
+ output_dir=self.temp_dir.name,
+ )
+ )
+ self.on_llm_end_records.append(generation_resp)
+ self.action_records.append(generation_resp)
+ if self.stream_logs:
+ self.run.log(generation_resp)
+
+ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when LLM errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_chain_start(
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """Run when chain starts running."""
+ self.step += 1
+ self.chain_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_chain_start"})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ chain_input = inputs["input"]
+
+ if isinstance(chain_input, str):
+ input_resp = deepcopy(resp)
+ input_resp["input"] = chain_input
+ self.on_chain_start_records.append(input_resp)
+ self.action_records.append(input_resp)
+ if self.stream_logs:
+ self.run.log(input_resp)
+ elif isinstance(chain_input, list):
+ for inp in chain_input:
+ input_resp = deepcopy(resp)
+ input_resp.update(inp)
+ self.on_chain_start_records.append(input_resp)
+ self.action_records.append(input_resp)
+ if self.stream_logs:
+ self.run.log(input_resp)
+ else:
+ raise ValueError("Unexpected data format provided!")
+
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+ """Run when chain ends running."""
+ self.step += 1
+ self.chain_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_chain_end_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.run.log(resp)
+
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when chain errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_tool_start(
+ self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
+ ) -> None:
+ """Run when tool starts running."""
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_tool_start", "input_str": input_str})
+ resp.update(flatten_dict(serialized))
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_tool_start_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.run.log(resp)
+
+ def on_tool_end(self, output: str, **kwargs: Any) -> None:
+ """Run when tool ends running."""
+ self.step += 1
+ self.tool_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_tool_end", "output": output})
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_tool_end_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.run.log(resp)
+
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
+ """Run when tool errors."""
+ self.step += 1
+ self.errors += 1
+
+ def on_text(self, text: str, **kwargs: Any) -> None:
+ """
+ Run when agent is ending.
+ """
+ self.step += 1
+ self.text_ctr += 1
+
+ resp = self._init_resp()
+ resp.update({"action": "on_text", "text": text})
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_text_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.run.log(resp)
+
+ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+ """Run when agent ends running."""
+ self.step += 1
+ self.agent_ends += 1
+ self.ends += 1
+
+ resp = self._init_resp()
+ resp.update(
+ {
+ "action": "on_agent_finish",
+ "output": finish.return_values["output"],
+ "log": finish.log,
+ }
+ )
+ resp.update(self.get_custom_callback_meta())
+
+ self.on_agent_finish_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.run.log(resp)
+
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+ """Run on agent action."""
+ self.step += 1
+ self.tool_starts += 1
+ self.starts += 1
+
+ resp = self._init_resp()
+ resp.update(
+ {
+ "action": "on_agent_action",
+ "tool": action.tool,
+ "tool_input": action.tool_input,
+ "log": action.log,
+ }
+ )
+ resp.update(self.get_custom_callback_meta())
+ self.on_agent_action_records.append(resp)
+ self.action_records.append(resp)
+ if self.stream_logs:
+ self.run.log(resp)
+
+ def _create_session_analysis_df(self) -> Any:
+ """Create a dataframe with all the information from the session."""
+ pd = import_pandas()
+ on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
+ on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
+
+ llm_input_prompts_df = (
+ on_llm_start_records_df[["step", "prompts", "name"]]
+ .dropna(axis=1)
+ .rename({"step": "prompt_step"}, axis=1)
+ )
+ complexity_metrics_columns = []
+ visualizations_columns = []
+
+ if self.complexity_metrics:
+ complexity_metrics_columns = [
+ "flesch_reading_ease",
+ "flesch_kincaid_grade",
+ "smog_index",
+ "coleman_liau_index",
+ "automated_readability_index",
+ "dale_chall_readability_score",
+ "difficult_words",
+ "linsear_write_formula",
+ "gunning_fog",
+ "text_standard",
+ "fernandez_huerta",
+ "szigriszt_pazos",
+ "gutierrez_polini",
+ "crawford",
+ "gulpease_index",
+ "osman",
+ ]
+
+ if self.visualize:
+ visualizations_columns = ["dependency_tree", "entities"]
+
+ llm_outputs_df = (
+ on_llm_end_records_df[
+ [
+ "step",
+ "text",
+ "token_usage_total_tokens",
+ "token_usage_prompt_tokens",
+ "token_usage_completion_tokens",
+ ]
+ + complexity_metrics_columns
+ + visualizations_columns
+ ]
+ .dropna(axis=1)
+ .rename({"step": "output_step", "text": "output"}, axis=1)
+ )
+ session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
+ session_analysis_df["chat_html"] = session_analysis_df[
+ ["prompts", "output"]
+ ].apply(
+ lambda row: construct_html_from_prompt_and_generation(
+ row["prompts"], row["output"]
+ ),
+ axis=1,
+ )
+ return session_analysis_df
+
+ def flush_tracker(
+ self,
+ langchain_asset: Any = None,
+ reset: bool = True,
+ finish: bool = False,
+ job_type: Optional[str] = None,
+ project: Optional[str] = None,
+ entity: Optional[str] = None,
+ tags: Optional[Sequence] = None,
+ group: Optional[str] = None,
+ name: Optional[str] = None,
+ notes: Optional[str] = None,
+ visualize: Optional[bool] = None,
+ complexity_metrics: Optional[bool] = None,
+ ) -> None:
+ """Flush the tracker and reset the session.
+
+ Args:
+ langchain_asset: The langchain asset to save.
+ reset: Whether to reset the session.
+ finish: Whether to finish the run.
+ job_type: The job type.
+ project: The project.
+ entity: The entity.
+ tags: The tags.
+ group: The group.
+ name: The name.
+ notes: The notes.
+ visualize: Whether to visualize.
+ complexity_metrics: Whether to compute complexity metrics.
+
+ Returns:
+ None
+ """
+ pd = import_pandas()
+ wandb = import_wandb()
+ action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
+ session_analysis_table = wandb.Table(
+ dataframe=self._create_session_analysis_df()
+ )
+ self.run.log(
+ {
+ "action_records": action_records_table,
+ "session_analysis": session_analysis_table,
+ }
+ )
+
+ if langchain_asset:
+ langchain_asset_path = Path(self.temp_dir.name, "model.json")
+ model_artifact = wandb.Artifact(name="model", type="model")
+ model_artifact.add(action_records_table, name="action_records")
+ model_artifact.add(session_analysis_table, name="session_analysis")
+ try:
+ langchain_asset.save(langchain_asset_path)
+ model_artifact.add_file(str(langchain_asset_path))
+ model_artifact.metadata = load_json_to_dict(langchain_asset_path)
+ except ValueError:
+ langchain_asset.save_agent(langchain_asset_path)
+ model_artifact.add_file(str(langchain_asset_path))
+ model_artifact.metadata = load_json_to_dict(langchain_asset_path)
+ except NotImplementedError as e:
+ print("Could not save model.")
+ print(repr(e))
+ pass
+ self.run.log_artifact(model_artifact)
+
+ if finish or reset:
+ self.run.finish()
+ self.temp_dir.cleanup()
+ self.reset_callback_meta()
+ if reset:
+ self.__init__( # type: ignore
+ job_type=job_type if job_type else self.job_type,
+ project=project if project else self.project,
+ entity=entity if entity else self.entity,
+ tags=tags if tags else self.tags,
+ group=group if group else self.group,
+ name=name if name else self.name,
+ notes=notes if notes else self.notes,
+ visualize=visualize if visualize else self.visualize,
+ complexity_metrics=complexity_metrics
+ if complexity_metrics
+ else self.complexity_metrics,
+ )
diff --git a/libs/community/langchain_community/callbacks/whylabs_callback.py b/libs/community/langchain_community/callbacks/whylabs_callback.py
new file mode 100644
index 00000000000..8e8f9854912
--- /dev/null
+++ b/libs/community/langchain_community/callbacks/whylabs_callback.py
@@ -0,0 +1,192 @@
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.utils import get_from_env
+
+if TYPE_CHECKING:
+ from whylogs.api.logger.logger import Logger
+
+diagnostic_logger = logging.getLogger(__name__)
+
+
+def import_langkit(
+ sentiment: bool = False,
+ toxicity: bool = False,
+ themes: bool = False,
+) -> Any:
+ """Import the langkit python package and raise an error if it is not installed.
+
+ Args:
+ sentiment: Whether to import the langkit.sentiment module. Defaults to False.
+ toxicity: Whether to import the langkit.toxicity module. Defaults to False.
+ themes: Whether to import the langkit.themes module. Defaults to False.
+
+ Returns:
+ The imported langkit module.
+ """
+ try:
+ import langkit # noqa: F401
+ import langkit.regexes # noqa: F401
+ import langkit.textstat # noqa: F401
+
+ if sentiment:
+ import langkit.sentiment # noqa: F401
+ if toxicity:
+ import langkit.toxicity # noqa: F401
+ if themes:
+ import langkit.themes # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "To use the whylabs callback manager you need to have the `langkit` python "
+ "package installed. Please install it with `pip install langkit`."
+ )
+ return langkit
+
+
+class WhyLabsCallbackHandler(BaseCallbackHandler):
+ """
+ Callback Handler for logging to WhyLabs. This callback handler utilizes
+ `langkit` to extract features from the prompts & responses when interacting with
+ an LLM. These features can be used to guardrail, evaluate, and observe interactions
+ over time to detect issues relating to hallucinations, prompt engineering,
+ or output validation. LangKit is an LLM monitoring toolkit developed by WhyLabs.
+
+ Here are some examples of what can be monitored with LangKit:
+ * Text Quality
+ - readability score
+ - complexity and grade scores
+ * Text Relevance
+ - Similarity scores between prompt/responses
+ - Similarity scores against user-defined themes
+ - Topic classification
+ * Security and Privacy
+ - patterns - count of strings matching a user-defined regex pattern group
+ - jailbreaks - similarity scores with respect to known jailbreak attempts
+ - prompt injection - similarity scores with respect to known prompt attacks
+ - refusals - similarity scores with respect to known LLM refusal responses
+ * Sentiment and Toxicity
+ - sentiment analysis
+ - toxicity analysis
+
+ For more information, see https://docs.whylabs.ai/docs/language-model-monitoring
+ or check out the LangKit repo here: https://github.com/whylabs/langkit
+
+ ---
+ Args:
+ api_key (Optional[str]): WhyLabs API key. Optional because the preferred
+ way to specify the API key is with environment variable
+ WHYLABS_API_KEY.
+ org_id (Optional[str]): WhyLabs organization id to write profiles to.
+ Optional because the preferred way to specify the organization id is
+ with environment variable WHYLABS_DEFAULT_ORG_ID.
+ dataset_id (Optional[str]): WhyLabs dataset id to write profiles to.
+ Optional because the preferred way to specify the dataset id is
+ with environment variable WHYLABS_DEFAULT_DATASET_ID.
+ sentiment (bool): Whether to enable sentiment analysis. Defaults to False.
+ toxicity (bool): Whether to enable toxicity analysis. Defaults to False.
+ themes (bool): Whether to enable theme analysis. Defaults to False.
+ """
+
+ def __init__(self, logger: Logger, handler: Any):
+ """Initiate the rolling logger."""
+ super().__init__()
+ if hasattr(handler, "init"):
+ handler.init(self)
+ if hasattr(handler, "_get_callbacks"):
+ self._callbacks = handler._get_callbacks()
+ else:
+ self._callbacks = dict()
+ diagnostic_logger.warning("initialized handler without callbacks.")
+ self._logger = logger
+
+ def flush(self) -> None:
+ """Explicitly write current profile if using a rolling logger."""
+ if self._logger and hasattr(self._logger, "_do_rollover"):
+ self._logger._do_rollover()
+ diagnostic_logger.info("Flushing WhyLabs logger, writing profile...")
+
+ def close(self) -> None:
+ """Close any loggers to allow writing out of any profiles before exiting."""
+ if self._logger and hasattr(self._logger, "close"):
+ self._logger.close()
+ diagnostic_logger.info("Closing WhyLabs logger, see you next time!")
+
+ def __enter__(self) -> WhyLabsCallbackHandler:
+ return self
+
+ def __exit__(
+ self, exception_type: Any, exception_value: Any, traceback: Any
+ ) -> None:
+ self.close()
+
+ @classmethod
+ def from_params(
+ cls,
+ *,
+ api_key: Optional[str] = None,
+ org_id: Optional[str] = None,
+ dataset_id: Optional[str] = None,
+ sentiment: bool = False,
+ toxicity: bool = False,
+ themes: bool = False,
+ logger: Optional[Logger] = None,
+ ) -> WhyLabsCallbackHandler:
+ """Instantiate whylogs Logger from params.
+
+ Args:
+ api_key (Optional[str]): WhyLabs API key. Optional because the preferred
+ way to specify the API key is with environment variable
+ WHYLABS_API_KEY.
+ org_id (Optional[str]): WhyLabs organization id to write profiles to.
+ If not set must be specified in environment variable
+ WHYLABS_DEFAULT_ORG_ID.
+ dataset_id (Optional[str]): The model or dataset this callback is gathering
+ telemetry for. If not set must be specified in environment variable
+ WHYLABS_DEFAULT_DATASET_ID.
+ sentiment (bool): If True will initialize a model to perform
+ sentiment analysis compound score. Defaults to False and will not gather
+ this metric.
+ toxicity (bool): If True will initialize a model to score
+ toxicity. Defaults to False and will not gather this metric.
+ themes (bool): If True will initialize a model to calculate
+ distance to configured themes. Defaults to None and will not gather this
+ metric.
+ logger (Optional[Logger]): If specified will bind the configured logger as
+ the telemetry gathering agent. Defaults to LangKit schema with periodic
+ WhyLabs writer.
+ """
+ # langkit library will import necessary whylogs libraries
+ import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
+
+ import whylogs as why
+ from langkit.callback_handler import get_callback_instance
+ from whylogs.api.writer.whylabs import WhyLabsWriter
+ from whylogs.experimental.core.udf_schema import udf_schema
+
+ if logger is None:
+ api_key = api_key or get_from_env("api_key", "WHYLABS_API_KEY")
+ org_id = org_id or get_from_env("org_id", "WHYLABS_DEFAULT_ORG_ID")
+ dataset_id = dataset_id or get_from_env(
+ "dataset_id", "WHYLABS_DEFAULT_DATASET_ID"
+ )
+ whylabs_writer = WhyLabsWriter(
+ api_key=api_key, org_id=org_id, dataset_id=dataset_id
+ )
+
+ whylabs_logger = why.logger(
+ mode="rolling", interval=5, when="M", schema=udf_schema()
+ )
+
+ whylabs_logger.append_writer(writer=whylabs_writer)
+ else:
+ diagnostic_logger.info("Using passed in whylogs logger {logger}")
+ whylabs_logger = logger
+
+ callback_handler_cls = get_callback_instance(logger=whylabs_logger, impl=cls)
+ diagnostic_logger.info(
+ "Started whylogs Logger with WhyLabsWriter and initialized LangKit. π"
+ )
+ return callback_handler_cls
diff --git a/libs/community/langchain_community/chat_loaders/__init__.py b/libs/community/langchain_community/chat_loaders/__init__.py
new file mode 100644
index 00000000000..7547ddcecc8
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/__init__.py
@@ -0,0 +1,19 @@
+"""**Chat Loaders** load chat messages from common communications platforms.
+
+Load chat messages from various
+communications platforms such as Facebook Messenger, Telegram, and
+WhatsApp. The loaded chat messages can be used for fine-tuning models.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseChatLoader --> ChatLoader # Examples: WhatsAppChatLoader, IMessageChatLoader
+
+**Main helpers:**
+
+.. code-block::
+
+ ChatSession
+
+""" # noqa: E501
diff --git a/libs/community/langchain_community/chat_loaders/base.py b/libs/community/langchain_community/chat_loaders/base.py
new file mode 100644
index 00000000000..7bbdd8894d4
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/base.py
@@ -0,0 +1,16 @@
+from abc import ABC, abstractmethod
+from typing import Iterator, List
+
+from langchain_core.chat_sessions import ChatSession
+
+
+class BaseChatLoader(ABC):
+ """Base class for chat loaders."""
+
+ @abstractmethod
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """Lazy load the chat sessions."""
+
+ def load(self) -> List[ChatSession]:
+ """Eagerly load the chat sessions into memory."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/chat_loaders/facebook_messenger.py b/libs/community/langchain_community/chat_loaders/facebook_messenger.py
new file mode 100644
index 00000000000..f0aad601ecd
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/facebook_messenger.py
@@ -0,0 +1,79 @@
+import json
+import logging
+from pathlib import Path
+from typing import Iterator, Union
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.messages import HumanMessage
+
+from langchain_community.chat_loaders.base import BaseChatLoader
+
+logger = logging.getLogger(__file__)
+
+
+class SingleFileFacebookMessengerChatLoader(BaseChatLoader):
+ """Load `Facebook Messenger` chat data from a single file.
+
+ Args:
+ path (Union[Path, str]): The path to the chat file.
+
+ Attributes:
+ path (Path): The path to the chat file.
+
+ """
+
+ def __init__(self, path: Union[Path, str]) -> None:
+ super().__init__()
+ self.file_path = path if isinstance(path, Path) else Path(path)
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """Lazy loads the chat data from the file.
+
+ Yields:
+ ChatSession: A chat session containing the loaded messages.
+
+ """
+ with open(self.file_path) as f:
+ data = json.load(f)
+ sorted_data = sorted(data["messages"], key=lambda x: x["timestamp_ms"])
+ messages = []
+ for m in sorted_data:
+ messages.append(
+ HumanMessage(
+ content=m["content"], additional_kwargs={"sender": m["sender_name"]}
+ )
+ )
+ yield ChatSession(messages=messages)
+
+
+class FolderFacebookMessengerChatLoader(BaseChatLoader):
+ """Load `Facebook Messenger` chat data from a folder.
+
+ Args:
+ path (Union[str, Path]): The path to the directory
+ containing the chat files.
+
+ Attributes:
+ path (Path): The path to the directory containing the chat files.
+
+ """
+
+ def __init__(self, path: Union[str, Path]) -> None:
+ super().__init__()
+ self.directory_path = Path(path) if isinstance(path, str) else path
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """Lazy loads the chat data from the folder.
+
+ Yields:
+ ChatSession: A chat session containing the loaded messages.
+
+ """
+ inbox_path = self.directory_path / "inbox"
+ for _dir in inbox_path.iterdir():
+ if _dir.is_dir():
+ for _file in _dir.iterdir():
+ if _file.suffix.lower() == ".json":
+ file_loader = SingleFileFacebookMessengerChatLoader(path=_file)
+ for result in file_loader.lazy_load():
+ yield result
diff --git a/libs/community/langchain_community/chat_loaders/gmail.py b/libs/community/langchain_community/chat_loaders/gmail.py
new file mode 100644
index 00000000000..7f120347f5b
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/gmail.py
@@ -0,0 +1,112 @@
+import base64
+import re
+from typing import Any, Iterator
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.messages import HumanMessage
+
+from langchain_community.chat_loaders.base import BaseChatLoader
+
+
+def _extract_email_content(msg: Any) -> HumanMessage:
+ from_email = None
+ for values in msg["payload"]["headers"]:
+ name = values["name"]
+ if name == "From":
+ from_email = values["value"]
+ if from_email is None:
+ raise ValueError
+ for part in msg["payload"]["parts"]:
+ if part["mimeType"] == "text/plain":
+ data = part["body"]["data"]
+ data = base64.urlsafe_b64decode(data).decode("utf-8")
+ # Regular expression to split the email body at the first
+ # occurrence of a line that starts with "On ... wrote:"
+ pattern = re.compile(r"\r\nOn .+(\r\n)*wrote:\r\n")
+ # Split the email body and extract the first part
+ newest_response = re.split(pattern, data)[0]
+ message = HumanMessage(
+ content=newest_response, additional_kwargs={"sender": from_email}
+ )
+ return message
+ raise ValueError
+
+
+def _get_message_data(service: Any, message: Any) -> ChatSession:
+ msg = service.users().messages().get(userId="me", id=message["id"]).execute()
+ message_content = _extract_email_content(msg)
+ in_reply_to = None
+ email_data = msg["payload"]["headers"]
+ for values in email_data:
+ name = values["name"]
+ if name == "In-Reply-To":
+ in_reply_to = values["value"]
+ if in_reply_to is None:
+ raise ValueError
+
+ thread_id = msg["threadId"]
+
+ thread = service.users().threads().get(userId="me", id=thread_id).execute()
+ messages = thread["messages"]
+
+ response_email = None
+ for message in messages:
+ email_data = message["payload"]["headers"]
+ for values in email_data:
+ if values["name"] == "Message-ID":
+ message_id = values["value"]
+ if message_id == in_reply_to:
+ response_email = message
+ if response_email is None:
+ raise ValueError
+ starter_content = _extract_email_content(response_email)
+ return ChatSession(messages=[starter_content, message_content])
+
+
+class GMailLoader(BaseChatLoader):
+ """Load data from `GMail`.
+
+ There are many ways you could want to load data from GMail.
+ This loader is currently fairly opinionated in how to do so.
+ The way it does it is it first looks for all messages that you have sent.
+ It then looks for messages where you are responding to a previous email.
+ It then fetches that previous email, and creates a training example
+ of that email, followed by your email.
+
+ Note that there are clear limitations here. For example,
+ all examples created are only looking at the previous email for context.
+
+ To use:
+
+ - Set up a Google Developer Account:
+ Go to the Google Developer Console, create a project,
+ and enable the Gmail API for that project.
+ This will give you a credentials.json file that you'll need later.
+ """
+
+ def __init__(self, creds: Any, n: int = 100, raise_error: bool = False) -> None:
+ super().__init__()
+ self.creds = creds
+ self.n = n
+ self.raise_error = raise_error
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ from googleapiclient.discovery import build
+
+ service = build("gmail", "v1", credentials=self.creds)
+ results = (
+ service.users()
+ .messages()
+ .list(userId="me", labelIds=["SENT"], maxResults=self.n)
+ .execute()
+ )
+ messages = results.get("messages", [])
+ for message in messages:
+ try:
+ yield _get_message_data(service, message)
+ except Exception as e:
+ # TODO: handle errors better
+ if self.raise_error:
+ raise e
+ else:
+ pass
diff --git a/libs/community/langchain_community/chat_loaders/imessage.py b/libs/community/langchain_community/chat_loaders/imessage.py
new file mode 100644
index 00000000000..2a2ac827e77
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/imessage.py
@@ -0,0 +1,161 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import TYPE_CHECKING, Iterator, List, Optional, Union
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.messages import HumanMessage
+
+from langchain_community.chat_loaders.base import BaseChatLoader
+
+if TYPE_CHECKING:
+ import sqlite3
+
+
+class IMessageChatLoader(BaseChatLoader):
+ """Load chat sessions from the `iMessage` chat.db SQLite file.
+
+ It only works on macOS when you have iMessage enabled and have the chat.db file.
+
+ The chat.db file is likely located at ~/Library/Messages/chat.db. However, your
+ terminal may not have permission to access this file. To resolve this, you can
+ copy the file to a different location, change the permissions of the file, or
+ grant full disk access for your terminal emulator
+ in System Settings > Security and Privacy > Full Disk Access.
+ """
+
+ def __init__(self, path: Optional[Union[str, Path]] = None):
+ """
+ Initialize the IMessageChatLoader.
+
+ Args:
+ path (str or Path, optional): Path to the chat.db SQLite file.
+ Defaults to None, in which case the default path
+ ~/Library/Messages/chat.db will be used.
+ """
+ if path is None:
+ path = Path.home() / "Library" / "Messages" / "chat.db"
+ self.db_path = path if isinstance(path, Path) else Path(path)
+ if not self.db_path.exists():
+ raise FileNotFoundError(f"File {self.db_path} not found")
+ try:
+ import sqlite3 # noqa: F401
+ except ImportError as e:
+ raise ImportError(
+ "The sqlite3 module is required to load iMessage chats.\n"
+ "Please install it with `pip install pysqlite3`"
+ ) from e
+
+ def _parse_attributedBody(self, attributedBody: bytes) -> str:
+ """
+ Parse the attributedBody field of the message table
+ for the text content of the message.
+
+ The attributedBody field is a binary blob that contains
+ the message content after the byte string b"NSString":
+
+ 5 bytes 1-3 bytes `len` bytes
+ ... | b"NSString" | preamble | `len` | contents | ...
+
+ The 5 preamble bytes are always b"\x01\x94\x84\x01+"
+
+ The size of `len` is either 1 byte or 3 bytes:
+ - If the first byte in `len` is b"\x81" then `len` is 3 bytes long.
+ So the message length is the 2 bytes after, in little Endian.
+ - Otherwise, the size of `len` is 1 byte, and the message length is
+ that byte.
+
+ Args:
+ attributedBody (bytes): attributedBody field of the message table.
+ Return:
+ str: Text content of the message.
+ """
+ content = attributedBody.split(b"NSString")[1][5:]
+ length, start = content[0], 1
+ if content[0] == 129:
+ length, start = int.from_bytes(content[1:3], "little"), 3
+ return content[start : start + length].decode("utf-8", errors="ignore")
+
+ def _load_single_chat_session(
+ self, cursor: "sqlite3.Cursor", chat_id: int
+ ) -> ChatSession:
+ """
+ Load a single chat session from the iMessage chat.db.
+
+ Args:
+ cursor: SQLite cursor object.
+ chat_id (int): ID of the chat session to load.
+
+ Returns:
+ ChatSession: Loaded chat session.
+ """
+ results: List[HumanMessage] = []
+
+ query = """
+ SELECT message.date, handle.id, message.text, message.attributedBody
+ FROM message
+ JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
+ JOIN handle ON message.handle_id = handle.ROWID
+ WHERE chat_message_join.chat_id = ?
+ ORDER BY message.date ASC;
+ """
+ cursor.execute(query, (chat_id,))
+ messages = cursor.fetchall()
+
+ for date, sender, text, attributedBody in messages:
+ if text:
+ content = text
+ elif attributedBody:
+ content = self._parse_attributedBody(attributedBody)
+ else: # Skip messages with no content
+ continue
+
+ results.append(
+ HumanMessage(
+ role=sender,
+ content=content,
+ additional_kwargs={
+ "message_time": date,
+ "sender": sender,
+ },
+ )
+ )
+
+ return ChatSession(messages=results)
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """
+ Lazy load the chat sessions from the iMessage chat.db
+ and yield them in the required format.
+
+ Yields:
+ ChatSession: Loaded chat session.
+ """
+ import sqlite3
+
+ try:
+ conn = sqlite3.connect(self.db_path)
+ except sqlite3.OperationalError as e:
+ raise ValueError(
+ f"Could not open iMessage DB file {self.db_path}.\n"
+ "Make sure your terminal emulator has disk access to this file.\n"
+ " You can either copy the DB file to an accessible location"
+ " or grant full disk access for your terminal emulator."
+ " You can grant full disk access for your terminal emulator"
+ " in System Settings > Security and Privacy > Full Disk Access."
+ ) from e
+ cursor = conn.cursor()
+
+ # Fetch the list of chat IDs sorted by time (most recent first)
+ query = """SELECT chat_id
+ FROM message
+ JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
+ GROUP BY chat_id
+ ORDER BY MAX(date) DESC;"""
+ cursor.execute(query)
+ chat_ids = [row[0] for row in cursor.fetchall()]
+
+ for chat_id in chat_ids:
+ yield self._load_single_chat_session(cursor, chat_id)
+
+ conn.close()
diff --git a/libs/community/langchain_community/chat_loaders/langsmith.py b/libs/community/langchain_community/chat_loaders/langsmith.py
new file mode 100644
index 00000000000..e808df4fbfb
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/langsmith.py
@@ -0,0 +1,159 @@
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union, cast
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.load import load
+
+from langchain_community.chat_loaders.base import BaseChatLoader
+
+if TYPE_CHECKING:
+ from langsmith.client import Client
+ from langsmith.schemas import Run
+
+logger = logging.getLogger(__name__)
+
+
+class LangSmithRunChatLoader(BaseChatLoader):
+ """
+ Load chat sessions from a list of LangSmith "llm" runs.
+
+ Attributes:
+ runs (Iterable[Union[str, Run]]): The list of LLM run IDs or run objects.
+ client (Client): Instance of LangSmith client for fetching data.
+ """
+
+ def __init__(
+ self, runs: Iterable[Union[str, Run]], client: Optional["Client"] = None
+ ):
+ """
+ Initialize a new LangSmithRunChatLoader instance.
+
+ :param runs: List of LLM run IDs or run objects.
+ :param client: An instance of LangSmith client, if not provided,
+ a new client instance will be created.
+ """
+ from langsmith.client import Client
+
+ self.runs = runs
+ self.client = client or Client()
+
+ def _load_single_chat_session(self, llm_run: "Run") -> ChatSession:
+ """
+ Convert an individual LangSmith LLM run to a ChatSession.
+
+ :param llm_run: The LLM run object.
+ :return: A chat session representing the run's data.
+ """
+ chat_session = LangSmithRunChatLoader._get_messages_from_llm_run(llm_run)
+ functions = LangSmithRunChatLoader._get_functions_from_llm_run(llm_run)
+ if functions:
+ chat_session["functions"] = functions
+ return chat_session
+
+ @staticmethod
+ def _get_messages_from_llm_run(llm_run: "Run") -> ChatSession:
+ """
+ Extract messages from a LangSmith LLM run.
+
+ :param llm_run: The LLM run object.
+ :return: ChatSession with the extracted messages.
+ """
+ if llm_run.run_type != "llm":
+ raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
+ if "messages" not in llm_run.inputs:
+ raise ValueError(f"Run has no 'messages' inputs. Got {llm_run.inputs}")
+ if not llm_run.outputs:
+ raise ValueError("Cannot convert pending run")
+ messages = load(llm_run.inputs)["messages"]
+ message_chunk = load(llm_run.outputs)["generations"][0]["message"]
+ return ChatSession(messages=messages + [message_chunk])
+
+ @staticmethod
+ def _get_functions_from_llm_run(llm_run: "Run") -> Optional[List[Dict]]:
+ """
+ Extract functions from a LangSmith LLM run if they exist.
+
+ :param llm_run: The LLM run object.
+ :return: Functions from the run or None.
+ """
+ if llm_run.run_type != "llm":
+ raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
+ return (llm_run.extra or {}).get("invocation_params", {}).get("functions")
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """
+ Lazy load the chat sessions from the iterable of run IDs.
+
+ This method fetches the runs and converts them to chat sessions on-the-fly,
+ yielding one session at a time.
+
+ :return: Iterator of chat sessions containing messages.
+ """
+ from langsmith.schemas import Run
+
+ for run_obj in self.runs:
+ try:
+ if hasattr(run_obj, "id"):
+ run = run_obj
+ else:
+ run = self.client.read_run(run_obj)
+ session = self._load_single_chat_session(cast(Run, run))
+ yield session
+ except ValueError as e:
+ logger.warning(f"Could not load run {run_obj}: {repr(e)}")
+ continue
+
+
+class LangSmithDatasetChatLoader(BaseChatLoader):
+ """
+ Load chat sessions from a LangSmith dataset with the "chat" data type.
+
+ Attributes:
+ dataset_name (str): The name of the LangSmith dataset.
+ client (Client): Instance of LangSmith client for fetching data.
+ """
+
+ def __init__(self, *, dataset_name: str, client: Optional["Client"] = None):
+ """
+ Initialize a new LangSmithChatDatasetLoader instance.
+
+ :param dataset_name: The name of the LangSmith dataset.
+ :param client: An instance of LangSmith client; if not provided,
+ a new client instance will be created.
+ """
+ try:
+ from langsmith.client import Client
+ except ImportError as e:
+ raise ImportError(
+ "The LangSmith client is required to load LangSmith datasets.\n"
+ "Please install it with `pip install langsmith`"
+ ) from e
+
+ self.dataset_name = dataset_name
+ self.client = client or Client()
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """
+ Lazy load the chat sessions from the specified LangSmith dataset.
+
+ This method fetches the chat data from the dataset and
+ converts each data point to chat sessions on-the-fly,
+ yielding one session at a time.
+
+ :return: Iterator of chat sessions containing messages.
+ """
+ from langchain_community.adapters import openai as oai_adapter # noqa: E402
+
+ data = self.client.read_dataset_openai_finetuning(
+ dataset_name=self.dataset_name
+ )
+ for data_point in data:
+ yield ChatSession(
+ messages=[
+ oai_adapter.convert_dict_to_message(m)
+ for m in data_point.get("messages", [])
+ ],
+ functions=data_point.get("functions"),
+ )
diff --git a/libs/community/langchain_community/chat_loaders/slack.py b/libs/community/langchain_community/chat_loaders/slack.py
new file mode 100644
index 00000000000..8b2603829f1
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/slack.py
@@ -0,0 +1,86 @@
+import json
+import logging
+import re
+import zipfile
+from pathlib import Path
+from typing import Dict, Iterator, List, Union
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.messages import AIMessage, HumanMessage
+
+from langchain_community.chat_loaders.base import BaseChatLoader
+
+logger = logging.getLogger(__name__)
+
+
+class SlackChatLoader(BaseChatLoader):
+ """Load `Slack` conversations from a dump zip file."""
+
+ def __init__(
+ self,
+ path: Union[str, Path],
+ ):
+ """
+ Initialize the chat loader with the path to the exported Slack dump zip file.
+
+ :param path: Path to the exported Slack dump zip file.
+ """
+ self.zip_path = path if isinstance(path, Path) else Path(path)
+ if not self.zip_path.exists():
+ raise FileNotFoundError(f"File {self.zip_path} not found")
+
+ def _load_single_chat_session(self, messages: List[Dict]) -> ChatSession:
+ results: List[Union[AIMessage, HumanMessage]] = []
+ previous_sender = None
+ for message in messages:
+ if not isinstance(message, dict):
+ continue
+ text = message.get("text", "")
+ timestamp = message.get("ts", "")
+ sender = message.get("user", "")
+ if not sender:
+ continue
+ skip_pattern = re.compile(
+ r"<@U\d+> has joined the channel", flags=re.IGNORECASE
+ )
+ if skip_pattern.match(text):
+ continue
+ if sender == previous_sender:
+ results[-1].content += "\n\n" + text
+ results[-1].additional_kwargs["events"].append(
+ {"message_time": timestamp}
+ )
+ else:
+ results.append(
+ HumanMessage(
+ role=sender,
+ content=text,
+ additional_kwargs={
+ "sender": sender,
+ "events": [{"message_time": timestamp}],
+ },
+ )
+ )
+ previous_sender = sender
+ return ChatSession(messages=results)
+
+ def _read_json(self, zip_file: zipfile.ZipFile, file_path: str) -> List[dict]:
+ """Read JSON data from a zip subfile."""
+ with zip_file.open(file_path, "r") as f:
+ data = json.load(f)
+ if not isinstance(data, list):
+ raise ValueError(f"Expected list of dictionaries, got {type(data)}")
+ return data
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """
+ Lazy load the chat sessions from the Slack dump file and yield them
+ in the required format.
+
+ :return: Iterator of chat sessions containing messages.
+ """
+ with zipfile.ZipFile(str(self.zip_path), "r") as zip_file:
+ for file_path in zip_file.namelist():
+ if file_path.endswith(".json"):
+ messages = self._read_json(zip_file, file_path)
+ yield self._load_single_chat_session(messages)
diff --git a/libs/community/langchain_community/chat_loaders/telegram.py b/libs/community/langchain_community/chat_loaders/telegram.py
new file mode 100644
index 00000000000..ff2056162b5
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/telegram.py
@@ -0,0 +1,151 @@
+import json
+import logging
+import os
+import tempfile
+import zipfile
+from pathlib import Path
+from typing import Iterator, List, Union
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
+
+from langchain_community.chat_loaders.base import BaseChatLoader
+
+logger = logging.getLogger(__name__)
+
+
+class TelegramChatLoader(BaseChatLoader):
+ """Load `telegram` conversations to LangChain chat messages.
+
+ To export, use the Telegram Desktop app from
+ https://desktop.telegram.org/, select a conversation, click the three dots
+ in the top right corner, and select "Export chat history". Then select
+ "Machine-readable JSON" (preferred) to export. Note: the 'lite' versions of
+ the desktop app (like "Telegram for MacOS") do not support exporting chat
+ history.
+ """
+
+ def __init__(
+ self,
+ path: Union[str, Path],
+ ):
+ """Initialize the TelegramChatLoader.
+
+ Args:
+ path (Union[str, Path]): Path to the exported Telegram chat zip,
+ directory, json, or HTML file.
+ """
+ self.path = path if isinstance(path, str) else str(path)
+
+ def _load_single_chat_session_html(self, file_path: str) -> ChatSession:
+ """Load a single chat session from an HTML file.
+
+ Args:
+ file_path (str): Path to the HTML file.
+
+ Returns:
+ ChatSession: The loaded chat session.
+ """
+ try:
+ from bs4 import BeautifulSoup
+ except ImportError:
+ raise ImportError(
+ "Please install the 'beautifulsoup4' package to load"
+ " Telegram HTML files. You can do this by running"
+ "'pip install beautifulsoup4' in your terminal."
+ )
+ with open(file_path, "r", encoding="utf-8") as file:
+ soup = BeautifulSoup(file, "html.parser")
+
+ results: List[Union[HumanMessage, AIMessage]] = []
+ previous_sender = None
+ for message in soup.select(".message.default"):
+ timestamp = message.select_one(".pull_right.date.details")["title"]
+ from_name_element = message.select_one(".from_name")
+ if from_name_element is None and previous_sender is None:
+ logger.debug("from_name not found in message")
+ continue
+ elif from_name_element is None:
+ from_name = previous_sender
+ else:
+ from_name = from_name_element.text.strip()
+ text = message.select_one(".text").text.strip()
+ results.append(
+ HumanMessage(
+ content=text,
+ additional_kwargs={
+ "sender": from_name,
+ "events": [{"message_time": timestamp}],
+ },
+ )
+ )
+ previous_sender = from_name
+
+ return ChatSession(messages=results)
+
+ def _load_single_chat_session_json(self, file_path: str) -> ChatSession:
+ """Load a single chat session from a JSON file.
+
+ Args:
+ file_path (str): Path to the JSON file.
+
+ Returns:
+ ChatSession: The loaded chat session.
+ """
+ with open(file_path, "r", encoding="utf-8") as file:
+ data = json.load(file)
+
+ messages = data.get("messages", [])
+ results: List[BaseMessage] = []
+ for message in messages:
+ text = message.get("text", "")
+ timestamp = message.get("date", "")
+ from_name = message.get("from", "")
+
+ results.append(
+ HumanMessage(
+ content=text,
+ additional_kwargs={
+ "sender": from_name,
+ "events": [{"message_time": timestamp}],
+ },
+ )
+ )
+
+ return ChatSession(messages=results)
+
+ def _iterate_files(self, path: str) -> Iterator[str]:
+ """Iterate over files in a directory or zip file.
+
+ Args:
+ path (str): Path to the directory or zip file.
+
+ Yields:
+ str: Path to each file.
+ """
+ if os.path.isfile(path) and path.endswith((".html", ".json")):
+ yield path
+ elif os.path.isdir(path):
+ for root, _, files in os.walk(path):
+ for file in files:
+ if file.endswith((".html", ".json")):
+ yield os.path.join(root, file)
+ elif zipfile.is_zipfile(path):
+ with zipfile.ZipFile(path) as zip_file:
+ for file in zip_file.namelist():
+ if file.endswith((".html", ".json")):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ yield zip_file.extract(file, path=temp_dir)
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """Lazy load the messages from the chat file and yield them
+ in as chat sessions.
+
+ Yields:
+ ChatSession: The loaded chat session.
+ """
+ for file_path in self._iterate_files(self.path):
+ if file_path.endswith(".html"):
+ yield self._load_single_chat_session_html(file_path)
+ elif file_path.endswith(".json"):
+ yield self._load_single_chat_session_json(file_path)
diff --git a/libs/community/langchain_community/chat_loaders/utils.py b/libs/community/langchain_community/chat_loaders/utils.py
new file mode 100644
index 00000000000..3fe9384e3d2
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/utils.py
@@ -0,0 +1,95 @@
+"""Utilities for chat loaders."""
+from copy import deepcopy
+from typing import Iterable, Iterator, List
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.messages import AIMessage, BaseMessage
+
+
+def merge_chat_runs_in_session(
+ chat_session: ChatSession, delimiter: str = "\n\n"
+) -> ChatSession:
+ """Merge chat runs together in a chat session.
+
+ A chat run is a sequence of messages from the same sender.
+
+ Args:
+ chat_session: A chat session.
+
+ Returns:
+ A chat session with merged chat runs.
+ """
+ messages: List[BaseMessage] = []
+ for message in chat_session["messages"]:
+ if not isinstance(message.content, str):
+ raise ValueError(
+ "Chat Loaders only support messages with content type string, "
+ f"got {message.content}"
+ )
+ if not messages:
+ messages.append(deepcopy(message))
+ elif (
+ isinstance(message, type(messages[-1]))
+ and messages[-1].additional_kwargs.get("sender") is not None
+ and messages[-1].additional_kwargs["sender"]
+ == message.additional_kwargs.get("sender")
+ ):
+ if not isinstance(messages[-1].content, str):
+ raise ValueError(
+ "Chat Loaders only support messages with content type string, "
+ f"got {messages[-1].content}"
+ )
+ messages[-1].content = (
+ messages[-1].content + delimiter + message.content
+ ).strip()
+ messages[-1].additional_kwargs.get("events", []).extend(
+ message.additional_kwargs.get("events") or []
+ )
+ else:
+ messages.append(deepcopy(message))
+ return ChatSession(messages=messages)
+
+
+def merge_chat_runs(chat_sessions: Iterable[ChatSession]) -> Iterator[ChatSession]:
+ """Merge chat runs together.
+
+ A chat run is a sequence of messages from the same sender.
+
+ Args:
+ chat_sessions: A list of chat sessions.
+
+ Returns:
+ A list of chat sessions with merged chat runs.
+ """
+ for chat_session in chat_sessions:
+ yield merge_chat_runs_in_session(chat_session)
+
+
+def map_ai_messages_in_session(chat_sessions: ChatSession, sender: str) -> ChatSession:
+ """Convert messages from the specified 'sender' to AI messages.
+
+ This is useful for fine-tuning the AI to adapt to your voice.
+ """
+ messages = []
+ num_converted = 0
+ for message in chat_sessions["messages"]:
+ if message.additional_kwargs.get("sender") == sender:
+ message = AIMessage(
+ content=message.content,
+ additional_kwargs=message.additional_kwargs.copy(),
+ example=getattr(message, "example", None),
+ )
+ num_converted += 1
+ messages.append(message)
+ return ChatSession(messages=messages)
+
+
+def map_ai_messages(
+ chat_sessions: Iterable[ChatSession], sender: str
+) -> Iterator[ChatSession]:
+ """Convert messages from the specified 'sender' to AI messages.
+
+ This is useful for fine-tuning the AI to adapt to your voice.
+ """
+ for chat_session in chat_sessions:
+ yield map_ai_messages_in_session(chat_session, sender)
diff --git a/libs/community/langchain_community/chat_loaders/whatsapp.py b/libs/community/langchain_community/chat_loaders/whatsapp.py
new file mode 100644
index 00000000000..25c155c1aee
--- /dev/null
+++ b/libs/community/langchain_community/chat_loaders/whatsapp.py
@@ -0,0 +1,119 @@
+import logging
+import os
+import re
+import zipfile
+from typing import Iterator, List, Union
+
+from langchain_core.chat_sessions import ChatSession
+from langchain_core.messages import AIMessage, HumanMessage
+
+from langchain_community.chat_loaders.base import BaseChatLoader
+
+logger = logging.getLogger(__name__)
+
+
+class WhatsAppChatLoader(BaseChatLoader):
+ """Load `WhatsApp` conversations from a dump zip file or directory."""
+
+ def __init__(self, path: str):
+ """Initialize the WhatsAppChatLoader.
+
+ Args:
+ path (str): Path to the exported WhatsApp chat
+ zip directory, folder, or file.
+
+ To generate the dump, open the chat, click the three dots in the top
+ right corner, and select "More". Then select "Export chat" and
+ choose "Without media".
+ """
+ self.path = path
+ ignore_lines = [
+ "This message was deleted",
+ "",
+ "image omitted",
+ "Messages and calls are end-to-end encrypted. No one outside of this chat,"
+ " not even WhatsApp, can read or listen to them.",
+ ]
+ self._ignore_lines = re.compile(
+ r"(" + "|".join([r"\u200E*" + line for line in ignore_lines]) + r")",
+ flags=re.IGNORECASE,
+ )
+ self._message_line_regex = re.compile(
+ r"\u200E*\[?(\d{1,2}/\d{1,2}/\d{2,4}, \d{1,2}:\d{2}:\d{2} (?:AM|PM))\]?[ \u200E]*([^:]+): (.+)", # noqa
+ flags=re.IGNORECASE,
+ )
+
+ def _load_single_chat_session(self, file_path: str) -> ChatSession:
+ """Load a single chat session from a file.
+
+ Args:
+ file_path (str): Path to the chat file.
+
+ Returns:
+ ChatSession: The loaded chat session.
+ """
+ with open(file_path, "r", encoding="utf-8") as file:
+ txt = file.read()
+
+ # Split messages by newlines, but keep multi-line messages grouped
+ chat_lines: List[str] = []
+ current_message = ""
+ for line in txt.split("\n"):
+ if self._message_line_regex.match(line):
+ if current_message:
+ chat_lines.append(current_message)
+ current_message = line
+ else:
+ current_message += " " + line.strip()
+ if current_message:
+ chat_lines.append(current_message)
+ results: List[Union[HumanMessage, AIMessage]] = []
+ for line in chat_lines:
+ result = self._message_line_regex.match(line.strip())
+ if result:
+ timestamp, sender, text = result.groups()
+ if not self._ignore_lines.match(text.strip()):
+ results.append(
+ HumanMessage(
+ role=sender,
+ content=text,
+ additional_kwargs={
+ "sender": sender,
+ "events": [{"message_time": timestamp}],
+ },
+ )
+ )
+ else:
+ logger.debug(f"Could not parse line: {line}")
+ return ChatSession(messages=results)
+
+ def _iterate_files(self, path: str) -> Iterator[str]:
+ """Iterate over the files in a directory or zip file.
+
+ Args:
+ path (str): Path to the directory or zip file.
+
+ Yields:
+ str: The path to each file.
+ """
+ if os.path.isfile(path):
+ yield path
+ elif os.path.isdir(path):
+ for root, _, files in os.walk(path):
+ for file in files:
+ if file.endswith(".txt"):
+ yield os.path.join(root, file)
+ elif zipfile.is_zipfile(path):
+ with zipfile.ZipFile(path) as zip_file:
+ for file in zip_file.namelist():
+ if file.endswith(".txt"):
+ yield zip_file.extract(file)
+
+ def lazy_load(self) -> Iterator[ChatSession]:
+ """Lazy load the messages from the chat file and yield
+ them as chat sessions.
+
+ Yields:
+ Iterator[ChatSession]: The loaded chat sessions.
+ """
+ yield self._load_single_chat_session(self.path)
diff --git a/libs/community/langchain_community/chat_message_histories/__init__.py b/libs/community/langchain_community/chat_message_histories/__init__.py
new file mode 100644
index 00000000000..a45ecb7ead6
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/__init__.py
@@ -0,0 +1,65 @@
+from langchain_community.chat_message_histories.astradb import (
+ AstraDBChatMessageHistory,
+)
+from langchain_community.chat_message_histories.cassandra import (
+ CassandraChatMessageHistory,
+)
+from langchain_community.chat_message_histories.cosmos_db import (
+ CosmosDBChatMessageHistory,
+)
+from langchain_community.chat_message_histories.dynamodb import (
+ DynamoDBChatMessageHistory,
+)
+from langchain_community.chat_message_histories.elasticsearch import (
+ ElasticsearchChatMessageHistory,
+)
+from langchain_community.chat_message_histories.file import FileChatMessageHistory
+from langchain_community.chat_message_histories.firestore import (
+ FirestoreChatMessageHistory,
+)
+from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
+from langchain_community.chat_message_histories.momento import MomentoChatMessageHistory
+from langchain_community.chat_message_histories.mongodb import MongoDBChatMessageHistory
+from langchain_community.chat_message_histories.neo4j import Neo4jChatMessageHistory
+from langchain_community.chat_message_histories.postgres import (
+ PostgresChatMessageHistory,
+)
+from langchain_community.chat_message_histories.redis import RedisChatMessageHistory
+from langchain_community.chat_message_histories.rocksetdb import (
+ RocksetChatMessageHistory,
+)
+from langchain_community.chat_message_histories.singlestoredb import (
+ SingleStoreDBChatMessageHistory,
+)
+from langchain_community.chat_message_histories.sql import SQLChatMessageHistory
+from langchain_community.chat_message_histories.streamlit import (
+ StreamlitChatMessageHistory,
+)
+from langchain_community.chat_message_histories.upstash_redis import (
+ UpstashRedisChatMessageHistory,
+)
+from langchain_community.chat_message_histories.xata import XataChatMessageHistory
+from langchain_community.chat_message_histories.zep import ZepChatMessageHistory
+
+__all__ = [
+ "AstraDBChatMessageHistory",
+ "ChatMessageHistory",
+ "CassandraChatMessageHistory",
+ "CosmosDBChatMessageHistory",
+ "DynamoDBChatMessageHistory",
+ "ElasticsearchChatMessageHistory",
+ "FileChatMessageHistory",
+ "FirestoreChatMessageHistory",
+ "MomentoChatMessageHistory",
+ "MongoDBChatMessageHistory",
+ "PostgresChatMessageHistory",
+ "RedisChatMessageHistory",
+ "RocksetChatMessageHistory",
+ "SQLChatMessageHistory",
+ "StreamlitChatMessageHistory",
+ "SingleStoreDBChatMessageHistory",
+ "XataChatMessageHistory",
+ "ZepChatMessageHistory",
+ "UpstashRedisChatMessageHistory",
+ "Neo4jChatMessageHistory",
+]
diff --git a/libs/community/langchain_community/chat_message_histories/astradb.py b/libs/community/langchain_community/chat_message_histories/astradb.py
new file mode 100644
index 00000000000..27e4dc5c936
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/astradb.py
@@ -0,0 +1,114 @@
+"""Astra DB - based chat message history, based on astrapy."""
+from __future__ import annotations
+
+import json
+import time
+import typing
+from typing import List, Optional
+
+if typing.TYPE_CHECKING:
+ from astrapy.db import AstraDB as LibAstraDB
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+DEFAULT_COLLECTION_NAME = "langchain_message_store"
+
+
+class AstraDBChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history that stores history in Astra DB.
+
+ Args (only keyword-arguments accepted):
+ session_id: arbitrary key that is used to store the messages
+ of a single chat session.
+ collection_name (str): name of the Astra DB collection to create/use.
+ token (Optional[str]): API token for Astra DB usage.
+ api_endpoint (Optional[str]): full URL to the API endpoint,
+ such as "https://-us-east1.apps.astra.datastax.com".
+ astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
+ you can pass an already-created 'astrapy.db.AstraDB' instance.
+ namespace (Optional[str]): namespace (aka keyspace) where the
+ collection is created. Defaults to the database's "default namespace".
+ """
+
+ def __init__(
+ self,
+ *,
+ session_id: str,
+ collection_name: str = DEFAULT_COLLECTION_NAME,
+ token: Optional[str] = None,
+ api_endpoint: Optional[str] = None,
+ astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB'
+ namespace: Optional[str] = None,
+ ) -> None:
+ """Create an Astra DB chat message history."""
+ try:
+ from astrapy.db import AstraDB as LibAstraDB
+ except (ImportError, ModuleNotFoundError):
+ raise ImportError(
+ "Could not import a recent astrapy python package. "
+ "Please install it with `pip install --upgrade astrapy`."
+ )
+
+ # Conflicting-arg checks:
+ if astra_db_client is not None:
+ if token is not None or api_endpoint is not None:
+ raise ValueError(
+ "You cannot pass 'astra_db_client' to AstraDB if passing "
+ "'token' and 'api_endpoint'."
+ )
+
+ self.session_id = session_id
+ self.collection_name = collection_name
+ self.token = token
+ self.api_endpoint = api_endpoint
+ self.namespace = namespace
+ if astra_db_client is not None:
+ self.astra_db = astra_db_client
+ else:
+ self.astra_db = LibAstraDB(
+ token=self.token,
+ api_endpoint=self.api_endpoint,
+ namespace=self.namespace,
+ )
+ self.collection = self.astra_db.create_collection(self.collection_name)
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve all session messages from DB"""
+ message_blobs = [
+ doc["body_blob"]
+ for doc in sorted(
+ self.collection.paginated_find(
+ filter={
+ "session_id": self.session_id,
+ },
+ projection={
+ "timestamp": 1,
+ "body_blob": 1,
+ },
+ ),
+ key=lambda _doc: _doc["timestamp"],
+ )
+ ]
+ items = [json.loads(message_blob) for message_blob in message_blobs]
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Write a message to the table"""
+ self.collection.insert_one(
+ {
+ "timestamp": time.time(),
+ "session_id": self.session_id,
+ "body_blob": json.dumps(message_to_dict(message)),
+ }
+ )
+
+ def clear(self) -> None:
+ """Clear session memory from DB"""
+ self.collection.delete_many(filter={"session_id": self.session_id})
diff --git a/libs/community/langchain_community/chat_message_histories/cassandra.py b/libs/community/langchain_community/chat_message_histories/cassandra.py
new file mode 100644
index 00000000000..bc3e0794652
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/cassandra.py
@@ -0,0 +1,72 @@
+"""Cassandra-based chat message history, based on cassIO."""
+from __future__ import annotations
+
+import json
+import typing
+from typing import List
+
+if typing.TYPE_CHECKING:
+ from cassandra.cluster import Session
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+DEFAULT_TABLE_NAME = "message_store"
+DEFAULT_TTL_SECONDS = None
+
+
+class CassandraChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history that stores history in Cassandra.
+
+ Args:
+ session_id: arbitrary key that is used to store the messages
+ of a single chat session.
+ session: a Cassandra `Session` object (an open DB connection)
+ keyspace: name of the keyspace to use.
+ table_name: name of the table to use.
+ ttl_seconds: time-to-live (seconds) for automatic expiration
+ of stored entries. None (default) for no expiration.
+ """
+
+ def __init__(
+ self,
+ session_id: str,
+ session: Session,
+ keyspace: str,
+ table_name: str = DEFAULT_TABLE_NAME,
+ ttl_seconds: typing.Optional[int] = DEFAULT_TTL_SECONDS,
+ ) -> None:
+ try:
+ from cassio.history import StoredBlobHistory
+ except (ImportError, ModuleNotFoundError):
+ raise ImportError(
+ "Could not import cassio python package. "
+ "Please install it with `pip install cassio`."
+ )
+ self.session_id = session_id
+ self.ttl_seconds = ttl_seconds
+ self.blob_history = StoredBlobHistory(session, keyspace, table_name)
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve all session messages from DB"""
+ message_blobs = self.blob_history.retrieve(
+ self.session_id,
+ )
+ items = [json.loads(message_blob) for message_blob in message_blobs]
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Write a message to the table"""
+ self.blob_history.store(
+ self.session_id, json.dumps(message_to_dict(message)), self.ttl_seconds
+ )
+
+ def clear(self) -> None:
+ """Clear session memory from DB"""
+ self.blob_history.clear_session_id(self.session_id)
diff --git a/libs/community/langchain_community/chat_message_histories/cosmos_db.py b/libs/community/langchain_community/chat_message_histories/cosmos_db.py
new file mode 100644
index 00000000000..4210d7a7076
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/cosmos_db.py
@@ -0,0 +1,172 @@
+"""Azure CosmosDB Memory History."""
+from __future__ import annotations
+
+import logging
+from types import TracebackType
+from typing import TYPE_CHECKING, Any, List, Optional, Type
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ messages_from_dict,
+ messages_to_dict,
+)
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from azure.cosmos import ContainerProxy
+
+
+class CosmosDBChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history backed by Azure CosmosDB."""
+
+ def __init__(
+ self,
+ cosmos_endpoint: str,
+ cosmos_database: str,
+ cosmos_container: str,
+ session_id: str,
+ user_id: str,
+ credential: Any = None,
+ connection_string: Optional[str] = None,
+ ttl: Optional[int] = None,
+ cosmos_client_kwargs: Optional[dict] = None,
+ ):
+ """
+ Initializes a new instance of the CosmosDBChatMessageHistory class.
+
+ Make sure to call prepare_cosmos or use the context manager to make
+ sure your database is ready.
+
+ Either a credential or a connection string must be provided.
+
+ :param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account.
+ :param cosmos_database: The name of the database to use.
+ :param cosmos_container: The name of the container to use.
+ :param session_id: The session ID to use, can be overwritten while loading.
+ :param user_id: The user ID to use, can be overwritten while loading.
+ :param credential: The credential to use to authenticate to Azure Cosmos DB.
+ :param connection_string: The connection string to use to authenticate.
+ :param ttl: The time to live (in seconds) to use for documents in the container.
+ :param cosmos_client_kwargs: Additional kwargs to pass to the CosmosClient.
+ """
+ self.cosmos_endpoint = cosmos_endpoint
+ self.cosmos_database = cosmos_database
+ self.cosmos_container = cosmos_container
+ self.credential = credential
+ self.conn_string = connection_string
+ self.session_id = session_id
+ self.user_id = user_id
+ self.ttl = ttl
+
+ self.messages: List[BaseMessage] = []
+ try:
+ from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
+ CosmosClient,
+ )
+ except ImportError as exc:
+ raise ImportError(
+ "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
+ "Please install it with `pip install azure-cosmos`."
+ ) from exc
+ if self.credential:
+ self._client = CosmosClient(
+ url=self.cosmos_endpoint,
+ credential=self.credential,
+ **cosmos_client_kwargs or {},
+ )
+ elif self.conn_string:
+ self._client = CosmosClient.from_connection_string(
+ conn_str=self.conn_string,
+ **cosmos_client_kwargs or {},
+ )
+ else:
+ raise ValueError("Either a connection string or a credential must be set.")
+ self._container: Optional[ContainerProxy] = None
+
+ def prepare_cosmos(self) -> None:
+ """Prepare the CosmosDB client.
+
+ Use this function or the context manager to make sure your database is ready.
+ """
+ try:
+ from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
+ PartitionKey,
+ )
+ except ImportError as exc:
+ raise ImportError(
+ "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
+ "Please install it with `pip install azure-cosmos`."
+ ) from exc
+ database = self._client.create_database_if_not_exists(self.cosmos_database)
+ self._container = database.create_container_if_not_exists(
+ self.cosmos_container,
+ partition_key=PartitionKey("/user_id"),
+ default_ttl=self.ttl,
+ )
+ self.load_messages()
+
+ def __enter__(self) -> "CosmosDBChatMessageHistory":
+ """Context manager entry point."""
+ self._client.__enter__()
+ self.prepare_cosmos()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
+ """Context manager exit"""
+ self.upsert_messages()
+ self._client.__exit__(exc_type, exc_val, traceback)
+
+ def load_messages(self) -> None:
+ """Retrieve the messages from Cosmos"""
+ if not self._container:
+ raise ValueError("Container not initialized")
+ try:
+ from azure.cosmos.exceptions import ( # pylint: disable=import-outside-toplevel # noqa: E501
+ CosmosHttpResponseError,
+ )
+ except ImportError as exc:
+ raise ImportError(
+ "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
+ "Please install it with `pip install azure-cosmos`."
+ ) from exc
+ try:
+ item = self._container.read_item(
+ item=self.session_id, partition_key=self.user_id
+ )
+ except CosmosHttpResponseError:
+ logger.info("no session found")
+ return
+ if "messages" in item and len(item["messages"]) > 0:
+ self.messages = messages_from_dict(item["messages"])
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Add a self-created message to the store"""
+ self.messages.append(message)
+ self.upsert_messages()
+
+ def upsert_messages(self) -> None:
+ """Update the cosmosdb item."""
+ if not self._container:
+ raise ValueError("Container not initialized")
+ self._container.upsert_item(
+ body={
+ "id": self.session_id,
+ "user_id": self.user_id,
+ "messages": messages_to_dict(self.messages),
+ }
+ )
+
+ def clear(self) -> None:
+ """Clear session memory from this memory and cosmos."""
+ self.messages = []
+ if self._container:
+ self._container.delete_item(
+ item=self.session_id, partition_key=self.user_id
+ )
diff --git a/libs/community/langchain_community/chat_message_histories/dynamodb.py b/libs/community/langchain_community/chat_message_histories/dynamodb.py
new file mode 100644
index 00000000000..a804e75018b
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/dynamodb.py
@@ -0,0 +1,153 @@
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Dict, List, Optional
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+ messages_to_dict,
+)
+
+if TYPE_CHECKING:
+ from boto3.session import Session
+
+logger = logging.getLogger(__name__)
+
+
+class DynamoDBChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history that stores history in AWS DynamoDB.
+
+ This class expects that a DynamoDB table exists with name `table_name`
+
+ Args:
+ table_name: name of the DynamoDB table
+ session_id: arbitrary key that is used to store the messages
+ of a single chat session.
+ endpoint_url: URL of the AWS endpoint to connect to. This argument
+ is optional and useful for test purposes, like using Localstack.
+ If you plan to use AWS cloud service, you normally don't have to
+ worry about setting the endpoint_url.
+ primary_key_name: name of the primary key of the DynamoDB table. This argument
+ is optional, defaulting to "SessionId".
+ key: an optional dictionary with a custom primary and secondary key.
+ This argument is optional, but useful when using composite dynamodb keys, or
+ isolating records based off of application details such as a user id.
+ This may also contain global and local secondary index keys.
+ kms_key_id: an optional AWS KMS Key ID, AWS KMS Key ARN, or AWS KMS Alias for
+ client-side encryption
+ """
+
+ def __init__(
+ self,
+ table_name: str,
+ session_id: str,
+ endpoint_url: Optional[str] = None,
+ primary_key_name: str = "SessionId",
+ key: Optional[Dict[str, str]] = None,
+ boto3_session: Optional[Session] = None,
+ kms_key_id: Optional[str] = None,
+ ):
+ if boto3_session:
+ client = boto3_session.resource("dynamodb", endpoint_url=endpoint_url)
+ else:
+ try:
+ import boto3
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import boto3, please install with `pip install boto3`."
+ ) from e
+ if endpoint_url:
+ client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
+ else:
+ client = boto3.resource("dynamodb")
+ self.table = client.Table(table_name)
+ self.session_id = session_id
+ self.key: Dict = key or {primary_key_name: session_id}
+
+ if kms_key_id:
+ try:
+ from dynamodb_encryption_sdk.encrypted.table import EncryptedTable
+ from dynamodb_encryption_sdk.identifiers import CryptoAction
+ from dynamodb_encryption_sdk.material_providers.aws_kms import (
+ AwsKmsCryptographicMaterialsProvider,
+ )
+ from dynamodb_encryption_sdk.structures import AttributeActions
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import dynamodb_encryption_sdk, please install with "
+ "`pip install dynamodb-encryption-sdk`."
+ ) from e
+
+ actions = AttributeActions(
+ default_action=CryptoAction.DO_NOTHING,
+ attribute_actions={"History": CryptoAction.ENCRYPT_AND_SIGN},
+ )
+ aws_kms_cmp = AwsKmsCryptographicMaterialsProvider(key_id=kms_key_id)
+ self.table = EncryptedTable(
+ table=self.table,
+ materials_provider=aws_kms_cmp,
+ attribute_actions=actions,
+ auto_refresh_table_indexes=False,
+ )
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the messages from DynamoDB"""
+ try:
+ from botocore.exceptions import ClientError
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import botocore, please install with `pip install botocore`."
+ ) from e
+
+ response = None
+ try:
+ response = self.table.get_item(Key=self.key)
+ except ClientError as error:
+ if error.response["Error"]["Code"] == "ResourceNotFoundException":
+ logger.warning("No record found with session id: %s", self.session_id)
+ else:
+ logger.error(error)
+
+ if response and "Item" in response:
+ items = response["Item"]["History"]
+ else:
+ items = []
+
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in DynamoDB"""
+ try:
+ from botocore.exceptions import ClientError
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import botocore, please install with `pip install botocore`."
+ ) from e
+
+ messages = messages_to_dict(self.messages)
+ _message = message_to_dict(message)
+ messages.append(_message)
+
+ try:
+ self.table.put_item(Item={**self.key, "History": messages})
+ except ClientError as err:
+ logger.error(err)
+
+ def clear(self) -> None:
+ """Clear session memory from DynamoDB"""
+ try:
+ from botocore.exceptions import ClientError
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import botocore, please install with `pip install botocore`."
+ ) from e
+
+ try:
+ self.table.delete_item(Key=self.key)
+ except ClientError as err:
+ logger.error(err)
diff --git a/libs/community/langchain_community/chat_message_histories/elasticsearch.py b/libs/community/langchain_community/chat_message_histories/elasticsearch.py
new file mode 100644
index 00000000000..39a18460224
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/elasticsearch.py
@@ -0,0 +1,195 @@
+import json
+import logging
+from time import time
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+if TYPE_CHECKING:
+ from elasticsearch import Elasticsearch
+
+logger = logging.getLogger(__name__)
+
+
+class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history that stores history in Elasticsearch.
+
+ Args:
+ es_url: URL of the Elasticsearch instance to connect to.
+ es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
+ es_user: Username to use when connecting to Elasticsearch.
+ es_password: Password to use when connecting to Elasticsearch.
+ es_api_key: API key to use when connecting to Elasticsearch.
+ es_connection: Optional pre-existing Elasticsearch connection.
+ index: Name of the index to use.
+ session_id: Arbitrary key that is used to store the messages
+ of a single chat session.
+ """
+
+ def __init__(
+ self,
+ index: str,
+ session_id: str,
+ *,
+ es_connection: Optional["Elasticsearch"] = None,
+ es_url: Optional[str] = None,
+ es_cloud_id: Optional[str] = None,
+ es_user: Optional[str] = None,
+ es_api_key: Optional[str] = None,
+ es_password: Optional[str] = None,
+ ):
+ self.index: str = index
+ self.session_id: str = session_id
+
+ # Initialize Elasticsearch client from passed client arg or connection info
+ if es_connection is not None:
+ self.client = es_connection.options(
+ headers={"user-agent": self.get_user_agent()}
+ )
+ elif es_url is not None or es_cloud_id is not None:
+ self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
+ es_url=es_url,
+ username=es_user,
+ password=es_password,
+ cloud_id=es_cloud_id,
+ api_key=es_api_key,
+ )
+ else:
+ raise ValueError(
+ """Either provide a pre-existing Elasticsearch connection, \
+ or valid credentials for creating a new connection."""
+ )
+
+ if self.client.indices.exists(index=index):
+ logger.debug(
+ f"Chat history index {index} already exists, skipping creation."
+ )
+ else:
+ logger.debug(f"Creating index {index} for storing chat history.")
+
+ self.client.indices.create(
+ index=index,
+ mappings={
+ "properties": {
+ "session_id": {"type": "keyword"},
+ "created_at": {"type": "date"},
+ "history": {"type": "text"},
+ }
+ },
+ )
+
+ @staticmethod
+ def get_user_agent() -> str:
+ from langchain_community import __version__
+
+ return f"langchain-py-ms/{__version__}"
+
+ @staticmethod
+ def connect_to_elasticsearch(
+ *,
+ es_url: Optional[str] = None,
+ cloud_id: Optional[str] = None,
+ api_key: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ ) -> "Elasticsearch":
+ try:
+ import elasticsearch
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+
+ if es_url and cloud_id:
+ raise ValueError(
+ "Both es_url and cloud_id are defined. Please provide only one."
+ )
+
+ connection_params: Dict[str, Any] = {}
+
+ if es_url:
+ connection_params["hosts"] = [es_url]
+ elif cloud_id:
+ connection_params["cloud_id"] = cloud_id
+ else:
+ raise ValueError("Please provide either elasticsearch_url or cloud_id.")
+
+ if api_key:
+ connection_params["api_key"] = api_key
+ elif username and password:
+ connection_params["basic_auth"] = (username, password)
+
+ es_client = elasticsearch.Elasticsearch(
+ **connection_params,
+ headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
+ )
+ try:
+ es_client.info()
+ except Exception as err:
+ logger.error(f"Error connecting to Elasticsearch: {err}")
+ raise err
+
+ return es_client
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore[override]
+ """Retrieve the messages from Elasticsearch"""
+ try:
+ from elasticsearch import ApiError
+
+ result = self.client.search(
+ index=self.index,
+ query={"term": {"session_id": self.session_id}},
+ sort="created_at:asc",
+ )
+ except ApiError as err:
+ logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
+ raise err
+
+ if result and len(result["hits"]["hits"]) > 0:
+ items = [
+ json.loads(document["_source"]["history"])
+ for document in result["hits"]["hits"]
+ ]
+ else:
+ items = []
+
+ return messages_from_dict(items)
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Add a message to the chat session in Elasticsearch"""
+ try:
+ from elasticsearch import ApiError
+
+ self.client.index(
+ index=self.index,
+ document={
+ "session_id": self.session_id,
+ "created_at": round(time() * 1000),
+ "history": json.dumps(message_to_dict(message)),
+ },
+ refresh=True,
+ )
+ except ApiError as err:
+ logger.error(f"Could not add message to Elasticsearch: {err}")
+ raise err
+
+ def clear(self) -> None:
+ """Clear session memory in Elasticsearch"""
+ try:
+ from elasticsearch import ApiError
+
+ self.client.delete_by_query(
+ index=self.index,
+ query={"term": {"session_id": self.session_id}},
+ refresh=True,
+ )
+ except ApiError as err:
+ logger.error(f"Could not clear session memory in Elasticsearch: {err}")
+ raise err
diff --git a/libs/community/langchain_community/chat_message_histories/file.py b/libs/community/langchain_community/chat_message_histories/file.py
new file mode 100644
index 00000000000..d6f2f43c3d6
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/file.py
@@ -0,0 +1,45 @@
+import json
+import logging
+from pathlib import Path
+from typing import List
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ messages_from_dict,
+ messages_to_dict,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class FileChatMessageHistory(BaseChatMessageHistory):
+ """
+ Chat message history that stores history in a local file.
+
+ Args:
+ file_path: path of the local file to store the messages.
+ """
+
+ def __init__(self, file_path: str):
+ self.file_path = Path(file_path)
+ if not self.file_path.exists():
+ self.file_path.touch()
+ self.file_path.write_text(json.dumps([]))
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the messages from the local file"""
+ items = json.loads(self.file_path.read_text())
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in the local file"""
+ messages = messages_to_dict(self.messages)
+ messages.append(messages_to_dict([message])[0])
+ self.file_path.write_text(json.dumps(messages))
+
+ def clear(self) -> None:
+ """Clear session memory from the local file"""
+ self.file_path.write_text(json.dumps([]))
diff --git a/libs/community/langchain_community/chat_message_histories/firestore.py b/libs/community/langchain_community/chat_message_histories/firestore.py
new file mode 100644
index 00000000000..941bbb72f85
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/firestore.py
@@ -0,0 +1,105 @@
+"""Firestore Chat Message History."""
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, List, Optional
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ messages_from_dict,
+ messages_to_dict,
+)
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from google.cloud.firestore import Client, DocumentReference
+
+
+def _get_firestore_client() -> Client:
+ try:
+ import firebase_admin
+ from firebase_admin import firestore
+ except ImportError:
+ raise ImportError(
+ "Could not import firebase-admin python package. "
+ "Please install it with `pip install firebase-admin`."
+ )
+
+ # For multiple instances, only initialize the app once.
+ try:
+ firebase_admin.get_app()
+ except ValueError as e:
+ logger.debug("Initializing Firebase app: %s", e)
+ firebase_admin.initialize_app()
+
+ return firestore.client()
+
+
+class FirestoreChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history backed by Google Firestore."""
+
+ def __init__(
+ self,
+ collection_name: str,
+ session_id: str,
+ user_id: str,
+ firestore_client: Optional[Client] = None,
+ ):
+ """
+ Initialize a new instance of the FirestoreChatMessageHistory class.
+
+ :param collection_name: The name of the collection to use.
+ :param session_id: The session ID for the chat..
+ :param user_id: The user ID for the chat.
+ """
+ self.collection_name = collection_name
+ self.session_id = session_id
+ self.user_id = user_id
+ self._document: Optional[DocumentReference] = None
+ self.messages: List[BaseMessage] = []
+ self.firestore_client = firestore_client or _get_firestore_client()
+ self.prepare_firestore()
+
+ def prepare_firestore(self) -> None:
+ """Prepare the Firestore client.
+
+ Use this function to make sure your database is ready.
+ """
+ self._document = self.firestore_client.collection(
+ self.collection_name
+ ).document(self.session_id)
+ self.load_messages()
+
+ def load_messages(self) -> None:
+ """Retrieve the messages from Firestore"""
+ if not self._document:
+ raise ValueError("Document not initialized")
+ doc = self._document.get()
+ if doc.exists:
+ data = doc.to_dict()
+ if "messages" in data and len(data["messages"]) > 0:
+ self.messages = messages_from_dict(data["messages"])
+
+ def add_message(self, message: BaseMessage) -> None:
+ self.messages.append(message)
+ self.upsert_messages()
+
+ def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
+ """Update the Firestore document."""
+ if not self._document:
+ raise ValueError("Document not initialized")
+ self._document.set(
+ {
+ "id": self.session_id,
+ "user_id": self.user_id,
+ "messages": messages_to_dict(self.messages),
+ }
+ )
+
+ def clear(self) -> None:
+ """Clear session memory from this memory and Firestore."""
+ self.messages = []
+ if self._document:
+ self._document.delete()
diff --git a/libs/community/langchain_community/chat_message_histories/in_memory.py b/libs/community/langchain_community/chat_message_histories/in_memory.py
new file mode 100644
index 00000000000..8c76e850dd9
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/in_memory.py
@@ -0,0 +1,21 @@
+from typing import List
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import BaseMessage
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+
+class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
+ """In memory implementation of chat message history.
+
+ Stores messages in an in memory list.
+ """
+
+ messages: List[BaseMessage] = Field(default_factory=list)
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Add a self-created message to the store"""
+ self.messages.append(message)
+
+ def clear(self) -> None:
+ self.messages = []
diff --git a/libs/community/langchain_community/chat_message_histories/momento.py b/libs/community/langchain_community/chat_message_histories/momento.py
new file mode 100644
index 00000000000..51073d789e5
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/momento.py
@@ -0,0 +1,189 @@
+from __future__ import annotations
+
+import json
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Optional
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+from langchain_core.utils import get_from_env
+
+if TYPE_CHECKING:
+ import momento
+
+
+def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
+ """Create cache if it doesn't exist.
+
+ Raises:
+ SdkException: Momento service or network error
+ Exception: Unexpected response
+ """
+ from momento.responses import CreateCache
+
+ create_cache_response = cache_client.create_cache(cache_name)
+ if isinstance(create_cache_response, CreateCache.Success) or isinstance(
+ create_cache_response, CreateCache.CacheAlreadyExists
+ ):
+ return None
+ elif isinstance(create_cache_response, CreateCache.Error):
+ raise create_cache_response.inner_exception
+ else:
+ raise Exception(f"Unexpected response cache creation: {create_cache_response}")
+
+
+class MomentoChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history cache that uses Momento as a backend.
+
+ See https://gomomento.com/"""
+
+ def __init__(
+ self,
+ session_id: str,
+ cache_client: momento.CacheClient,
+ cache_name: str,
+ *,
+ key_prefix: str = "message_store:",
+ ttl: Optional[timedelta] = None,
+ ensure_cache_exists: bool = True,
+ ):
+ """Instantiate a chat message history cache that uses Momento as a backend.
+
+ Note: to instantiate the cache client passed to MomentoChatMessageHistory,
+ you must have a Momento account at https://gomomento.com/.
+
+ Args:
+ session_id (str): The session ID to use for this chat session.
+ cache_client (CacheClient): The Momento cache client.
+ cache_name (str): The name of the cache to use to store the messages.
+ key_prefix (str, optional): The prefix to apply to the cache key.
+ Defaults to "message_store:".
+ ttl (Optional[timedelta], optional): The TTL to use for the messages.
+ Defaults to None, ie the default TTL of the cache will be used.
+ ensure_cache_exists (bool, optional): Create the cache if it doesn't exist.
+ Defaults to True.
+
+ Raises:
+ ImportError: Momento python package is not installed.
+ TypeError: cache_client is not of type momento.CacheClientObject
+ """
+ try:
+ from momento import CacheClient
+ from momento.requests import CollectionTtl
+ except ImportError:
+ raise ImportError(
+ "Could not import momento python package. "
+ "Please install it with `pip install momento`."
+ )
+ if not isinstance(cache_client, CacheClient):
+ raise TypeError("cache_client must be a momento.CacheClient object.")
+ if ensure_cache_exists:
+ _ensure_cache_exists(cache_client, cache_name)
+ self.key = key_prefix + session_id
+ self.cache_client = cache_client
+ self.cache_name = cache_name
+ if ttl is not None:
+ self.ttl = CollectionTtl.of(ttl)
+ else:
+ self.ttl = CollectionTtl.from_cache_ttl()
+
+ @classmethod
+ def from_client_params(
+ cls,
+ session_id: str,
+ cache_name: str,
+ ttl: timedelta,
+ *,
+ configuration: Optional[momento.config.Configuration] = None,
+ api_key: Optional[str] = None,
+ auth_token: Optional[str] = None, # for backwards compatibility
+ **kwargs: Any,
+ ) -> MomentoChatMessageHistory:
+ """Construct cache from CacheClient parameters."""
+ try:
+ from momento import CacheClient, Configurations, CredentialProvider
+ except ImportError:
+ raise ImportError(
+ "Could not import momento python package. "
+ "Please install it with `pip install momento`."
+ )
+ if configuration is None:
+ configuration = Configurations.Laptop.v1()
+
+ # Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility
+ try:
+ api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
+ except ValueError:
+ api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY")
+ credentials = CredentialProvider.from_string(api_key)
+ cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
+ return cls(session_id, cache_client, cache_name, ttl=ttl, **kwargs)
+
+ @property
+ def messages(self) -> list[BaseMessage]: # type: ignore[override]
+ """Retrieve the messages from Momento.
+
+ Raises:
+ SdkException: Momento service or network error
+ Exception: Unexpected response
+
+ Returns:
+ list[BaseMessage]: List of cached messages
+ """
+ from momento.responses import CacheListFetch
+
+ fetch_response = self.cache_client.list_fetch(self.cache_name, self.key)
+
+ if isinstance(fetch_response, CacheListFetch.Hit):
+ items = [json.loads(m) for m in fetch_response.value_list_string]
+ return messages_from_dict(items)
+ elif isinstance(fetch_response, CacheListFetch.Miss):
+ return []
+ elif isinstance(fetch_response, CacheListFetch.Error):
+ raise fetch_response.inner_exception
+ else:
+ raise Exception(f"Unexpected response: {fetch_response}")
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Store a message in the cache.
+
+ Args:
+ message (BaseMessage): The message object to store.
+
+ Raises:
+ SdkException: Momento service or network error.
+ Exception: Unexpected response.
+ """
+ from momento.responses import CacheListPushBack
+
+ item = json.dumps(message_to_dict(message))
+ push_response = self.cache_client.list_push_back(
+ self.cache_name, self.key, item, ttl=self.ttl
+ )
+ if isinstance(push_response, CacheListPushBack.Success):
+ return None
+ elif isinstance(push_response, CacheListPushBack.Error):
+ raise push_response.inner_exception
+ else:
+ raise Exception(f"Unexpected response: {push_response}")
+
+ def clear(self) -> None:
+ """Remove the session's messages from the cache.
+
+ Raises:
+ SdkException: Momento service or network error.
+ Exception: Unexpected response.
+ """
+ from momento.responses import CacheDelete
+
+ delete_response = self.cache_client.delete(self.cache_name, self.key)
+ if isinstance(delete_response, CacheDelete.Success):
+ return None
+ elif isinstance(delete_response, CacheDelete.Error):
+ raise delete_response.inner_exception
+ else:
+ raise Exception(f"Unexpected response: {delete_response}")
diff --git a/libs/community/langchain_community/chat_message_histories/mongodb.py b/libs/community/langchain_community/chat_message_histories/mongodb.py
new file mode 100644
index 00000000000..5865f86b13c
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/mongodb.py
@@ -0,0 +1,91 @@
+import json
+import logging
+from typing import List
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_DBNAME = "chat_history"
+DEFAULT_COLLECTION_NAME = "message_store"
+
+
+class MongoDBChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history that stores history in MongoDB.
+
+ Args:
+ connection_string: connection string to connect to MongoDB
+ session_id: arbitrary key that is used to store the messages
+ of a single chat session.
+ database_name: name of the database to use
+ collection_name: name of the collection to use
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ session_id: str,
+ database_name: str = DEFAULT_DBNAME,
+ collection_name: str = DEFAULT_COLLECTION_NAME,
+ ):
+ from pymongo import MongoClient, errors
+
+ self.connection_string = connection_string
+ self.session_id = session_id
+ self.database_name = database_name
+ self.collection_name = collection_name
+
+ try:
+ self.client: MongoClient = MongoClient(connection_string)
+ except errors.ConnectionFailure as error:
+ logger.error(error)
+
+ self.db = self.client[database_name]
+ self.collection = self.db[collection_name]
+ self.collection.create_index("SessionId")
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the messages from MongoDB"""
+ from pymongo import errors
+
+ try:
+ cursor = self.collection.find({"SessionId": self.session_id})
+ except errors.OperationFailure as error:
+ logger.error(error)
+
+ if cursor:
+ items = [json.loads(document["History"]) for document in cursor]
+ else:
+ items = []
+
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in MongoDB"""
+ from pymongo import errors
+
+ try:
+ self.collection.insert_one(
+ {
+ "SessionId": self.session_id,
+ "History": json.dumps(message_to_dict(message)),
+ }
+ )
+ except errors.WriteError as err:
+ logger.error(err)
+
+ def clear(self) -> None:
+ """Clear session memory from MongoDB"""
+ from pymongo import errors
+
+ try:
+ self.collection.delete_many({"SessionId": self.session_id})
+ except errors.WriteError as err:
+ logger.error(err)
diff --git a/libs/community/langchain_community/chat_message_histories/neo4j.py b/libs/community/langchain_community/chat_message_histories/neo4j.py
new file mode 100644
index 00000000000..d64b1e5ed6e
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/neo4j.py
@@ -0,0 +1,112 @@
+from typing import List, Optional, Union
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import BaseMessage, messages_from_dict
+from langchain_core.utils import get_from_env
+
+
+class Neo4jChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history stored in a Neo4j database."""
+
+ def __init__(
+ self,
+ session_id: Union[str, int],
+ url: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ database: str = "neo4j",
+ node_label: str = "Session",
+ window: int = 3,
+ ):
+ try:
+ import neo4j
+ except ImportError:
+ raise ValueError(
+ "Could not import neo4j python package. "
+ "Please install it with `pip install neo4j`."
+ )
+
+ # Make sure session id is not null
+ if not session_id:
+ raise ValueError("Please ensure that the session_id parameter is provided")
+
+ url = get_from_env("url", "NEO4J_URI", url)
+ username = get_from_env("username", "NEO4J_USERNAME", username)
+ password = get_from_env("password", "NEO4J_PASSWORD", password)
+ database = get_from_env("database", "NEO4J_DATABASE", database)
+
+ self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
+ self._database = database
+ self._session_id = session_id
+ self._node_label = node_label
+ self._window = window
+
+ # Verify connection
+ try:
+ self._driver.verify_connectivity()
+ except neo4j.exceptions.ServiceUnavailable:
+ raise ValueError(
+ "Could not connect to Neo4j database. "
+ "Please ensure that the url is correct"
+ )
+ except neo4j.exceptions.AuthError:
+ raise ValueError(
+ "Could not connect to Neo4j database. "
+ "Please ensure that the username and password are correct"
+ )
+ # Create session node
+ self._driver.execute_query(
+ f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
+ {"session_id": self._session_id},
+ ).summary
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the messages from Neo4j"""
+ query = (
+ f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
+ "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
+ f"{self._window*2}]-() WITH p, length(p) AS length "
+ "ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
+ "RETURN {data:{content: node.content}, type:node.type} AS result"
+ )
+ records, _, _ = self._driver.execute_query(
+ query, {"session_id": self._session_id}
+ )
+
+ messages = messages_from_dict([el["result"] for el in records])
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in Neo4j"""
+ query = (
+ f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id "
+ "OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
+ "CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
+ "SET new += {type:$type, content:$content} "
+ "WITH new, lm, last_message WHERE last_message IS NOT NULL "
+ "CREATE (last_message)-[:NEXT]->(new) "
+ "DELETE lm"
+ )
+ self._driver.execute_query(
+ query,
+ {
+ "type": message.type,
+ "content": message.content,
+ "session_id": self._session_id,
+ },
+ ).summary
+
+ def clear(self) -> None:
+ """Clear session memory from Neo4j"""
+ query = (
+ f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
+ "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
+ "WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
+ "UNWIND nodes(p) as node DETACH DELETE node;"
+ )
+ self._driver.execute_query(query, {"session_id": self._session_id}).summary
+
+ def __del__(self) -> None:
+ if self._driver:
+ self._driver.close()
diff --git a/libs/community/langchain_community/chat_message_histories/postgres.py b/libs/community/langchain_community/chat_message_histories/postgres.py
new file mode 100644
index 00000000000..63794197e8f
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/postgres.py
@@ -0,0 +1,82 @@
+import json
+import logging
+from typing import List
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_CONNECTION_STRING = "postgresql://postgres:mypassword@localhost/chat_history"
+
+
+class PostgresChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history stored in a Postgres database."""
+
+ def __init__(
+ self,
+ session_id: str,
+ connection_string: str = DEFAULT_CONNECTION_STRING,
+ table_name: str = "message_store",
+ ):
+ import psycopg
+ from psycopg.rows import dict_row
+
+ try:
+ self.connection = psycopg.connect(connection_string)
+ self.cursor = self.connection.cursor(row_factory=dict_row)
+ except psycopg.OperationalError as error:
+ logger.error(error)
+
+ self.session_id = session_id
+ self.table_name = table_name
+
+ self._create_table_if_not_exists()
+
+ def _create_table_if_not_exists(self) -> None:
+ create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
+ id SERIAL PRIMARY KEY,
+ session_id TEXT NOT NULL,
+ message JSONB NOT NULL
+ );"""
+ self.cursor.execute(create_table_query)
+ self.connection.commit()
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the messages from PostgreSQL"""
+ query = (
+ f"SELECT message FROM {self.table_name} WHERE session_id = %s ORDER BY id;"
+ )
+ self.cursor.execute(query, (self.session_id,))
+ items = [record["message"] for record in self.cursor.fetchall()]
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in PostgreSQL"""
+ from psycopg import sql
+
+ query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
+ sql.Identifier(self.table_name)
+ )
+ self.cursor.execute(
+ query, (self.session_id, json.dumps(message_to_dict(message)))
+ )
+ self.connection.commit()
+
+ def clear(self) -> None:
+ """Clear session memory from PostgreSQL"""
+ query = f"DELETE FROM {self.table_name} WHERE session_id = %s;"
+ self.cursor.execute(query, (self.session_id,))
+ self.connection.commit()
+
+ def __del__(self) -> None:
+ if self.cursor:
+ self.cursor.close()
+ if self.connection:
+ self.connection.close()
diff --git a/libs/community/langchain_community/chat_message_histories/redis.py b/libs/community/langchain_community/chat_message_histories/redis.py
new file mode 100644
index 00000000000..37226e2ab22
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/redis.py
@@ -0,0 +1,65 @@
+import json
+import logging
+from typing import List, Optional
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+from langchain_community.utilities.redis import get_client
+
+logger = logging.getLogger(__name__)
+
+
+class RedisChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history stored in a Redis database."""
+
+ def __init__(
+ self,
+ session_id: str,
+ url: str = "redis://localhost:6379/0",
+ key_prefix: str = "message_store:",
+ ttl: Optional[int] = None,
+ ):
+ try:
+ import redis
+ except ImportError:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ )
+
+ try:
+ self.redis_client = get_client(redis_url=url)
+ except redis.exceptions.ConnectionError as error:
+ logger.error(error)
+
+ self.session_id = session_id
+ self.key_prefix = key_prefix
+ self.ttl = ttl
+
+ @property
+ def key(self) -> str:
+ """Construct the record key to use"""
+ return self.key_prefix + self.session_id
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the messages from Redis"""
+ _items = self.redis_client.lrange(self.key, 0, -1)
+ items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in Redis"""
+ self.redis_client.lpush(self.key, json.dumps(message_to_dict(message)))
+ if self.ttl:
+ self.redis_client.expire(self.key, self.ttl)
+
+ def clear(self) -> None:
+ """Clear session memory from Redis"""
+ self.redis_client.delete(self.key)
diff --git a/libs/community/langchain_community/chat_message_histories/rocksetdb.py b/libs/community/langchain_community/chat_message_histories/rocksetdb.py
new file mode 100644
index 00000000000..0391726c4c6
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/rocksetdb.py
@@ -0,0 +1,268 @@
+from datetime import datetime
+from time import sleep
+from typing import Any, Callable, List, Union
+from uuid import uuid4
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+
+class RocksetChatMessageHistory(BaseChatMessageHistory):
+ """Uses Rockset to store chat messages.
+
+ To use, ensure that the `rockset` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_message_histories import (
+ RocksetChatMessageHistory
+ )
+ from rockset import RocksetClient
+
+ history = RocksetChatMessageHistory(
+ session_id="MySession",
+ client=RocksetClient(),
+ collection="langchain_demo",
+ sync=True
+ )
+
+ history.add_user_message("hi!")
+ history.add_ai_message("whats up?")
+
+ print(history.messages)
+ """
+
+ # You should set these values based on your VI.
+ # These values are configured for the typical
+ # free VI. Read more about VIs here:
+ # https://rockset.com/docs/instances
+ SLEEP_INTERVAL_MS: int = 5
+ ADD_TIMEOUT_MS: int = 5000
+ CREATE_TIMEOUT_MS: int = 20000
+
+ def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
+ """Sleeps until meth() evaluates to true. Passes kwargs into
+ meth.
+ """
+ start = datetime.now()
+ while not method(**method_params):
+ curr = datetime.now()
+ if (curr - start).total_seconds() * 1000 > timeout:
+ raise TimeoutError(f"{method} timed out at {timeout} ms")
+ sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
+
+ def _query(self, query: str, **query_params: Any) -> List[Any]:
+ """Executes an SQL statement and returns the result
+ Args:
+ - query: The SQL string
+ - **query_params: Parameters to pass into the query
+ """
+ return self.client.sql(query, params=query_params).results
+
+ def _create_collection(self) -> None:
+ """Creates a collection for this message history"""
+ self.client.Collections.create_s3_collection(
+ name=self.collection, workspace=self.workspace
+ )
+
+ def _collection_exists(self) -> bool:
+ """Checks whether a collection exists for this message history"""
+ try:
+ self.client.Collections.get(collection=self.collection)
+ except self.rockset.exceptions.NotFoundException:
+ return False
+ return True
+
+ def _collection_is_ready(self) -> bool:
+ """Checks whether the collection for this message history is ready
+ to be queried
+ """
+ return (
+ self.client.Collections.get(collection=self.collection).data.status
+ == "READY"
+ )
+
+ def _document_exists(self) -> bool:
+ return (
+ len(
+ self._query(
+ f"""
+ SELECT 1
+ FROM {self.location}
+ WHERE _id=:session_id
+ LIMIT 1
+ """,
+ session_id=self.session_id,
+ )
+ )
+ != 0
+ )
+
+ def _wait_until_collection_created(self) -> None:
+ """Sleeps until the collection for this message history is ready
+ to be queried
+ """
+ self._wait_until(
+ lambda: self._collection_is_ready(),
+ RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
+ )
+
+ def _wait_until_message_added(self, message_id: str) -> None:
+ """Sleeps until a message is added to the messages list"""
+ self._wait_until(
+ lambda message_id: len(
+ self._query(
+ f"""
+ SELECT *
+ FROM UNNEST((
+ SELECT {self.messages_key}
+ FROM {self.location}
+ WHERE _id = :session_id
+ )) AS message
+ WHERE message.data.additional_kwargs.id = :message_id
+ LIMIT 1
+ """,
+ session_id=self.session_id,
+ message_id=message_id,
+ ),
+ )
+ != 0,
+ RocksetChatMessageHistory.ADD_TIMEOUT_MS,
+ message_id=message_id,
+ )
+
+ def _create_empty_doc(self) -> None:
+ """Creates or replaces a document for this message history with no
+ messages"""
+ self.client.Documents.add_documents(
+ collection=self.collection,
+ workspace=self.workspace,
+ data=[{"_id": self.session_id, self.messages_key: []}],
+ )
+
+ def __init__(
+ self,
+ session_id: str,
+ client: Any,
+ collection: str,
+ workspace: str = "commons",
+ messages_key: str = "messages",
+ sync: bool = False,
+ message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
+ ) -> None:
+ """Constructs a new RocksetChatMessageHistory.
+
+ Args:
+ - session_id: The ID of the chat session
+ - client: The RocksetClient object to use to query
+ - collection: The name of the collection to use to store chat
+ messages. If a collection with the given name
+ does not exist in the workspace, it is created.
+ - workspace: The workspace containing `collection`. Defaults
+ to `"commons"`
+ - messages_key: The DB column containing message history.
+ Defaults to `"messages"`
+ - sync: Whether to wait for messages to be added. Defaults
+ to `False`. NOTE: setting this to `True` will slow
+ down performance.
+ - message_uuid_method: The method that generates message IDs.
+ If set, all messages will have an `id` field within the
+ `additional_kwargs` property. If this param is not set
+ and `sync` is `False`, message IDs will not be created.
+ If this param is not set and `sync` is `True`, the
+ `uuid.uuid4` method will be used to create message IDs.
+ """
+ try:
+ import rockset
+ except ImportError:
+ raise ImportError(
+ "Could not import rockset client python package. "
+ "Please install it with `pip install rockset`."
+ )
+
+ if not isinstance(client, rockset.RocksetClient):
+ raise ValueError(
+ f"client should be an instance of rockset.RocksetClient, "
+ f"got {type(client)}"
+ )
+
+ self.session_id = session_id
+ self.client = client
+ self.collection = collection
+ self.workspace = workspace
+ self.location = f'"{self.workspace}"."{self.collection}"'
+ self.rockset = rockset
+ self.messages_key = messages_key
+ self.message_uuid_method = message_uuid_method
+ self.sync = sync
+
+ try:
+ self.client.set_application("langchain")
+ except AttributeError:
+ # ignore
+ pass
+
+ if not self._collection_exists():
+ self._create_collection()
+ self._wait_until_collection_created()
+ self._create_empty_doc()
+ elif not self._document_exists():
+ self._create_empty_doc()
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Messages in this chat history."""
+ return messages_from_dict(
+ self._query(
+ f"""
+ SELECT *
+ FROM UNNEST ((
+ SELECT "{self.messages_key}"
+ FROM {self.location}
+ WHERE _id = :session_id
+ ))
+ """,
+ session_id=self.session_id,
+ )
+ )
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Add a Message object to the history.
+
+ Args:
+ message: A BaseMessage object to store.
+ """
+ if self.sync and "id" not in message.additional_kwargs:
+ message.additional_kwargs["id"] = self.message_uuid_method()
+ self.client.Documents.patch_documents(
+ collection=self.collection,
+ workspace=self.workspace,
+ data=[
+ self.rockset.model.patch_document.PatchDocument(
+ id=self.session_id,
+ patch=[
+ self.rockset.model.patch_operation.PatchOperation(
+ op="ADD",
+ path=f"/{self.messages_key}/-",
+ value=message_to_dict(message),
+ )
+ ],
+ )
+ ],
+ )
+ if self.sync:
+ self._wait_until_message_added(message.additional_kwargs["id"])
+
+ def clear(self) -> None:
+ """Removes all messages from the chat history"""
+ self._create_empty_doc()
+ if self.sync:
+ self._wait_until(
+ lambda: not self.messages,
+ RocksetChatMessageHistory.ADD_TIMEOUT_MS,
+ )
diff --git a/libs/community/langchain_community/chat_message_histories/singlestoredb.py b/libs/community/langchain_community/chat_message_histories/singlestoredb.py
new file mode 100644
index 00000000000..a2fc1af1372
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/singlestoredb.py
@@ -0,0 +1,277 @@
+import json
+import logging
+import re
+from typing import (
+ Any,
+ List,
+)
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history stored in a SingleStoreDB database."""
+
+ def __init__(
+ self,
+ session_id: str,
+ *,
+ table_name: str = "message_store",
+ id_field: str = "id",
+ session_id_field: str = "session_id",
+ message_field: str = "message",
+ pool_size: int = 5,
+ max_overflow: int = 10,
+ timeout: float = 30,
+ **kwargs: Any,
+ ):
+ """Initialize with necessary components.
+
+ Args:
+
+
+ table_name (str, optional): Specifies the name of the table in use.
+ Defaults to "message_store".
+ id_field (str, optional): Specifies the name of the id field in the table.
+ Defaults to "id".
+ session_id_field (str, optional): Specifies the name of the session_id
+ field in the table. Defaults to "session_id".
+ message_field (str, optional): Specifies the name of the message field
+ in the table. Defaults to "message".
+
+ Following arguments pertain to the connection pool:
+
+ pool_size (int, optional): Determines the number of active connections in
+ the pool. Defaults to 5.
+ max_overflow (int, optional): Determines the maximum number of connections
+ allowed beyond the pool_size. Defaults to 10.
+ timeout (float, optional): Specifies the maximum wait time in seconds for
+ establishing a connection. Defaults to 30.
+
+ Following arguments pertain to the database connection:
+
+ host (str, optional): Specifies the hostname, IP address, or URL for the
+ database connection. The default scheme is "mysql".
+ user (str, optional): Database username.
+ password (str, optional): Database password.
+ port (int, optional): Database port. Defaults to 3306 for non-HTTP
+ connections, 80 for HTTP connections, and 443 for HTTPS connections.
+ database (str, optional): Database name.
+
+ Additional optional arguments provide further customization over the
+ database connection:
+
+ pure_python (bool, optional): Toggles the connector mode. If True,
+ operates in pure Python mode.
+ local_infile (bool, optional): Allows local file uploads.
+ charset (str, optional): Specifies the character set for string values.
+ ssl_key (str, optional): Specifies the path of the file containing the SSL
+ key.
+ ssl_cert (str, optional): Specifies the path of the file containing the SSL
+ certificate.
+ ssl_ca (str, optional): Specifies the path of the file containing the SSL
+ certificate authority.
+ ssl_cipher (str, optional): Sets the SSL cipher list.
+ ssl_disabled (bool, optional): Disables SSL usage.
+ ssl_verify_cert (bool, optional): Verifies the server's certificate.
+ Automatically enabled if ``ssl_ca`` is specified.
+ ssl_verify_identity (bool, optional): Verifies the server's identity.
+ conv (dict[int, Callable], optional): A dictionary of data conversion
+ functions.
+ credential_type (str, optional): Specifies the type of authentication to
+ use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO.
+ autocommit (bool, optional): Enables autocommits.
+ results_type (str, optional): Determines the structure of the query results:
+ tuples, namedtuples, dicts.
+ results_format (str, optional): Deprecated. This option has been renamed to
+ results_type.
+
+ Examples:
+ Basic Usage:
+
+ .. code-block:: python
+
+ from langchain_community.chat_message_histories import (
+ SingleStoreDBChatMessageHistory
+ )
+
+ message_history = SingleStoreDBChatMessageHistory(
+ session_id="my-session",
+ host="https://user:password@127.0.0.1:3306/database"
+ )
+
+ Advanced Usage:
+
+ .. code-block:: python
+
+ from langchain_community.chat_message_histories import (
+ SingleStoreDBChatMessageHistory
+ )
+
+ message_history = SingleStoreDBChatMessageHistory(
+ session_id="my-session",
+ host="127.0.0.1",
+ port=3306,
+ user="user",
+ password="password",
+ database="db",
+ table_name="my_custom_table",
+ pool_size=10,
+ timeout=60,
+ )
+
+ Using environment variables:
+
+ .. code-block:: python
+
+ from langchain_community.chat_message_histories import (
+ SingleStoreDBChatMessageHistory
+ )
+
+ os.environ['SINGLESTOREDB_URL'] = 'me:p455w0rd@s2-host.com/my_db'
+ message_history = SingleStoreDBChatMessageHistory("my-session")
+ """
+
+ self.table_name = self._sanitize_input(table_name)
+ self.session_id = self._sanitize_input(session_id)
+ self.id_field = self._sanitize_input(id_field)
+ self.session_id_field = self._sanitize_input(session_id_field)
+ self.message_field = self._sanitize_input(message_field)
+
+ # Pass the rest of the kwargs to the connection.
+ self.connection_kwargs = kwargs
+
+ # Add connection attributes to the connection kwargs.
+ if "conn_attrs" not in self.connection_kwargs:
+ self.connection_kwargs["conn_attrs"] = dict()
+
+ self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
+ self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
+
+ # Create a connection pool.
+ try:
+ from sqlalchemy.pool import QueuePool
+ except ImportError:
+ raise ImportError(
+ "Could not import sqlalchemy.pool python package. "
+ "Please install it with `pip install singlestoredb`."
+ )
+
+ self.connection_pool = QueuePool(
+ self._get_connection,
+ max_overflow=max_overflow,
+ pool_size=pool_size,
+ timeout=timeout,
+ )
+ self.table_created = False
+
+ def _sanitize_input(self, input_str: str) -> str:
+ # Remove characters that are not alphanumeric or underscores
+ return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
+
+ def _get_connection(self) -> Any:
+ try:
+ import singlestoredb as s2
+ except ImportError:
+ raise ImportError(
+ "Could not import singlestoredb python package. "
+ "Please install it with `pip install singlestoredb`."
+ )
+ return s2.connect(**self.connection_kwargs)
+
+ def _create_table_if_not_exists(self) -> None:
+ """Create table if it doesn't exist."""
+ if self.table_created:
+ return
+ conn = self.connection_pool.connect()
+ try:
+ cur = conn.cursor()
+ try:
+ cur.execute(
+ """CREATE TABLE IF NOT EXISTS {}
+ ({} BIGINT PRIMARY KEY AUTO_INCREMENT,
+ {} TEXT NOT NULL,
+ {} JSON NOT NULL);""".format(
+ self.table_name,
+ self.id_field,
+ self.session_id_field,
+ self.message_field,
+ ),
+ )
+ self.table_created = True
+ finally:
+ cur.close()
+ finally:
+ conn.close()
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the messages from SingleStoreDB"""
+ self._create_table_if_not_exists()
+ conn = self.connection_pool.connect()
+ items = []
+ try:
+ cur = conn.cursor()
+ try:
+ cur.execute(
+ """SELECT {} FROM {} WHERE {} = %s""".format(
+ self.message_field,
+ self.table_name,
+ self.session_id_field,
+ ),
+ (self.session_id),
+ )
+ for row in cur.fetchall():
+ items.append(row[0])
+ finally:
+ cur.close()
+ finally:
+ conn.close()
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in SingleStoreDB"""
+ self._create_table_if_not_exists()
+ conn = self.connection_pool.connect()
+ try:
+ cur = conn.cursor()
+ try:
+ cur.execute(
+ """INSERT INTO {} ({}, {}) VALUES (%s, %s)""".format(
+ self.table_name,
+ self.session_id_field,
+ self.message_field,
+ ),
+ (self.session_id, json.dumps(message_to_dict(message))),
+ )
+ finally:
+ cur.close()
+ finally:
+ conn.close()
+
+ def clear(self) -> None:
+ """Clear session memory from SingleStoreDB"""
+ self._create_table_if_not_exists()
+ conn = self.connection_pool.connect()
+ try:
+ cur = conn.cursor()
+ try:
+ cur.execute(
+ """DELETE FROM {} WHERE {} = %s""".format(
+ self.table_name,
+ self.session_id_field,
+ ),
+ (self.session_id),
+ )
+ finally:
+ cur.close()
+ finally:
+ conn.close()
diff --git a/libs/community/langchain_community/chat_message_histories/sql.py b/libs/community/langchain_community/chat_message_histories/sql.py
new file mode 100644
index 00000000000..fcc3ac71ab1
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/sql.py
@@ -0,0 +1,140 @@
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, List, Optional
+
+from sqlalchemy import Column, Integer, Text, create_engine
+
+try:
+ from sqlalchemy.orm import declarative_base
+except ImportError:
+ from sqlalchemy.ext.declarative import declarative_base
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+from sqlalchemy.orm import sessionmaker
+
+logger = logging.getLogger(__name__)
+
+
+class BaseMessageConverter(ABC):
+ """The class responsible for converting BaseMessage to your SQLAlchemy model."""
+
+ @abstractmethod
+ def from_sql_model(self, sql_message: Any) -> BaseMessage:
+ """Convert a SQLAlchemy model to a BaseMessage instance."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
+ """Convert a BaseMessage instance to a SQLAlchemy model."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_sql_model_class(self) -> Any:
+ """Get the SQLAlchemy model class."""
+ raise NotImplementedError
+
+
+def create_message_model(table_name, DynamicBase): # type: ignore
+ """
+ Create a message model for a given table name.
+
+ Args:
+ table_name: The name of the table to use.
+ DynamicBase: The base class to use for the model.
+
+ Returns:
+ The model class.
+
+ """
+
+ # Model decleared inside a function to have a dynamic table name
+ class Message(DynamicBase):
+ __tablename__ = table_name
+ id = Column(Integer, primary_key=True)
+ session_id = Column(Text)
+ message = Column(Text)
+
+ return Message
+
+
+class DefaultMessageConverter(BaseMessageConverter):
+ """The default message converter for SQLChatMessageHistory."""
+
+ def __init__(self, table_name: str):
+ self.model_class = create_message_model(table_name, declarative_base())
+
+ def from_sql_model(self, sql_message: Any) -> BaseMessage:
+ return messages_from_dict([json.loads(sql_message.message)])[0]
+
+ def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
+ return self.model_class(
+ session_id=session_id, message=json.dumps(message_to_dict(message))
+ )
+
+ def get_sql_model_class(self) -> Any:
+ return self.model_class
+
+
+class SQLChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history stored in an SQL database."""
+
+ def __init__(
+ self,
+ session_id: str,
+ connection_string: str,
+ table_name: str = "message_store",
+ session_id_field_name: str = "session_id",
+ custom_message_converter: Optional[BaseMessageConverter] = None,
+ ):
+ self.connection_string = connection_string
+ self.engine = create_engine(connection_string, echo=False)
+ self.session_id_field_name = session_id_field_name
+ self.converter = custom_message_converter or DefaultMessageConverter(table_name)
+ self.sql_model_class = self.converter.get_sql_model_class()
+ if not hasattr(self.sql_model_class, session_id_field_name):
+ raise ValueError("SQL model class must have session_id column")
+ self._create_table_if_not_exists()
+
+ self.session_id = session_id
+ self.Session = sessionmaker(self.engine)
+
+ def _create_table_if_not_exists(self) -> None:
+ self.sql_model_class.metadata.create_all(self.engine)
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve all messages from db"""
+ with self.Session() as session:
+ result = (
+ session.query(self.sql_model_class)
+ .where(
+ getattr(self.sql_model_class, self.session_id_field_name)
+ == self.session_id
+ )
+ .order_by(self.sql_model_class.id.asc())
+ )
+ messages = []
+ for record in result:
+ messages.append(self.converter.from_sql_model(record))
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in db"""
+ with self.Session() as session:
+ session.add(self.converter.to_sql_model(message, self.session_id))
+ session.commit()
+
+ def clear(self) -> None:
+ """Clear session memory from db"""
+
+ with self.Session() as session:
+ session.query(self.sql_model_class).filter(
+ getattr(self.sql_model_class, self.session_id_field_name)
+ == self.session_id
+ ).delete()
+ session.commit()
diff --git a/libs/community/langchain_community/chat_message_histories/streamlit.py b/libs/community/langchain_community/chat_message_histories/streamlit.py
new file mode 100644
index 00000000000..51350d36024
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/streamlit.py
@@ -0,0 +1,38 @@
+from typing import List
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import BaseMessage
+
+
+class StreamlitChatMessageHistory(BaseChatMessageHistory):
+ """
+ Chat message history that stores messages in Streamlit session state.
+
+ Args:
+ key: The key to use in Streamlit session state for storing messages.
+ """
+
+ def __init__(self, key: str = "langchain_messages"):
+ try:
+ import streamlit as st
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import streamlit, please run `pip install streamlit`."
+ ) from e
+
+ if key not in st.session_state:
+ st.session_state[key] = []
+ self._messages = st.session_state[key]
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the current list of messages"""
+ return self._messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Add a message to the session memory"""
+ self._messages.append(message)
+
+ def clear(self) -> None:
+ """Clear session memory"""
+ self._messages.clear()
diff --git a/libs/community/langchain_community/chat_message_histories/upstash_redis.py b/libs/community/langchain_community/chat_message_histories/upstash_redis.py
new file mode 100644
index 00000000000..de1e7c37822
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/upstash_redis.py
@@ -0,0 +1,69 @@
+import json
+import logging
+from typing import List, Optional
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class UpstashRedisChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history stored in an Upstash Redis database."""
+
+ def __init__(
+ self,
+ session_id: str,
+ url: str = "",
+ token: str = "",
+ key_prefix: str = "message_store:",
+ ttl: Optional[int] = None,
+ ):
+ try:
+ from upstash_redis import Redis
+ except ImportError:
+ raise ImportError(
+ "Could not import upstash redis python package. "
+ "Please install it with `pip install upstash_redis`."
+ )
+
+ if url == "" or token == "":
+ raise ValueError(
+ "UPSTASH_REDIS_REST_URL and UPSTASH_REDIS_REST_TOKEN are needed."
+ )
+
+ try:
+ self.redis_client = Redis(url=url, token=token)
+ except Exception:
+ logger.error("Upstash Redis instance could not be initiated.")
+
+ self.session_id = session_id
+ self.key_prefix = key_prefix
+ self.ttl = ttl
+
+ @property
+ def key(self) -> str:
+ """Construct the record key to use"""
+ return self.key_prefix + self.session_id
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve the messages from Upstash Redis"""
+ _items = self.redis_client.lrange(self.key, 0, -1)
+ items = [json.loads(m) for m in _items[::-1]]
+ messages = messages_from_dict(items)
+ return messages
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the record in Upstash Redis"""
+ self.redis_client.lpush(self.key, json.dumps(message_to_dict(message)))
+ if self.ttl:
+ self.redis_client.expire(self.key, self.ttl)
+
+ def clear(self) -> None:
+ """Clear session memory from Upstash Redis"""
+ self.redis_client.delete(self.key)
diff --git a/libs/community/langchain_community/chat_message_histories/xata.py b/libs/community/langchain_community/chat_message_histories/xata.py
new file mode 100644
index 00000000000..56bcf1d98f3
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/xata.py
@@ -0,0 +1,134 @@
+import json
+from typing import List
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ BaseMessage,
+ message_to_dict,
+ messages_from_dict,
+)
+
+
+class XataChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history stored in a Xata database."""
+
+ def __init__(
+ self,
+ session_id: str,
+ db_url: str,
+ api_key: str,
+ branch_name: str = "main",
+ table_name: str = "messages",
+ create_table: bool = True,
+ ) -> None:
+ """Initialize with Xata client."""
+ try:
+ from xata.client import XataClient # noqa: F401
+ except ImportError:
+ raise ValueError(
+ "Could not import xata python package. "
+ "Please install it with `pip install xata`."
+ )
+ self._client = XataClient(
+ api_key=api_key, db_url=db_url, branch_name=branch_name
+ )
+ self._table_name = table_name
+ self._session_id = session_id
+
+ if create_table:
+ self._create_table_if_not_exists()
+
+ def _create_table_if_not_exists(self) -> None:
+ r = self._client.table().get_schema(self._table_name)
+ if r.status_code <= 299:
+ return
+ if r.status_code != 404:
+ raise Exception(
+ f"Error checking if table exists in Xata: {r.status_code} {r}"
+ )
+ r = self._client.table().create(self._table_name)
+ if r.status_code > 299:
+ raise Exception(f"Error creating table in Xata: {r.status_code} {r}")
+ r = self._client.table().set_schema(
+ self._table_name,
+ payload={
+ "columns": [
+ {"name": "sessionId", "type": "string"},
+ {"name": "type", "type": "string"},
+ {"name": "role", "type": "string"},
+ {"name": "content", "type": "text"},
+ {"name": "name", "type": "string"},
+ {"name": "additionalKwargs", "type": "json"},
+ ]
+ },
+ )
+ if r.status_code > 299:
+ raise Exception(f"Error setting table schema in Xata: {r.status_code} {r}")
+
+ def add_message(self, message: BaseMessage) -> None:
+ """Append the message to the Xata table"""
+ msg = message_to_dict(message)
+ r = self._client.records().insert(
+ self._table_name,
+ {
+ "sessionId": self._session_id,
+ "type": msg["type"],
+ "content": message.content,
+ "additionalKwargs": json.dumps(message.additional_kwargs),
+ "role": msg["data"].get("role"),
+ "name": msg["data"].get("name"),
+ },
+ )
+ if r.status_code > 299:
+ raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ r = self._client.data().query(
+ self._table_name,
+ payload={
+ "filter": {
+ "sessionId": self._session_id,
+ },
+ "sort": {"xata.createdAt": "asc"},
+ },
+ )
+ if r.status_code != 200:
+ raise Exception(f"Error running query: {r.status_code} {r}")
+ msgs = messages_from_dict(
+ [
+ {
+ "type": m["type"],
+ "data": {
+ "content": m["content"],
+ "role": m.get("role"),
+ "name": m.get("name"),
+ "additional_kwargs": json.loads(m["additionalKwargs"]),
+ },
+ }
+ for m in r["records"]
+ ]
+ )
+ return msgs
+
+ def clear(self) -> None:
+ """Delete session from Xata table."""
+ while True:
+ r = self._client.data().query(
+ self._table_name,
+ payload={
+ "columns": ["id"],
+ "filter": {
+ "sessionId": self._session_id,
+ },
+ },
+ )
+ if r.status_code != 200:
+ raise Exception(f"Error running query: {r.status_code} {r}")
+ ids = [rec["id"] for rec in r["records"]]
+ if len(ids) == 0:
+ break
+ operations = [
+ {"delete": {"table": self._table_name, "id": id}} for id in ids
+ ]
+ self._client.records().transaction(payload={"operations": operations})
diff --git a/libs/community/langchain_community/chat_message_histories/zep.py b/libs/community/langchain_community/chat_message_histories/zep.py
new file mode 100644
index 00000000000..3899f89ba6e
--- /dev/null
+++ b/libs/community/langchain_community/chat_message_histories/zep.py
@@ -0,0 +1,191 @@
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ HumanMessage,
+ SystemMessage,
+)
+
+if TYPE_CHECKING:
+ from zep_python import Memory, MemorySearchResult, Message, NotFoundError
+
+logger = logging.getLogger(__name__)
+
+
+class ZepChatMessageHistory(BaseChatMessageHistory):
+ """Chat message history that uses Zep as a backend.
+
+ Recommended usage::
+
+ # Set up Zep Chat History
+ zep_chat_history = ZepChatMessageHistory(
+ session_id=session_id,
+ url=ZEP_API_URL,
+ api_key=,
+ )
+
+ # Use a standard ConversationBufferMemory to encapsulate the Zep chat history
+ memory = ConversationBufferMemory(
+ memory_key="chat_history", chat_memory=zep_chat_history
+ )
+
+
+ Zep provides long-term conversation storage for LLM apps. The server stores,
+ summarizes, embeds, indexes, and enriches conversational AI chat
+ histories, and exposes them via simple, low-latency APIs.
+
+ For server installation instructions and more, see:
+ https://docs.getzep.com/deployment/quickstart/
+
+ This class is a thin wrapper around the zep-python package. Additional
+ Zep functionality is exposed via the `zep_summary` and `zep_messages`
+ properties.
+
+ For more information on the zep-python package, see:
+ https://github.com/getzep/zep-python
+ """
+
+ def __init__(
+ self,
+ session_id: str,
+ url: str = "http://localhost:8000",
+ api_key: Optional[str] = None,
+ ) -> None:
+ try:
+ from zep_python import ZepClient
+ except ImportError:
+ raise ImportError(
+ "Could not import zep-python package. "
+ "Please install it with `pip install zep-python`."
+ )
+
+ self.zep_client = ZepClient(base_url=url, api_key=api_key)
+ self.session_id = session_id
+
+ @property
+ def messages(self) -> List[BaseMessage]: # type: ignore
+ """Retrieve messages from Zep memory"""
+ zep_memory: Optional[Memory] = self._get_memory()
+ if not zep_memory:
+ return []
+
+ messages: List[BaseMessage] = []
+ # Extract summary, if present, and messages
+ if zep_memory.summary:
+ if len(zep_memory.summary.content) > 0:
+ messages.append(SystemMessage(content=zep_memory.summary.content))
+ if zep_memory.messages:
+ msg: Message
+ for msg in zep_memory.messages:
+ metadata: Dict = {
+ "uuid": msg.uuid,
+ "created_at": msg.created_at,
+ "token_count": msg.token_count,
+ "metadata": msg.metadata,
+ }
+ if msg.role == "ai":
+ messages.append(
+ AIMessage(content=msg.content, additional_kwargs=metadata)
+ )
+ else:
+ messages.append(
+ HumanMessage(content=msg.content, additional_kwargs=metadata)
+ )
+
+ return messages
+
+ @property
+ def zep_messages(self) -> List[Message]:
+ """Retrieve summary from Zep memory"""
+ zep_memory: Optional[Memory] = self._get_memory()
+ if not zep_memory:
+ return []
+
+ return zep_memory.messages
+
+ @property
+ def zep_summary(self) -> Optional[str]:
+ """Retrieve summary from Zep memory"""
+ zep_memory: Optional[Memory] = self._get_memory()
+ if not zep_memory or not zep_memory.summary:
+ return None
+
+ return zep_memory.summary.content
+
+ def _get_memory(self) -> Optional[Memory]:
+ """Retrieve memory from Zep"""
+ from zep_python import NotFoundError
+
+ try:
+ zep_memory: Memory = self.zep_client.memory.get_memory(self.session_id)
+ except NotFoundError:
+ logger.warning(
+ f"Session {self.session_id} not found in Zep. Returning None"
+ )
+ return None
+ return zep_memory
+
+ def add_user_message(
+ self, message: str, metadata: Optional[Dict[str, Any]] = None
+ ) -> None:
+ """Convenience method for adding a human message string to the store.
+
+ Args:
+ message: The string contents of a human message.
+ metadata: Optional metadata to attach to the message.
+ """
+ self.add_message(HumanMessage(content=message), metadata=metadata)
+
+ def add_ai_message(
+ self, message: str, metadata: Optional[Dict[str, Any]] = None
+ ) -> None:
+ """Convenience method for adding an AI message string to the store.
+
+ Args:
+ message: The string contents of an AI message.
+ metadata: Optional metadata to attach to the message.
+ """
+ self.add_message(AIMessage(content=message), metadata=metadata)
+
+ def add_message(
+ self, message: BaseMessage, metadata: Optional[Dict[str, Any]] = None
+ ) -> None:
+ """Append the message to the Zep memory history"""
+ from zep_python import Memory, Message
+
+ zep_message = Message(
+ content=message.content, role=message.type, metadata=metadata
+ )
+ zep_memory = Memory(messages=[zep_message])
+
+ self.zep_client.memory.add_memory(self.session_id, zep_memory)
+
+ def search(
+ self, query: str, metadata: Optional[Dict] = None, limit: Optional[int] = None
+ ) -> List[MemorySearchResult]:
+ """Search Zep memory for messages matching the query"""
+ from zep_python import MemorySearchPayload
+
+ payload: MemorySearchPayload = MemorySearchPayload(
+ text=query, metadata=metadata
+ )
+
+ return self.zep_client.memory.search_memory(
+ self.session_id, payload, limit=limit
+ )
+
+ def clear(self) -> None:
+ """Clear session memory from Zep. Note that Zep is long-term storage for memory
+ and this is not advised unless you have specific data retention requirements.
+ """
+ try:
+ self.zep_client.memory.delete_memory(self.session_id)
+ except NotFoundError:
+ logger.warning(
+ f"Session {self.session_id} not found in Zep. Skipping delete."
+ )
diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py
new file mode 100644
index 00000000000..7064f7b7585
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/__init__.py
@@ -0,0 +1,82 @@
+"""**Chat Models** are a variation on language models.
+
+While Chat Models use language models under the hood, the interface they expose
+is a bit different. Rather than expose a "text in, text out" API, they expose
+an interface where "chat messages" are the inputs and outputs.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseLanguageModel --> BaseChatModel --> # Examples: ChatOpenAI, ChatGooglePalm
+
+**Main helpers:**
+
+.. code-block::
+
+ AIMessage, BaseMessage, HumanMessage
+""" # noqa: E501
+
+from langchain_community.chat_models.anthropic import ChatAnthropic
+from langchain_community.chat_models.anyscale import ChatAnyscale
+from langchain_community.chat_models.azure_openai import AzureChatOpenAI
+from langchain_community.chat_models.baichuan import ChatBaichuan
+from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
+from langchain_community.chat_models.bedrock import BedrockChat
+from langchain_community.chat_models.cohere import ChatCohere
+from langchain_community.chat_models.databricks import ChatDatabricks
+from langchain_community.chat_models.ernie import ErnieBotChat
+from langchain_community.chat_models.everlyai import ChatEverlyAI
+from langchain_community.chat_models.fake import FakeListChatModel
+from langchain_community.chat_models.fireworks import ChatFireworks
+from langchain_community.chat_models.gigachat import GigaChat
+from langchain_community.chat_models.google_palm import ChatGooglePalm
+from langchain_community.chat_models.human import HumanInputChatModel
+from langchain_community.chat_models.hunyuan import ChatHunyuan
+from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
+from langchain_community.chat_models.jinachat import JinaChat
+from langchain_community.chat_models.konko import ChatKonko
+from langchain_community.chat_models.litellm import ChatLiteLLM
+from langchain_community.chat_models.minimax import MiniMaxChat
+from langchain_community.chat_models.mlflow import ChatMlflow
+from langchain_community.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
+from langchain_community.chat_models.ollama import ChatOllama
+from langchain_community.chat_models.openai import ChatOpenAI
+from langchain_community.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
+from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOpenAI
+from langchain_community.chat_models.vertexai import ChatVertexAI
+from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat
+from langchain_community.chat_models.yandex import ChatYandexGPT
+
+__all__ = [
+ "ChatOpenAI",
+ "BedrockChat",
+ "AzureChatOpenAI",
+ "FakeListChatModel",
+ "PromptLayerChatOpenAI",
+ "ChatDatabricks",
+ "ChatEverlyAI",
+ "ChatAnthropic",
+ "ChatCohere",
+ "ChatGooglePalm",
+ "ChatMlflow",
+ "ChatMLflowAIGateway",
+ "ChatOllama",
+ "ChatVertexAI",
+ "JinaChat",
+ "HumanInputChatModel",
+ "MiniMaxChat",
+ "ChatAnyscale",
+ "ChatLiteLLM",
+ "ErnieBotChat",
+ "ChatJavelinAIGateway",
+ "ChatKonko",
+ "PaiEasChatEndpoint",
+ "QianfanChatEndpoint",
+ "ChatFireworks",
+ "ChatYandexGPT",
+ "ChatBaichuan",
+ "ChatHunyuan",
+ "GigaChat",
+ "VolcEngineMaasChat",
+]
diff --git a/libs/community/langchain_community/chat_models/anthropic.py b/libs/community/langchain_community/chat_models/anthropic.py
new file mode 100644
index 00000000000..57d7dc77079
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/anthropic.py
@@ -0,0 +1,226 @@
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ agenerate_from_stream,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.prompt_values import PromptValue
+
+from langchain_community.llms.anthropic import _AnthropicCommon
+
+
+def _convert_one_message_to_text(
+ message: BaseMessage,
+ human_prompt: str,
+ ai_prompt: str,
+) -> str:
+ content = cast(str, message.content)
+ if isinstance(message, ChatMessage):
+ message_text = f"\n\n{message.role.capitalize()}: {content}"
+ elif isinstance(message, HumanMessage):
+ message_text = f"{human_prompt} {content}"
+ elif isinstance(message, AIMessage):
+ message_text = f"{ai_prompt} {content}"
+ elif isinstance(message, SystemMessage):
+ message_text = content
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ return message_text
+
+
+def convert_messages_to_prompt_anthropic(
+ messages: List[BaseMessage],
+ *,
+ human_prompt: str = "\n\nHuman:",
+ ai_prompt: str = "\n\nAssistant:",
+) -> str:
+ """Format a list of messages into a full prompt for the Anthropic model
+ Args:
+ messages (List[BaseMessage]): List of BaseMessage to combine.
+ human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
+ ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
+ Returns:
+ str: Combined string with necessary human_prompt and ai_prompt tags.
+ """
+
+ messages = messages.copy() # don't mutate the original list
+ if not isinstance(messages[-1], AIMessage):
+ messages.append(AIMessage(content=""))
+
+ text = "".join(
+ _convert_one_message_to_text(message, human_prompt, ai_prompt)
+ for message in messages
+ )
+
+ # trim off the trailing ' ' that might come from the "Assistant: "
+ return text.rstrip()
+
+
+class ChatAnthropic(BaseChatModel, _AnthropicCommon):
+ """`Anthropic` chat large language models.
+
+ To use, you should have the ``anthropic`` python package installed, and the
+ environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ import anthropic
+ from langchain_community.chat_models import ChatAnthropic
+ model = ChatAnthropic(model="", anthropic_api_key="my-api-key")
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+ arbitrary_types_allowed = True
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "anthropic-chat"
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this model can be serialized by Langchain."""
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "chat_models", "anthropic"]
+
+ def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
+ """Format a list of messages into a full prompt for the Anthropic model
+ Args:
+ messages (List[BaseMessage]): List of BaseMessage to combine.
+ Returns:
+ str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
+ """
+ prompt_params = {}
+ if self.HUMAN_PROMPT:
+ prompt_params["human_prompt"] = self.HUMAN_PROMPT
+ if self.AI_PROMPT:
+ prompt_params["ai_prompt"] = self.AI_PROMPT
+ return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
+
+ def convert_prompt(self, prompt: PromptValue) -> str:
+ return self._convert_messages_to_prompt(prompt.to_messages())
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ prompt = self._convert_messages_to_prompt(messages)
+ params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
+ if stop:
+ params["stop_sequences"] = stop
+
+ stream_resp = self.client.completions.create(**params, stream=True)
+ for data in stream_resp:
+ delta = data.completion
+ yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
+ if run_manager:
+ run_manager.on_llm_new_token(delta)
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ prompt = self._convert_messages_to_prompt(messages)
+ params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
+ if stop:
+ params["stop_sequences"] = stop
+
+ stream_resp = await self.async_client.completions.create(**params, stream=True)
+ async for data in stream_resp:
+ delta = data.completion
+ yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
+ if run_manager:
+ await run_manager.on_llm_new_token(delta)
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._stream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+ prompt = self._convert_messages_to_prompt(
+ messages,
+ )
+ params: Dict[str, Any] = {
+ "prompt": prompt,
+ **self._default_params,
+ **kwargs,
+ }
+ if stop:
+ params["stop_sequences"] = stop
+ response = self.client.completions.create(**params)
+ completion = response.completion
+ message = AIMessage(content=completion)
+ return ChatResult(generations=[ChatGeneration(message=message)])
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._astream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await agenerate_from_stream(stream_iter)
+ prompt = self._convert_messages_to_prompt(
+ messages,
+ )
+ params: Dict[str, Any] = {
+ "prompt": prompt,
+ **self._default_params,
+ **kwargs,
+ }
+ if stop:
+ params["stop_sequences"] = stop
+ response = await self.async_client.completions.create(**params)
+ completion = response.completion
+ message = AIMessage(content=completion)
+ return ChatResult(generations=[ChatGeneration(message=message)])
+
+ def get_num_tokens(self, text: str) -> int:
+ """Calculate number of tokens."""
+ if not self.count_tokens:
+ raise NameError("Please ensure the anthropic package is loaded")
+ return self.count_tokens(text)
diff --git a/libs/community/langchain_community/chat_models/anyscale.py b/libs/community/langchain_community/chat_models/anyscale.py
new file mode 100644
index 00000000000..40d36aa187d
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/anyscale.py
@@ -0,0 +1,220 @@
+"""Anyscale Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
+from __future__ import annotations
+
+import logging
+import os
+import sys
+from typing import TYPE_CHECKING, Dict, Optional, Set
+
+import requests
+from langchain_core.messages import BaseMessage
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.adapters.openai import convert_message_to_dict
+from langchain_community.chat_models.openai import (
+ ChatOpenAI,
+ _import_tiktoken,
+)
+from langchain_community.utils.openai import is_openai_v1
+
+if TYPE_CHECKING:
+ import tiktoken
+
+logger = logging.getLogger(__name__)
+
+
+DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
+DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
+
+
+class ChatAnyscale(ChatOpenAI):
+ """`Anyscale` Chat large language models.
+
+ See https://www.anyscale.com/ for information about Anyscale.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``ANYSCALE_API_KEY`` set with your API key.
+ Alternatively, you can use the anyscale_api_key keyword argument.
+
+ Any parameters that are valid to be passed to the `openai.create` call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatAnyscale
+ chat = ChatAnyscale(model_name="meta-llama/Llama-2-7b-chat-hf")
+ """
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "anyscale-chat"
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"anyscale_api_key": "ANYSCALE_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ anyscale_api_key: SecretStr
+ """AnyScale Endpoints API keys."""
+ model_name: str = Field(default=DEFAULT_MODEL, alias="model")
+ """Model name to use."""
+ anyscale_api_base: str = Field(default=DEFAULT_API_BASE)
+ """Base URL path for API requests,
+ leave blank if not using a proxy or service emulator."""
+ anyscale_proxy: Optional[str] = None
+ """To support explicit proxy for Anyscale."""
+ available_models: Optional[Set[str]] = None
+ """Available models from Anyscale API."""
+
+ @staticmethod
+ def get_available_models(
+ anyscale_api_key: Optional[str] = None,
+ anyscale_api_base: str = DEFAULT_API_BASE,
+ ) -> Set[str]:
+ """Get available models from Anyscale API."""
+ try:
+ anyscale_api_key = anyscale_api_key or os.environ["ANYSCALE_API_KEY"]
+ except KeyError as e:
+ raise ValueError(
+ "Anyscale API key must be passed as keyword argument or "
+ "set in environment variable ANYSCALE_API_KEY.",
+ ) from e
+
+ models_url = f"{anyscale_api_base}/models"
+ models_response = requests.get(
+ models_url,
+ headers={
+ "Authorization": f"Bearer {anyscale_api_key}",
+ },
+ )
+
+ if models_response.status_code != 200:
+ raise ValueError(
+ f"Error getting models from {models_url}: "
+ f"{models_response.status_code}",
+ )
+
+ return {model["id"] for model in models_response.json()["data"]}
+
+ @root_validator(pre=True)
+ def validate_environment_override(cls, values: dict) -> dict:
+ """Validate that api key and python package exists in environment."""
+ values["openai_api_key"] = get_from_dict_or_env(
+ values,
+ "anyscale_api_key",
+ "ANYSCALE_API_KEY",
+ )
+ values["anyscale_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "anyscale_api_key",
+ "ANYSCALE_API_KEY",
+ )
+ )
+ values["openai_api_base"] = get_from_dict_or_env(
+ values,
+ "anyscale_api_base",
+ "ANYSCALE_API_BASE",
+ default=DEFAULT_API_BASE,
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values,
+ "anyscale_proxy",
+ "ANYSCALE_PROXY",
+ default="",
+ )
+ try:
+ import openai
+
+ except ImportError as e:
+ raise ValueError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`.",
+ ) from e
+ try:
+ if is_openai_v1():
+ client_params = {
+ "api_key": values["openai_api_key"],
+ "base_url": values["openai_api_base"],
+ # To do: future support
+ # "organization": values["openai_organization"],
+ # "timeout": values["request_timeout"],
+ # "max_retries": values["max_retries"],
+ # "default_headers": values["default_headers"],
+ # "default_query": values["default_query"],
+ # "http_client": values["http_client"],
+ }
+ values["client"] = openai.OpenAI(**client_params).chat.completions
+ else:
+ values["client"] = openai.ChatCompletion
+ except AttributeError as exc:
+ raise ValueError(
+ "`openai` has no `ChatCompletion` attribute, this is likely "
+ "due to an old version of the openai package. Try upgrading it "
+ "with `pip install --upgrade openai`.",
+ ) from exc
+
+ if "model_name" not in values.keys():
+ values["model_name"] = DEFAULT_MODEL
+
+ model_name = values["model_name"]
+
+ available_models = cls.get_available_models(
+ values["openai_api_key"],
+ values["openai_api_base"],
+ )
+
+ if model_name not in available_models:
+ raise ValueError(
+ f"Model name {model_name} not found in available models: "
+ f"{available_models}.",
+ )
+
+ values["available_models"] = available_models
+
+ return values
+
+ def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
+ tiktoken_ = _import_tiktoken()
+ if self.tiktoken_model_name is not None:
+ model = self.tiktoken_model_name
+ else:
+ model = self.model_name
+ # Returns the number of tokens used by a list of messages.
+ try:
+ encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ model = "cl100k_base"
+ encoding = tiktoken_.get_encoding(model)
+ return model, encoding
+
+ def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
+ """Calculate num tokens with tiktoken package.
+
+ Official documentation: https://github.com/openai/openai-cookbook/blob/
+ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+ if sys.version_info[1] <= 7:
+ return super().get_num_tokens_from_messages(messages)
+ model, encoding = self._get_encoding_model()
+ tokens_per_message = 3
+ tokens_per_name = 1
+ num_tokens = 0
+ messages_dict = [convert_message_to_dict(m) for m in messages]
+ for message in messages_dict:
+ num_tokens += tokens_per_message
+ for key, value in message.items():
+ # Cast str(value) in case the message value is not a string
+ # This occurs with function messages
+ num_tokens += len(encoding.encode(str(value)))
+ if key == "name":
+ num_tokens += tokens_per_name
+ # every reply is primed with assistant
+ num_tokens += 3
+ return num_tokens
diff --git a/libs/community/langchain_community/chat_models/azure_openai.py b/libs/community/langchain_community/chat_models/azure_openai.py
new file mode 100644
index 00000000000..b82c379e1d3
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/azure_openai.py
@@ -0,0 +1,271 @@
+"""Azure OpenAI chat wrapper."""
+from __future__ import annotations
+
+import logging
+import os
+import warnings
+from typing import Any, Callable, Dict, List, Union
+
+from langchain_core.outputs import ChatResult
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.chat_models.openai import ChatOpenAI
+from langchain_community.utils.openai import is_openai_v1
+
+logger = logging.getLogger(__name__)
+
+
+class AzureChatOpenAI(ChatOpenAI):
+ """`Azure OpenAI` Chat Completion API.
+
+ To use this class you
+ must have a deployed model on Azure OpenAI. Use `deployment_name` in the
+ constructor to refer to the "Model deployment name" in the Azure portal.
+
+ In addition, you should have the ``openai`` python package installed, and the
+ following environment variables set or passed in constructor in lower case:
+ - ``AZURE_OPENAI_API_KEY``
+ - ``AZURE_OPENAI_API_ENDPOINT``
+ - ``AZURE_OPENAI_AD_TOKEN``
+ - ``OPENAI_API_VERSION``
+ - ``OPENAI_PROXY``
+
+ For example, if you have `gpt-35-turbo` deployed, with the deployment name
+ `35-turbo-dev`, the constructor should look like:
+
+ .. code-block:: python
+
+ AzureChatOpenAI(
+ azure_deployment="35-turbo-dev",
+ openai_api_version="2023-05-15",
+ )
+
+ Be aware the API version may change.
+
+ You can also specify the version of the model using ``model_version`` constructor
+ parameter, as Azure OpenAI doesn't return model version with the response.
+
+ Default is empty. When you specify the version, it will be appended to the
+ model name in the response. Setting correct version will help you to calculate the
+ cost properly. Model version is not validated, so make sure you set it correctly
+ to get the correct cost.
+
+ Any parameters that are valid to be passed to the openai.create call can be passed
+ in, even if not explicitly saved on this class.
+ """
+
+ azure_endpoint: Union[str, None] = None
+ """Your Azure endpoint, including the resource.
+
+ Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
+
+ Example: `https://example-resource.azure.openai.com/`
+ """
+ deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
+ """A model deployment.
+
+ If given sets the base client URL to include `/deployments/{azure_deployment}`.
+ Note: this means you won't be able to use non-deployment endpoints.
+ """
+ openai_api_version: str = Field(default="", alias="api_version")
+ """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
+ openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
+ """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
+ azure_ad_token: Union[str, None] = None
+ """Your Azure Active Directory token.
+
+ Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
+
+ For more:
+ https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
+ """ # noqa: E501
+ azure_ad_token_provider: Union[Callable[[], str], None] = None
+ """A function that returns an Azure Active Directory token.
+
+ Will be invoked on every request.
+ """
+ model_version: str = ""
+ """Legacy, for openai<1.0.0 support."""
+ openai_api_type: str = ""
+ """Legacy, for openai<1.0.0 support."""
+ validate_base_url: bool = True
+ """For backwards compatibility. If legacy val openai_api_base is passed in, try to
+ infer if it is a base_url or azure_endpoint and update accordingly.
+ """
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "chat_models", "azure_openai"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ if values["n"] < 1:
+ raise ValueError("n must be at least 1.")
+ if values["n"] > 1 and values["streaming"]:
+ raise ValueError("n must be 1 when streaming.")
+
+ # Check OPENAI_KEY for backwards compatibility.
+ # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
+ # other forms of azure credentials.
+ values["openai_api_key"] = (
+ values["openai_api_key"]
+ or os.getenv("AZURE_OPENAI_API_KEY")
+ or os.getenv("OPENAI_API_KEY")
+ )
+ values["openai_api_base"] = values["openai_api_base"] or os.getenv(
+ "OPENAI_API_BASE"
+ )
+ values["openai_api_version"] = values["openai_api_version"] or os.getenv(
+ "OPENAI_API_VERSION"
+ )
+ # Check OPENAI_ORGANIZATION for backwards compatibility.
+ values["openai_organization"] = (
+ values["openai_organization"]
+ or os.getenv("OPENAI_ORG_ID")
+ or os.getenv("OPENAI_ORGANIZATION")
+ )
+ values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
+ "AZURE_OPENAI_ENDPOINT"
+ )
+ values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
+ "AZURE_OPENAI_AD_TOKEN"
+ )
+
+ values["openai_api_type"] = get_from_dict_or_env(
+ values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values, "openai_proxy", "OPENAI_PROXY", default=""
+ )
+
+ try:
+ import openai
+
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ if is_openai_v1():
+ # For backwards compatibility. Before openai v1, no distinction was made
+ # between azure_endpoint and base_url (openai_api_base).
+ openai_api_base = values["openai_api_base"]
+ if openai_api_base and values["validate_base_url"]:
+ if "/openai" not in openai_api_base:
+ values["openai_api_base"] = (
+ values["openai_api_base"].rstrip("/") + "/openai"
+ )
+ warnings.warn(
+ "As of openai>=1.0.0, Azure endpoints should be specified via "
+ f"the `azure_endpoint` param not `openai_api_base` "
+ f"(or alias `base_url`). Updating `openai_api_base` from "
+ f"{openai_api_base} to {values['openai_api_base']}."
+ )
+ if values["deployment_name"]:
+ warnings.warn(
+ "As of openai>=1.0.0, if `deployment_name` (or alias "
+ "`azure_deployment`) is specified then "
+ "`openai_api_base` (or alias `base_url`) should not be. "
+ "Instead use `deployment_name` (or alias `azure_deployment`) "
+ "and `azure_endpoint`."
+ )
+ if values["deployment_name"] not in values["openai_api_base"]:
+ warnings.warn(
+ "As of openai>=1.0.0, if `openai_api_base` "
+ "(or alias `base_url`) is specified it is expected to be "
+ "of the form "
+ "https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
+ f"Updating {openai_api_base} to "
+ f"{values['openai_api_base']}."
+ )
+ values["openai_api_base"] += (
+ "/deployments/" + values["deployment_name"]
+ )
+ values["deployment_name"] = None
+ client_params = {
+ "api_version": values["openai_api_version"],
+ "azure_endpoint": values["azure_endpoint"],
+ "azure_deployment": values["deployment_name"],
+ "api_key": values["openai_api_key"],
+ "azure_ad_token": values["azure_ad_token"],
+ "azure_ad_token_provider": values["azure_ad_token_provider"],
+ "organization": values["openai_organization"],
+ "base_url": values["openai_api_base"],
+ "timeout": values["request_timeout"],
+ "max_retries": values["max_retries"],
+ "default_headers": values["default_headers"],
+ "default_query": values["default_query"],
+ "http_client": values["http_client"],
+ }
+ values["client"] = openai.AzureOpenAI(**client_params).chat.completions
+ values["async_client"] = openai.AsyncAzureOpenAI(
+ **client_params
+ ).chat.completions
+ else:
+ values["client"] = openai.ChatCompletion
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling OpenAI API."""
+ if is_openai_v1():
+ return super()._default_params
+ else:
+ return {
+ **super()._default_params,
+ "engine": self.deployment_name,
+ }
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**self._default_params}
+
+ @property
+ def _client_params(self) -> Dict[str, Any]:
+ """Get the config params used for the openai client."""
+ if is_openai_v1():
+ return super()._client_params
+ else:
+ return {
+ **super()._client_params,
+ "api_type": self.openai_api_type,
+ "api_version": self.openai_api_version,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ return "azure-openai-chat"
+
+ @property
+ def lc_attributes(self) -> Dict[str, Any]:
+ return {
+ "openai_api_type": self.openai_api_type,
+ "openai_api_version": self.openai_api_version,
+ }
+
+ def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
+ if not isinstance(response, dict):
+ response = response.dict()
+ for res in response["choices"]:
+ if res.get("finish_reason", None) == "content_filter":
+ raise ValueError(
+ "Azure has not provided the response due to a content filter "
+ "being triggered"
+ )
+ chat_result = super()._create_chat_result(response)
+
+ if "model" in response:
+ model = response["model"]
+ if self.model_version:
+ model = f"{model}-{self.model_version}"
+
+ if chat_result.llm_output is not None and isinstance(
+ chat_result.llm_output, dict
+ ):
+ chat_result.llm_output["model_name"] = model
+
+ return chat_result
diff --git a/libs/community/langchain_community/chat_models/azureml_endpoint.py b/libs/community/langchain_community/chat_models/azureml_endpoint.py
new file mode 100644
index 00000000000..111f0502b1f
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/azureml_endpoint.py
@@ -0,0 +1,169 @@
+import json
+from typing import Any, Dict, List, Optional, cast
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.chat_models import SimpleChatModel
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.pydantic_v1 import SecretStr, validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.llms.azureml_endpoint import (
+ AzureMLEndpointClient,
+ ContentFormatterBase,
+)
+
+
+class LlamaContentFormatter(ContentFormatterBase):
+ """Content formatter for `LLaMA`."""
+
+ SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
+
+ @staticmethod
+ def _convert_message_to_dict(message: BaseMessage) -> Dict:
+ """Converts message to a dict according to role"""
+ content = cast(str, message.content)
+ if isinstance(message, HumanMessage):
+ return {
+ "role": "user",
+ "content": ContentFormatterBase.escape_special_characters(content),
+ }
+ elif isinstance(message, AIMessage):
+ return {
+ "role": "assistant",
+ "content": ContentFormatterBase.escape_special_characters(content),
+ }
+ elif isinstance(message, SystemMessage):
+ return {
+ "role": "system",
+ "content": ContentFormatterBase.escape_special_characters(content),
+ }
+ elif (
+ isinstance(message, ChatMessage)
+ and message.role in LlamaContentFormatter.SUPPORTED_ROLES
+ ):
+ return {
+ "role": message.role,
+ "content": ContentFormatterBase.escape_special_characters(content),
+ }
+ else:
+ supported = ",".join(
+ [role for role in LlamaContentFormatter.SUPPORTED_ROLES]
+ )
+ raise ValueError(
+ f"""Received unsupported role.
+ Supported roles for the LLaMa Foundation Model: {supported}"""
+ )
+
+ def _format_request_payload(
+ self, messages: List[BaseMessage], model_kwargs: Dict
+ ) -> bytes:
+ chat_messages = [
+ LlamaContentFormatter._convert_message_to_dict(message)
+ for message in messages
+ ]
+ prompt = json.dumps(
+ {"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
+ )
+ return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
+
+ def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
+ """Formats the request according to the chosen api"""
+ return str.encode(prompt)
+
+ def format_response_payload(self, output: bytes) -> str:
+ """Formats response"""
+ return json.loads(output)["output"]
+
+
+class AzureMLChatOnlineEndpoint(SimpleChatModel):
+ """`AzureML` Chat models API.
+
+ Example:
+ .. code-block:: python
+
+ azure_chat = AzureMLChatOnlineEndpoint(
+ endpoint_url="https://..inference.ml.azure.com/score",
+ endpoint_api_key="my-api-key",
+ content_formatter=content_formatter,
+ )
+ """
+
+ endpoint_url: str = ""
+ """URL of pre-existing Endpoint. Should be passed to constructor or specified as
+ env var `AZUREML_ENDPOINT_URL`."""
+
+ endpoint_api_key: SecretStr = convert_to_secret_str("")
+ """Authentication Key for Endpoint. Should be passed to constructor or specified as
+ env var `AZUREML_ENDPOINT_API_KEY`."""
+
+ http_client: Any = None #: :meta private:
+
+ content_formatter: Any = None
+ """The content formatter that provides an input and output
+ transform function to handle formats between the LLM and
+ the endpoint"""
+
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+
+ @validator("http_client", always=True, allow_reuse=True)
+ @classmethod
+ def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
+ """Validate that api key and python package exist in environment."""
+ values["endpoint_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
+ )
+ endpoint_url = get_from_dict_or_env(
+ values, "endpoint_url", "AZUREML_ENDPOINT_URL"
+ )
+ http_client = AzureMLEndpointClient(
+ endpoint_url, values["endpoint_api_key"].get_secret_value()
+ )
+ return http_client
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "azureml_chat_endpoint"
+
+ def _call(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to an AzureML Managed Online endpoint.
+ Args:
+ messages: The messages in the conversation with the chat model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ The string generated by the model.
+ Example:
+ .. code-block:: python
+ response = azureml_model("Tell me a joke.")
+ """
+ _model_kwargs = self.model_kwargs or {}
+
+ request_payload = self.content_formatter._format_request_payload(
+ messages, _model_kwargs
+ )
+ response_payload = self.http_client.call(request_payload, **kwargs)
+ generated_text = self.content_formatter.format_response_payload(
+ response_payload
+ )
+ return generated_text
diff --git a/libs/community/langchain_community/chat_models/baichuan.py b/libs/community/langchain_community/chat_models/baichuan.py
new file mode 100644
index 00000000000..14cf4a57e2e
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/baichuan.py
@@ -0,0 +1,298 @@
+import hashlib
+import json
+import logging
+import time
+from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessage,
+ ChatMessageChunk,
+ HumanMessage,
+ HumanMessageChunk,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import (
+ convert_to_secret_str,
+ get_from_dict_or_env,
+ get_pydantic_field_names,
+)
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1"
+
+
+def _convert_message_to_dict(message: BaseMessage) -> dict:
+ message_dict: Dict[str, Any]
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ else:
+ raise TypeError(f"Got unknown type {message}")
+
+ return message_dict
+
+
+def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["role"]
+ if role == "user":
+ return HumanMessage(content=_dict["content"])
+ elif role == "assistant":
+ return AIMessage(content=_dict.get("content", "") or "")
+ else:
+ return ChatMessage(content=_dict["content"], role=role)
+
+
+def _convert_delta_to_message_chunk(
+ _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
+) -> BaseMessageChunk:
+ role = _dict.get("role")
+ content = _dict.get("content") or ""
+
+ if role == "user" or default_class == HumanMessageChunk:
+ return HumanMessageChunk(content=content)
+ elif role == "assistant" or default_class == AIMessageChunk:
+ return AIMessageChunk(content=content)
+ elif role or default_class == ChatMessageChunk:
+ return ChatMessageChunk(content=content, role=role)
+ else:
+ return default_class(content=content)
+
+
+# signature generation
+def _signature(secret_key: SecretStr, payload: Dict[str, Any], timestamp: int) -> str:
+ input_str = secret_key.get_secret_value() + json.dumps(payload) + str(timestamp)
+ md5 = hashlib.md5()
+ md5.update(input_str.encode("utf-8"))
+ return md5.hexdigest()
+
+
+class ChatBaichuan(BaseChatModel):
+ """Baichuan chat models API by Baichuan Intelligent Technology.
+
+ For more information, see https://platform.baichuan-ai.com/docs/api
+ """
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {
+ "baichuan_api_key": "BAICHUAN_API_KEY",
+ "baichuan_secret_key": "BAICHUAN_SECRET_KEY",
+ }
+
+ @property
+ def lc_serializable(self) -> bool:
+ return True
+
+ baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
+ """Baichuan custom endpoints"""
+ baichuan_api_key: Optional[SecretStr] = None
+ """Baichuan API Key"""
+ baichuan_secret_key: Optional[SecretStr] = None
+ """Baichuan Secret Key"""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+ request_timeout: int = 60
+ """request timeout for chat http requests"""
+
+ model = "Baichuan2-53B"
+ """model name of Baichuan, default is `Baichuan2-53B`."""
+ temperature: float = 0.3
+ """What sampling temperature to use."""
+ top_k: int = 5
+ """What search sampling control to use."""
+ top_p: float = 0.85
+ """What probability mass to use."""
+ with_search_enhance: bool = False
+ """Whether to use search enhance, default is False."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for API call not explicitly specified."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ values["baichuan_api_base"] = get_from_dict_or_env(
+ values,
+ "baichuan_api_base",
+ "BAICHUAN_API_BASE",
+ DEFAULT_API_BASE,
+ )
+ values["baichuan_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "baichuan_api_key",
+ "BAICHUAN_API_KEY",
+ )
+ )
+ values["baichuan_secret_key"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "baichuan_secret_key",
+ "BAICHUAN_SECRET_KEY",
+ )
+ )
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Baichuan API."""
+ normal_params = {
+ "model": self.model,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "with_search_enhance": self.with_search_enhance,
+ }
+
+ return {**normal_params, **self.model_kwargs}
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._stream(
+ messages=messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ res = self._chat(messages, **kwargs)
+
+ response = res.json()
+
+ if response.get("code") != 0:
+ raise ValueError(f"Error from Baichuan api response: {response}")
+
+ return self._create_chat_result(response)
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ res = self._chat(messages, **kwargs)
+
+ default_chunk_class = AIMessageChunk
+ for chunk in res.iter_lines():
+ response = json.loads(chunk)
+ if response.get("code") != 0:
+ raise ValueError(f"Error from Baichuan api response: {response}")
+
+ data = response.get("data")
+ for m in data.get("messages"):
+ chunk = _convert_delta_to_message_chunk(m, default_chunk_class)
+ default_chunk_class = chunk.__class__
+ yield ChatGenerationChunk(message=chunk)
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.content)
+
+ def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
+ if self.baichuan_secret_key is None:
+ raise ValueError("Baichuan secret key is not set.")
+
+ parameters = {**self._default_params, **kwargs}
+
+ model = parameters.pop("model")
+ headers = parameters.pop("headers", {})
+
+ payload = {
+ "model": model,
+ "messages": [_convert_message_to_dict(m) for m in messages],
+ "parameters": parameters,
+ }
+
+ timestamp = int(time.time())
+
+ url = self.baichuan_api_base
+ if self.streaming:
+ url = f"{url}/stream"
+ url = f"{url}/chat"
+
+ api_key = ""
+ if self.baichuan_api_key:
+ api_key = self.baichuan_api_key.get_secret_value()
+
+ res = requests.post(
+ url=url,
+ timeout=self.request_timeout,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}",
+ "X-BC-Timestamp": str(timestamp),
+ "X-BC-Signature": _signature(
+ secret_key=self.baichuan_secret_key,
+ payload=payload,
+ timestamp=timestamp,
+ ),
+ "X-BC-Sign-Algo": "MD5",
+ **headers,
+ },
+ json=payload,
+ stream=self.streaming,
+ )
+ return res
+
+ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for m in response["data"]["messages"]:
+ message = _convert_dict_to_message(m)
+ gen = ChatGeneration(message=message)
+ generations.append(gen)
+
+ token_usage = response["usage"]
+ llm_output = {"token_usage": token_usage, "model": self.model}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ @property
+ def _llm_type(self) -> str:
+ return "baichuan-chat"
diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py
new file mode 100644
index 00000000000..81e2a544a47
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py
@@ -0,0 +1,344 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ ChatMessage,
+ FunctionMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+def convert_message_to_dict(message: BaseMessage) -> dict:
+ """Convert a message to a dictionary that can be passed to the API."""
+ message_dict: Dict[str, Any]
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ if "function_call" in message.additional_kwargs:
+ message_dict["function_call"] = message.additional_kwargs["function_call"]
+ # If function call only, content is None not empty string
+ if message_dict["content"] == "":
+ message_dict["content"] = None
+ elif isinstance(message, FunctionMessage):
+ message_dict = {
+ "role": "function",
+ "content": message.content,
+ "name": message.name,
+ }
+ else:
+ raise TypeError(f"Got unknown type {message}")
+
+ return message_dict
+
+
+def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
+ content = _dict.get("result", "") or ""
+ if _dict.get("function_call"):
+ additional_kwargs = {"function_call": dict(_dict["function_call"])}
+ if "thoughts" in additional_kwargs["function_call"]:
+ # align to api sample, which affects the llm function_call output
+ additional_kwargs["function_call"].pop("thoughts")
+ else:
+ additional_kwargs = {}
+ return AIMessage(
+ content=content,
+ additional_kwargs={**_dict.get("body", {}), **additional_kwargs},
+ )
+
+
+class QianfanChatEndpoint(BaseChatModel):
+ """Baidu Qianfan chat models.
+
+ To use, you should have the ``qianfan`` python package installed, and
+ the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with your
+ API key and Secret Key.
+
+ ak, sk are required parameters
+ which you could get from https://cloud.baidu.com/product/wenxinworkshop
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import QianfanChatEndpoint
+ qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
+ endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
+ """
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+
+ client: Any
+
+ qianfan_ak: Optional[SecretStr] = None
+ qianfan_sk: Optional[SecretStr] = None
+
+ streaming: Optional[bool] = False
+ """Whether to stream the results or not."""
+
+ request_timeout: Optional[int] = 60
+ """request timeout for chat http requests"""
+
+ top_p: Optional[float] = 0.8
+ temperature: Optional[float] = 0.95
+ penalty_score: Optional[float] = 1
+ """Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
+ In the case of other model, passing these params will not affect the result.
+ """
+
+ model: str = "ERNIE-Bot-turbo"
+ """Model name.
+ you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
+
+ preset models are mapping to an endpoint.
+ `model` will be ignored if `endpoint` is set.
+ Default is ERNIE-Bot-turbo.
+ """
+
+ endpoint: Optional[str] = None
+ """Endpoint of the Qianfan LLM, required if custom model used."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ values["qianfan_ak"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "qianfan_ak",
+ "QIANFAN_AK",
+ )
+ )
+ values["qianfan_sk"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "qianfan_sk",
+ "QIANFAN_SK",
+ )
+ )
+ params = {
+ "ak": values["qianfan_ak"].get_secret_value(),
+ "sk": values["qianfan_sk"].get_secret_value(),
+ "model": values["model"],
+ "stream": values["streaming"],
+ }
+ if values["endpoint"] is not None and values["endpoint"] != "":
+ params["endpoint"] = values["endpoint"]
+ try:
+ import qianfan
+
+ values["client"] = qianfan.ChatCompletion(**params)
+ except ImportError:
+ raise ValueError(
+ "qianfan package not found, please install it with "
+ "`pip install qianfan`"
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return {
+ **{"endpoint": self.endpoint, "model": self.model},
+ **super()._identifying_params,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat_model."""
+ return "baidu-qianfan-chat"
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Qianfan API."""
+ normal_params = {
+ "model": self.model,
+ "endpoint": self.endpoint,
+ "stream": self.streaming,
+ "request_timeout": self.request_timeout,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ "penalty_score": self.penalty_score,
+ }
+
+ return {**normal_params, **self.model_kwargs}
+
+ def _convert_prompt_msg_params(
+ self,
+ messages: List[BaseMessage],
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ """
+ Converts a list of messages into a dictionary containing the message content
+ and default parameters.
+
+ Args:
+ messages (List[BaseMessage]): The list of messages.
+ **kwargs (Any): Optional arguments to add additional parameters to the
+ resulting dictionary.
+
+ Returns:
+ Dict[str, Any]: A dictionary containing the message content and default
+ parameters.
+
+ """
+ messages_dict: Dict[str, Any] = {
+ "messages": [
+ convert_message_to_dict(m)
+ for m in messages
+ if not isinstance(m, SystemMessage)
+ ]
+ }
+ for i in [i for i, m in enumerate(messages) if isinstance(m, SystemMessage)]:
+ if "system" not in messages_dict:
+ messages_dict["system"] = ""
+ messages_dict["system"] += cast(str, messages[i].content) + "\n"
+
+ return {
+ **messages_dict,
+ **self._default_params,
+ **kwargs,
+ }
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Call out to an qianfan models endpoint for each generation with a prompt.
+ Args:
+ messages: The messages to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+ response = qianfan_model("Tell me a joke.")
+ """
+ if self.streaming:
+ completion = ""
+ for chunk in self._stream(messages, stop, run_manager, **kwargs):
+ completion += chunk.text
+ lc_msg = AIMessage(content=completion, additional_kwargs={})
+ gen = ChatGeneration(
+ message=lc_msg,
+ generation_info=dict(finish_reason="stop"),
+ )
+ return ChatResult(
+ generations=[gen],
+ llm_output={"token_usage": {}, "model_name": self.model},
+ )
+ params = self._convert_prompt_msg_params(messages, **kwargs)
+ response_payload = self.client.do(**params)
+ lc_msg = _convert_dict_to_message(response_payload)
+ gen = ChatGeneration(
+ message=lc_msg,
+ generation_info={
+ "finish_reason": "stop",
+ **response_payload.get("body", {}),
+ },
+ )
+ token_usage = response_payload.get("usage", {})
+ llm_output = {"token_usage": token_usage, "model_name": self.model}
+ return ChatResult(generations=[gen], llm_output=llm_output)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ completion = ""
+ token_usage = {}
+ async for chunk in self._astream(messages, stop, run_manager, **kwargs):
+ completion += chunk.text
+
+ lc_msg = AIMessage(content=completion, additional_kwargs={})
+ gen = ChatGeneration(
+ message=lc_msg,
+ generation_info=dict(finish_reason="stop"),
+ )
+ return ChatResult(
+ generations=[gen],
+ llm_output={"token_usage": {}, "model_name": self.model},
+ )
+ params = self._convert_prompt_msg_params(messages, **kwargs)
+ response_payload = await self.client.ado(**params)
+ lc_msg = _convert_dict_to_message(response_payload)
+ generations = []
+ gen = ChatGeneration(
+ message=lc_msg,
+ generation_info={
+ "finish_reason": "stop",
+ **response_payload.get("body", {}),
+ },
+ )
+ generations.append(gen)
+ token_usage = response_payload.get("usage", {})
+ llm_output = {"token_usage": token_usage, "model_name": self.model}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ params = self._convert_prompt_msg_params(messages, **kwargs)
+ for res in self.client.do(**params):
+ if res:
+ msg = _convert_dict_to_message(res)
+ chunk = ChatGenerationChunk(
+ text=res["result"],
+ message=AIMessageChunk(
+ content=msg.content,
+ role="assistant",
+ additional_kwargs=msg.additional_kwargs,
+ ),
+ )
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ params = self._convert_prompt_msg_params(messages, **kwargs)
+ async for res in await self.client.ado(**params):
+ if res:
+ msg = _convert_dict_to_message(res)
+ chunk = ChatGenerationChunk(
+ text=res["result"],
+ message=AIMessageChunk(
+ content=msg.content,
+ role="assistant",
+ additional_kwargs=msg.additional_kwargs,
+ ),
+ )
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py
new file mode 100644
index 00000000000..49b7acad19a
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/bedrock.py
@@ -0,0 +1,131 @@
+from typing import Any, Dict, Iterator, List, Optional
+
+from langchain_core.callbacks import (
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import Extra
+
+from langchain_community.chat_models.anthropic import (
+ convert_messages_to_prompt_anthropic,
+)
+from langchain_community.chat_models.meta import convert_messages_to_prompt_llama
+from langchain_community.llms.bedrock import BedrockBase
+from langchain_community.utilities.anthropic import (
+ get_num_tokens_anthropic,
+ get_token_ids_anthropic,
+)
+
+
+class ChatPromptAdapter:
+ """Adapter class to prepare the inputs from Langchain to prompt format
+ that Chat model expects.
+ """
+
+ @classmethod
+ def convert_messages_to_prompt(
+ cls, provider: str, messages: List[BaseMessage]
+ ) -> str:
+ if provider == "anthropic":
+ prompt = convert_messages_to_prompt_anthropic(messages=messages)
+ elif provider == "meta":
+ prompt = convert_messages_to_prompt_llama(messages=messages)
+ else:
+ raise NotImplementedError(
+ f"Provider {provider} model does not support chat."
+ )
+ return prompt
+
+
+class BedrockChat(BaseChatModel, BedrockBase):
+ """A chat model that uses the Bedrock API."""
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "amazon_bedrock_chat"
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this model can be serialized by Langchain."""
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "chat_models", "bedrock"]
+
+ @property
+ def lc_attributes(self) -> Dict[str, Any]:
+ attributes: Dict[str, Any] = {}
+
+ if self.region_name:
+ attributes["region_name"] = self.region_name
+
+ return attributes
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ provider = self._get_provider()
+ prompt = ChatPromptAdapter.convert_messages_to_prompt(
+ provider=provider, messages=messages
+ )
+
+ for chunk in self._prepare_input_and_invoke_stream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ delta = chunk.text
+ yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ completion = ""
+
+ if self.streaming:
+ for chunk in self._stream(messages, stop, run_manager, **kwargs):
+ completion += chunk.text
+ else:
+ provider = self._get_provider()
+ prompt = ChatPromptAdapter.convert_messages_to_prompt(
+ provider=provider, messages=messages
+ )
+
+ params: Dict[str, Any] = {**kwargs}
+ if stop:
+ params["stop_sequences"] = stop
+
+ completion = self._prepare_input_and_invoke(
+ prompt=prompt, stop=stop, run_manager=run_manager, **params
+ )
+
+ message = AIMessage(content=completion)
+ return ChatResult(generations=[ChatGeneration(message=message)])
+
+ def get_num_tokens(self, text: str) -> int:
+ if self._model_is_anthropic:
+ return get_num_tokens_anthropic(text)
+ else:
+ return super().get_num_tokens(text)
+
+ def get_token_ids(self, text: str) -> List[int]:
+ if self._model_is_anthropic:
+ return get_token_ids_anthropic(text)
+ else:
+ return super().get_token_ids(text)
diff --git a/libs/community/langchain_community/chat_models/cohere.py b/libs/community/langchain_community/chat_models/cohere.py
new file mode 100644
index 00000000000..f74abcee3ca
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/cohere.py
@@ -0,0 +1,234 @@
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ agenerate_from_stream,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+
+from langchain_community.llms.cohere import BaseCohere
+
+
+def get_role(message: BaseMessage) -> str:
+ """Get the role of the message.
+
+ Args:
+ message: The message.
+
+ Returns:
+ The role of the message.
+
+ Raises:
+ ValueError: If the message is of an unknown type.
+ """
+ if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
+ return "User"
+ elif isinstance(message, AIMessage):
+ return "Chatbot"
+ elif isinstance(message, SystemMessage):
+ return "System"
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+
+def get_cohere_chat_request(
+ messages: List[BaseMessage],
+ *,
+ connectors: Optional[List[Dict[str, str]]] = None,
+ **kwargs: Any,
+) -> Dict[str, Any]:
+ """Get the request for the Cohere chat API.
+
+ Args:
+ messages: The messages.
+ connectors: The connectors.
+ **kwargs: The keyword arguments.
+
+ Returns:
+ The request for the Cohere chat API.
+ """
+ documents = (
+ None
+ if "source_documents" not in kwargs
+ else [
+ {
+ "snippet": doc.page_content,
+ "id": doc.metadata.get("id") or f"doc-{str(i)}",
+ }
+ for i, doc in enumerate(kwargs["source_documents"])
+ ]
+ )
+ kwargs.pop("source_documents", None)
+ maybe_connectors = connectors if documents is None else None
+
+ # by enabling automatic prompt truncation, the probability of request failure is
+ # reduced with minimal impact on response quality
+ prompt_truncation = (
+ "AUTO" if documents is not None or connectors is not None else None
+ )
+
+ return {
+ "message": messages[-1].content,
+ "chat_history": [
+ {"role": get_role(x), "message": x.content} for x in messages[:-1]
+ ],
+ "documents": documents,
+ "connectors": maybe_connectors,
+ "prompt_truncation": prompt_truncation,
+ **kwargs,
+ }
+
+
+class ChatCohere(BaseChatModel, BaseCohere):
+ """`Cohere` chat large language models.
+
+ To use, you should have the ``cohere`` python package installed, and the
+ environment variable ``COHERE_API_KEY`` set with your API key, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatCohere
+ from langchain_core.messages import HumanMessage
+
+ chat = ChatCohere(model="foo")
+ result = chat([HumanMessage(content="Hello")])
+ print(result.content)
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+ arbitrary_types_allowed = True
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "cohere-chat"
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Cohere API."""
+ return {
+ "temperature": self.temperature,
+ }
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model": self.model}, **self._default_params}
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
+ stream = self.client.chat(**request, stream=True)
+
+ for data in stream:
+ if data.event_type == "text-generation":
+ delta = data.text
+ yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
+ if run_manager:
+ run_manager.on_llm_new_token(delta)
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
+ stream = await self.async_client.chat(**request, stream=True)
+
+ async for data in stream:
+ if data.event_type == "text-generation":
+ delta = data.text
+ yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
+ if run_manager:
+ await run_manager.on_llm_new_token(delta)
+
+ def _get_generation_info(self, response: Any) -> Dict[str, Any]:
+ """Get the generation info from cohere API response."""
+ return {
+ "documents": response.documents,
+ "citations": response.citations,
+ "search_results": response.search_results,
+ "search_queries": response.search_queries,
+ "token_count": response.token_count,
+ }
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._stream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
+ response = self.client.chat(**request)
+
+ message = AIMessage(content=response.text)
+ generation_info = None
+ if hasattr(response, "documents"):
+ generation_info = self._get_generation_info(response)
+ return ChatResult(
+ generations=[
+ ChatGeneration(message=message, generation_info=generation_info)
+ ]
+ )
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._astream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await agenerate_from_stream(stream_iter)
+
+ request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
+ response = self.client.chat(**request, stream=False)
+
+ message = AIMessage(content=response.text)
+ generation_info = None
+ if hasattr(response, "documents"):
+ generation_info = self._get_generation_info(response)
+ return ChatResult(
+ generations=[
+ ChatGeneration(message=message, generation_info=generation_info)
+ ]
+ )
+
+ def get_num_tokens(self, text: str) -> int:
+ """Calculate number of tokens."""
+ return len(self.client.tokenize(text).tokens)
diff --git a/libs/community/langchain_community/chat_models/databricks.py b/libs/community/langchain_community/chat_models/databricks.py
new file mode 100644
index 00000000000..008d4d3fa59
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/databricks.py
@@ -0,0 +1,46 @@
+import logging
+from urllib.parse import urlparse
+
+from langchain_community.chat_models.mlflow import ChatMlflow
+
+logger = logging.getLogger(__name__)
+
+
+class ChatDatabricks(ChatMlflow):
+ """`Databricks` chat models API.
+
+ To use, you should have the ``mlflow`` python package installed.
+ For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatDatabricks
+
+ chat = ChatDatabricks(
+ target_uri="databricks",
+ endpoint="chat",
+ temperature-0.1,
+ )
+ """
+
+ target_uri: str = "databricks"
+ """The target URI to use. Defaults to ``databricks``."""
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "databricks-chat"
+
+ @property
+ def _mlflow_extras(self) -> str:
+ return ""
+
+ def _validate_uri(self) -> None:
+ if self.target_uri == "databricks":
+ return
+
+ if urlparse(self.target_uri).scheme != "databricks":
+ raise ValueError(
+ "Invalid target URI. The target URI must be a valid databricks URI."
+ )
diff --git a/libs/community/langchain_community/chat_models/ernie.py b/libs/community/langchain_community/chat_models/ernie.py
new file mode 100644
index 00000000000..8d69669afc2
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/ernie.py
@@ -0,0 +1,223 @@
+import logging
+import threading
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatResult
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+def _convert_message_to_dict(message: BaseMessage) -> dict:
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ return message_dict
+
+
+class ErnieBotChat(BaseChatModel):
+ """`ERNIE-Bot` large language model.
+
+ ERNIE-Bot is a large language model developed by Baidu,
+ covering a huge amount of Chinese data.
+
+ To use, you should have the `ernie_client_id` and `ernie_client_secret` set,
+ or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
+
+ Note:
+ access_token will be automatically generated based on client_id and client_secret,
+ and will be regenerated after expiration (30 days).
+
+ Default model is `ERNIE-Bot-turbo`,
+ currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`, `ERNIE-Bot-8K`,
+ `ERNIE-Bot-4`, `ERNIE-Bot-turbo-AI`.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ErnieBotChat
+ chat = ErnieBotChat(model_name='ERNIE-Bot')
+
+
+ Deprecated Note:
+ Please use `QianfanChatEndpoint` instead of this class.
+ `QianfanChatEndpoint` is a more suitable choice for production.
+
+ Always test your code after changing to `QianfanChatEndpoint`.
+
+ Example of `QianfanChatEndpoint`:
+ .. code-block:: python
+
+ from langchain_community.chat_models import QianfanChatEndpoint
+ qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
+ endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
+
+ """
+
+ ernie_api_base: Optional[str] = None
+ """Baidu application custom endpoints"""
+
+ ernie_client_id: Optional[str] = None
+ """Baidu application client id"""
+
+ ernie_client_secret: Optional[str] = None
+ """Baidu application client secret"""
+
+ access_token: Optional[str] = None
+ """access token is generated by client id and client secret,
+ setting this value directly will cause an error"""
+
+ model_name: str = "ERNIE-Bot-turbo"
+ """model name of ernie, default is `ERNIE-Bot-turbo`.
+ Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""
+
+ system: Optional[str] = None
+ """system is mainly used for model character design,
+ for example, you are an AI assistant produced by xxx company.
+ The length of the system is limiting of 1024 characters."""
+
+ request_timeout: Optional[int] = 60
+ """request timeout for chat http requests"""
+
+ streaming: Optional[bool] = False
+ """streaming mode. not supported yet."""
+
+ top_p: Optional[float] = 0.8
+ temperature: Optional[float] = 0.95
+ penalty_score: Optional[float] = 1
+
+ _lock = threading.Lock()
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ values["ernie_api_base"] = get_from_dict_or_env(
+ values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
+ )
+ values["ernie_client_id"] = get_from_dict_or_env(
+ values,
+ "ernie_client_id",
+ "ERNIE_CLIENT_ID",
+ )
+ values["ernie_client_secret"] = get_from_dict_or_env(
+ values,
+ "ernie_client_secret",
+ "ERNIE_CLIENT_SECRET",
+ )
+ return values
+
+ def _chat(self, payload: object) -> dict:
+ base_url = f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
+ model_paths = {
+ "ERNIE-Bot-turbo": "eb-instant",
+ "ERNIE-Bot": "completions",
+ "ERNIE-Bot-8K": "ernie_bot_8k",
+ "ERNIE-Bot-4": "completions_pro",
+ "ERNIE-Bot-turbo-AI": "ai_apaas",
+ "BLOOMZ-7B": "bloomz_7b1",
+ "Llama-2-7b-chat": "llama_2_7b",
+ "Llama-2-13b-chat": "llama_2_13b",
+ "Llama-2-70b-chat": "llama_2_70b",
+ }
+ if self.model_name in model_paths:
+ url = f"{base_url}/{model_paths[self.model_name]}"
+ else:
+ raise ValueError(f"Got unknown model_name {self.model_name}")
+
+ resp = requests.post(
+ url,
+ timeout=self.request_timeout,
+ headers={
+ "Content-Type": "application/json",
+ },
+ params={"access_token": self.access_token},
+ json=payload,
+ )
+ return resp.json()
+
+ def _refresh_access_token_with_lock(self) -> None:
+ with self._lock:
+ logger.debug("Refreshing access token")
+ base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
+ resp = requests.post(
+ base_url,
+ timeout=10,
+ headers={
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ },
+ params={
+ "grant_type": "client_credentials",
+ "client_id": self.ernie_client_id,
+ "client_secret": self.ernie_client_secret,
+ },
+ )
+ self.access_token = str(resp.json().get("access_token"))
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ raise ValueError("`streaming` option currently unsupported.")
+
+ if not self.access_token:
+ self._refresh_access_token_with_lock()
+ payload = {
+ "messages": [_convert_message_to_dict(m) for m in messages],
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ "penalty_score": self.penalty_score,
+ "system": self.system,
+ **kwargs,
+ }
+ logger.debug(f"Payload for ernie api is {payload}")
+ resp = self._chat(payload)
+ if resp.get("error_code"):
+ if resp.get("error_code") == 111:
+ logger.debug("access_token expired, refresh it")
+ self._refresh_access_token_with_lock()
+ resp = self._chat(payload)
+ else:
+ raise ValueError(f"Error from ErnieChat api response: {resp}")
+ return self._create_chat_result(resp)
+
+ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
+ if "function_call" in response:
+ additional_kwargs = {
+ "function_call": dict(response.get("function_call", {}))
+ }
+ else:
+ additional_kwargs = {}
+ generations = [
+ ChatGeneration(
+ message=AIMessage(
+ content=response.get("result"),
+ additional_kwargs={**additional_kwargs},
+ )
+ )
+ ]
+ token_usage = response.get("usage", {})
+ llm_output = {"token_usage": token_usage, "model_name": self.model_name}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ @property
+ def _llm_type(self) -> str:
+ return "ernie-bot-chat"
diff --git a/libs/community/langchain_community/chat_models/everlyai.py b/libs/community/langchain_community/chat_models/everlyai.py
new file mode 100644
index 00000000000..dca315d1899
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/everlyai.py
@@ -0,0 +1,159 @@
+"""EverlyAI Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
+from __future__ import annotations
+
+import logging
+import sys
+from typing import TYPE_CHECKING, Dict, Optional, Set
+
+from langchain_core.messages import BaseMessage
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.adapters.openai import convert_message_to_dict
+from langchain_community.chat_models.openai import (
+ ChatOpenAI,
+ _import_tiktoken,
+)
+
+if TYPE_CHECKING:
+ import tiktoken
+
+logger = logging.getLogger(__name__)
+
+
+DEFAULT_API_BASE = "https://everlyai.xyz/hosted"
+DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
+
+
+class ChatEverlyAI(ChatOpenAI):
+ """`EverlyAI` Chat large language models.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``EVERLYAI_API_KEY`` set with your API key.
+ Alternatively, you can use the everlyai_api_key keyword argument.
+
+ Any parameters that are valid to be passed to the `openai.create` call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatEverlyAI
+ chat = ChatEverlyAI(model_name="meta-llama/Llama-2-7b-chat-hf")
+ """
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "everlyai-chat"
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"everlyai_api_key": "EVERLYAI_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ everlyai_api_key: Optional[str] = None
+ """EverlyAI Endpoints API keys."""
+ model_name: str = Field(default=DEFAULT_MODEL, alias="model")
+ """Model name to use."""
+ everlyai_api_base: str = DEFAULT_API_BASE
+ """Base URL path for API requests."""
+ available_models: Optional[Set[str]] = None
+ """Available models from EverlyAI API."""
+
+ @staticmethod
+ def get_available_models() -> Set[str]:
+ """Get available models from EverlyAI API."""
+ # EverlyAI doesn't yet support dynamically query for available models.
+ return set(
+ [
+ "meta-llama/Llama-2-7b-chat-hf",
+ "meta-llama/Llama-2-13b-chat-hf-quantized",
+ ]
+ )
+
+ @root_validator(pre=True)
+ def validate_environment_override(cls, values: dict) -> dict:
+ """Validate that api key and python package exists in environment."""
+ values["openai_api_key"] = get_from_dict_or_env(
+ values,
+ "everlyai_api_key",
+ "EVERLYAI_API_KEY",
+ )
+ values["openai_api_base"] = DEFAULT_API_BASE
+
+ try:
+ import openai
+
+ except ImportError as e:
+ raise ValueError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`.",
+ ) from e
+ try:
+ values["client"] = openai.ChatCompletion
+ except AttributeError as exc:
+ raise ValueError(
+ "`openai` has no `ChatCompletion` attribute, this is likely "
+ "due to an old version of the openai package. Try upgrading it "
+ "with `pip install --upgrade openai`.",
+ ) from exc
+
+ if "model_name" not in values.keys():
+ values["model_name"] = DEFAULT_MODEL
+
+ model_name = values["model_name"]
+
+ available_models = cls.get_available_models()
+
+ if model_name not in available_models:
+ raise ValueError(
+ f"Model name {model_name} not found in available models: "
+ f"{available_models}.",
+ )
+
+ values["available_models"] = available_models
+
+ return values
+
+ def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
+ tiktoken_ = _import_tiktoken()
+ if self.tiktoken_model_name is not None:
+ model = self.tiktoken_model_name
+ else:
+ model = self.model_name
+ # Returns the number of tokens used by a list of messages.
+ try:
+ encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ model = "cl100k_base"
+ encoding = tiktoken_.get_encoding(model)
+ return model, encoding
+
+ def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
+ """Calculate num tokens with tiktoken package.
+
+ Official documentation: https://github.com/openai/openai-cookbook/blob/
+ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+ if sys.version_info[1] <= 7:
+ return super().get_num_tokens_from_messages(messages)
+ model, encoding = self._get_encoding_model()
+ tokens_per_message = 3
+ tokens_per_name = 1
+ num_tokens = 0
+ messages_dict = [convert_message_to_dict(m) for m in messages]
+ for message in messages_dict:
+ num_tokens += tokens_per_message
+ for key, value in message.items():
+ # Cast str(value) in case the message value is not a string
+ # This occurs with function messages
+ num_tokens += len(encoding.encode(str(value)))
+ if key == "name":
+ num_tokens += tokens_per_name
+ # every reply is primed with assistant
+ num_tokens += 3
+ return num_tokens
diff --git a/libs/community/langchain_community/chat_models/fake.py b/libs/community/langchain_community/chat_models/fake.py
new file mode 100644
index 00000000000..2ce4b00117b
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/fake.py
@@ -0,0 +1,104 @@
+"""Fake ChatModel for testing purposes."""
+import asyncio
+import time
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
+from langchain_core.messages import AIMessageChunk, BaseMessage
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+
+
+class FakeMessagesListChatModel(BaseChatModel):
+ """Fake ChatModel for testing purposes."""
+
+ responses: List[BaseMessage]
+ sleep: Optional[float] = None
+ i: int = 0
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ response = self.responses[self.i]
+ if self.i < len(self.responses) - 1:
+ self.i += 1
+ else:
+ self.i = 0
+ generation = ChatGeneration(message=response)
+ return ChatResult(generations=[generation])
+
+ @property
+ def _llm_type(self) -> str:
+ return "fake-messages-list-chat-model"
+
+
+class FakeListChatModel(SimpleChatModel):
+ """Fake ChatModel for testing purposes."""
+
+ responses: List
+ sleep: Optional[float] = None
+ i: int = 0
+
+ @property
+ def _llm_type(self) -> str:
+ return "fake-list-chat-model"
+
+ def _call(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """First try to lookup in queries, else return 'foo' or 'bar'."""
+ response = self.responses[self.i]
+ if self.i < len(self.responses) - 1:
+ self.i += 1
+ else:
+ self.i = 0
+ return response
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Union[List[str], None] = None,
+ run_manager: Union[CallbackManagerForLLMRun, None] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ response = self.responses[self.i]
+ if self.i < len(self.responses) - 1:
+ self.i += 1
+ else:
+ self.i = 0
+ for c in response:
+ if self.sleep is not None:
+ time.sleep(self.sleep)
+ yield ChatGenerationChunk(message=AIMessageChunk(content=c))
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Union[List[str], None] = None,
+ run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ response = self.responses[self.i]
+ if self.i < len(self.responses) - 1:
+ self.i += 1
+ else:
+ self.i = 0
+ for c in response:
+ if self.sleep is not None:
+ await asyncio.sleep(self.sleep)
+ yield ChatGenerationChunk(message=AIMessageChunk(content=c))
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return {"responses": self.responses}
diff --git a/libs/community/langchain_community/chat_models/fireworks.py b/libs/community/langchain_community/chat_models/fireworks.py
new file mode 100644
index 00000000000..0b54c0a873f
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/fireworks.py
@@ -0,0 +1,350 @@
+from typing import (
+ Any,
+ AsyncIterator,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Type,
+ Union,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.language_models.llms import create_base_retry_decorator
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessage,
+ ChatMessageChunk,
+ FunctionMessage,
+ FunctionMessageChunk,
+ HumanMessage,
+ HumanMessageChunk,
+ SystemMessage,
+ SystemMessageChunk,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str
+from langchain_core.utils.env import get_from_dict_or_env
+
+from langchain_community.adapters.openai import convert_message_to_dict
+
+
+def _convert_delta_to_message_chunk(
+ _dict: Any, default_class: Type[BaseMessageChunk]
+) -> BaseMessageChunk:
+ """Convert a delta response to a message chunk."""
+ role = _dict.role
+ content = _dict.content or ""
+ additional_kwargs: Dict = {}
+
+ if role == "user" or default_class == HumanMessageChunk:
+ return HumanMessageChunk(content=content)
+ elif role == "assistant" or default_class == AIMessageChunk:
+ return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
+ elif role == "system" or default_class == SystemMessageChunk:
+ return SystemMessageChunk(content=content)
+ elif role == "function" or default_class == FunctionMessageChunk:
+ return FunctionMessageChunk(content=content, name=_dict.name)
+ elif role or default_class == ChatMessageChunk:
+ return ChatMessageChunk(content=content, role=role)
+ else:
+ return default_class(content=content)
+
+
+def convert_dict_to_message(_dict: Any) -> BaseMessage:
+ """Convert a dict response to a message."""
+ role = _dict.role
+ content = _dict.content or ""
+ if role == "user":
+ return HumanMessage(content=content)
+ elif role == "assistant":
+ content = _dict.content
+ additional_kwargs: Dict = {}
+ return AIMessage(content=content, additional_kwargs=additional_kwargs)
+ elif role == "system":
+ return SystemMessage(content=content)
+ elif role == "function":
+ return FunctionMessage(content=content, name=_dict.name)
+ else:
+ return ChatMessage(content=content, role=role)
+
+
+class ChatFireworks(BaseChatModel):
+ """Fireworks Chat models."""
+
+ model: str = "accounts/fireworks/models/llama-v2-7b-chat"
+ model_kwargs: dict = Field(
+ default_factory=lambda: {
+ "temperature": 0.7,
+ "max_tokens": 512,
+ "top_p": 1,
+ }.copy()
+ )
+ fireworks_api_key: Optional[SecretStr] = None
+ max_retries: int = 20
+ use_retry: bool = True
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"fireworks_api_key": "FIREWORKS_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "chat_models", "fireworks"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key in environment."""
+ try:
+ import fireworks.client
+ except ImportError as e:
+ raise ImportError(
+ "Could not import fireworks-ai python package. "
+ "Please install it with `pip install fireworks-ai`."
+ ) from e
+ fireworks_api_key = convert_to_secret_str(
+ get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
+ )
+ fireworks.client.api_key = fireworks_api_key.get_secret_value()
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "fireworks-chat"
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ message_dicts = self._create_message_dicts(messages)
+
+ params = {
+ "model": self.model,
+ "messages": message_dicts,
+ **self.model_kwargs,
+ **kwargs,
+ }
+ response = completion_with_retry(
+ self,
+ self.use_retry,
+ run_manager=run_manager,
+ stop=stop,
+ **params,
+ )
+ return self._create_chat_result(response)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ message_dicts = self._create_message_dicts(messages)
+ params = {
+ "model": self.model,
+ "messages": message_dicts,
+ **self.model_kwargs,
+ **kwargs,
+ }
+ response = await acompletion_with_retry(
+ self, self.use_retry, run_manager=run_manager, stop=stop, **params
+ )
+ return self._create_chat_result(response)
+
+ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
+ if llm_outputs[0] is None:
+ return {}
+ return llm_outputs[0]
+
+ def _create_chat_result(self, response: Any) -> ChatResult:
+ generations = []
+ for res in response.choices:
+ message = convert_dict_to_message(res.message)
+ gen = ChatGeneration(
+ message=message,
+ generation_info=dict(finish_reason=res.finish_reason),
+ )
+ generations.append(gen)
+ llm_output = {"model": self.model}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ def _create_message_dicts(
+ self, messages: List[BaseMessage]
+ ) -> List[Dict[str, Any]]:
+ message_dicts = [convert_message_to_dict(m) for m in messages]
+ return message_dicts
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ message_dicts = self._create_message_dicts(messages)
+ default_chunk_class = AIMessageChunk
+ params = {
+ "model": self.model,
+ "messages": message_dicts,
+ "stream": True,
+ **self.model_kwargs,
+ **kwargs,
+ }
+ for chunk in completion_with_retry(
+ self, self.use_retry, run_manager=run_manager, stop=stop, **params
+ ):
+ choice = chunk.choices[0]
+ chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
+ finish_reason = choice.finish_reason
+ generation_info = (
+ dict(finish_reason=finish_reason) if finish_reason is not None else None
+ )
+ default_chunk_class = chunk.__class__
+ chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ message_dicts = self._create_message_dicts(messages)
+ default_chunk_class = AIMessageChunk
+ params = {
+ "model": self.model,
+ "messages": message_dicts,
+ "stream": True,
+ **self.model_kwargs,
+ **kwargs,
+ }
+ async for chunk in await acompletion_with_retry_streaming(
+ self, self.use_retry, run_manager=run_manager, stop=stop, **params
+ ):
+ choice = chunk.choices[0]
+ chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
+ finish_reason = choice.finish_reason
+ generation_info = (
+ dict(finish_reason=finish_reason) if finish_reason is not None else None
+ )
+ default_chunk_class = chunk.__class__
+ chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
+
+
+def conditional_decorator(
+ condition: bool, decorator: Callable[[Any], Any]
+) -> Callable[[Any], Any]:
+ def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
+ if condition:
+ return decorator(func)
+ return func
+
+ return actual_decorator
+
+
+def completion_with_retry(
+ llm: ChatFireworks,
+ use_retry: bool,
+ *,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ import fireworks.client
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @conditional_decorator(use_retry, retry_decorator)
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return fireworks.client.ChatCompletion.create(
+ **kwargs,
+ )
+
+ return _completion_with_retry(**kwargs)
+
+
+async def acompletion_with_retry(
+ llm: ChatFireworks,
+ use_retry: bool,
+ *,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the async completion call."""
+ import fireworks.client
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @conditional_decorator(use_retry, retry_decorator)
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ return await fireworks.client.ChatCompletion.acreate(
+ **kwargs,
+ )
+
+ return await _completion_with_retry(**kwargs)
+
+
+async def acompletion_with_retry_streaming(
+ llm: ChatFireworks,
+ use_retry: bool,
+ *,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call for streaming."""
+ import fireworks.client
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @conditional_decorator(use_retry, retry_decorator)
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ return fireworks.client.ChatCompletion.acreate(
+ **kwargs,
+ )
+
+ return await _completion_with_retry(**kwargs)
+
+
+def _create_retry_decorator(
+ llm: ChatFireworks,
+ run_manager: Optional[
+ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
+ ] = None,
+) -> Callable[[Any], Any]:
+ """Define retry mechanism."""
+ import fireworks.client
+
+ errors = [
+ fireworks.client.error.RateLimitError,
+ fireworks.client.error.InternalServerError,
+ fireworks.client.error.BadGatewayError,
+ fireworks.client.error.ServiceUnavailableError,
+ ]
+ return create_base_retry_decorator(
+ error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
+ )
diff --git a/libs/community/langchain_community/chat_models/gigachat.py b/libs/community/langchain_community/chat_models/gigachat.py
new file mode 100644
index 00000000000..1349a40c621
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/gigachat.py
@@ -0,0 +1,179 @@
+import logging
+from typing import Any, AsyncIterator, Iterator, List, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ agenerate_from_stream,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+
+from langchain_community.llms.gigachat import _BaseGigaChat
+
+logger = logging.getLogger(__name__)
+
+
+def _convert_dict_to_message(message: Any) -> BaseMessage:
+ from gigachat.models import MessagesRole
+
+ if message.role == MessagesRole.SYSTEM:
+ return SystemMessage(content=message.content)
+ elif message.role == MessagesRole.USER:
+ return HumanMessage(content=message.content)
+ elif message.role == MessagesRole.ASSISTANT:
+ return AIMessage(content=message.content)
+ else:
+ raise TypeError(f"Got unknown role {message.role} {message}")
+
+
+def _convert_message_to_dict(message: BaseMessage) -> Any:
+ from gigachat.models import Messages, MessagesRole
+
+ if isinstance(message, SystemMessage):
+ return Messages(role=MessagesRole.SYSTEM, content=message.content)
+ elif isinstance(message, HumanMessage):
+ return Messages(role=MessagesRole.USER, content=message.content)
+ elif isinstance(message, AIMessage):
+ return Messages(role=MessagesRole.ASSISTANT, content=message.content)
+ elif isinstance(message, ChatMessage):
+ return Messages(role=MessagesRole(message.role), content=message.content)
+ else:
+ raise TypeError(f"Got unknown type {message}")
+
+
+class GigaChat(_BaseGigaChat, BaseChatModel):
+ """`GigaChat` large language models API.
+
+ To use, you should pass login and password to access GigaChat API or use token.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import GigaChat
+ giga = GigaChat(credentials=..., verify_ssl_certs=False)
+ """
+
+ def _build_payload(self, messages: List[BaseMessage]) -> Any:
+ from gigachat.models import Chat
+
+ payload = Chat(
+ messages=[_convert_message_to_dict(m) for m in messages],
+ profanity_check=self.profanity,
+ )
+ if self.temperature is not None:
+ payload.temperature = self.temperature
+ if self.max_tokens is not None:
+ payload.max_tokens = self.max_tokens
+
+ if self.verbose:
+ logger.info("Giga request: %s", payload.dict())
+
+ return payload
+
+ def _create_chat_result(self, response: Any) -> ChatResult:
+ generations = []
+ for res in response.choices:
+ message = _convert_dict_to_message(res.message)
+ finish_reason = res.finish_reason
+ gen = ChatGeneration(
+ message=message,
+ generation_info={"finish_reason": finish_reason},
+ )
+ generations.append(gen)
+ if finish_reason != "stop":
+ logger.warning(
+ "Giga generation stopped with reason: %s",
+ finish_reason,
+ )
+ if self.verbose:
+ logger.info("Giga response: %s", message.content)
+ llm_output = {"token_usage": response.usage, "model_name": response.model}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._stream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ payload = self._build_payload(messages)
+ response = self._client.chat(payload)
+
+ return self._create_chat_result(response)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._astream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await agenerate_from_stream(stream_iter)
+
+ payload = self._build_payload(messages)
+ response = await self._client.achat(payload)
+
+ return self._create_chat_result(response)
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ payload = self._build_payload(messages)
+
+ for chunk in self._client.stream(payload):
+ if chunk.choices:
+ content = chunk.choices[0].delta.content
+ yield ChatGenerationChunk(message=AIMessageChunk(content=content))
+ if run_manager:
+ run_manager.on_llm_new_token(content)
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ payload = self._build_payload(messages)
+
+ async for chunk in self._client.astream(payload):
+ if chunk.choices:
+ content = chunk.choices[0].delta.content
+ yield ChatGenerationChunk(message=AIMessageChunk(content=content))
+ if run_manager:
+ await run_manager.on_llm_new_token(content)
+
+ def get_num_tokens(self, text: str) -> int:
+ """Count approximate number of tokens"""
+ return round(len(text) / 4.6)
diff --git a/libs/community/langchain_community/chat_models/google_palm.py b/libs/community/langchain_community/chat_models/google_palm.py
new file mode 100644
index 00000000000..23d86f0bf24
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/google_palm.py
@@ -0,0 +1,348 @@
+"""Wrapper around Google's PaLM Chat API."""
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import (
+ ChatGeneration,
+ ChatResult,
+)
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+if TYPE_CHECKING:
+ import google.generativeai as genai
+
+logger = logging.getLogger(__name__)
+
+
+class ChatGooglePalmError(Exception):
+ """Error with the `Google PaLM` API."""
+
+
+def _truncate_at_stop_tokens(
+ text: str,
+ stop: Optional[List[str]],
+) -> str:
+ """Truncates text at the earliest stop token found."""
+ if stop is None:
+ return text
+
+ for stop_token in stop:
+ stop_token_idx = text.find(stop_token)
+ if stop_token_idx != -1:
+ text = text[:stop_token_idx]
+ return text
+
+
+def _response_to_result(
+ response: genai.types.ChatResponse,
+ stop: Optional[List[str]],
+) -> ChatResult:
+ """Converts a PaLM API response into a LangChain ChatResult."""
+ if not response.candidates:
+ raise ChatGooglePalmError("ChatResponse must have at least one candidate.")
+
+ generations: List[ChatGeneration] = []
+ for candidate in response.candidates:
+ author = candidate.get("author")
+ if author is None:
+ raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}")
+
+ content = _truncate_at_stop_tokens(candidate.get("content", ""), stop)
+ if content is None:
+ raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}")
+
+ if author == "ai":
+ generations.append(
+ ChatGeneration(text=content, message=AIMessage(content=content))
+ )
+ elif author == "human":
+ generations.append(
+ ChatGeneration(
+ text=content,
+ message=HumanMessage(content=content),
+ )
+ )
+ else:
+ generations.append(
+ ChatGeneration(
+ text=content,
+ message=ChatMessage(role=author, content=content),
+ )
+ )
+
+ return ChatResult(generations=generations)
+
+
+def _messages_to_prompt_dict(
+ input_messages: List[BaseMessage],
+) -> genai.types.MessagePromptDict:
+ """Converts a list of LangChain messages into a PaLM API MessagePrompt structure."""
+ import google.generativeai as genai
+
+ context: str = ""
+ examples: List[genai.types.MessageDict] = []
+ messages: List[genai.types.MessageDict] = []
+
+ remaining = list(enumerate(input_messages))
+
+ while remaining:
+ index, input_message = remaining.pop(0)
+
+ if isinstance(input_message, SystemMessage):
+ if index != 0:
+ raise ChatGooglePalmError("System message must be first input message.")
+ context = cast(str, input_message.content)
+ elif isinstance(input_message, HumanMessage) and input_message.example:
+ if messages:
+ raise ChatGooglePalmError(
+ "Message examples must come before other messages."
+ )
+ _, next_input_message = remaining.pop(0)
+ if isinstance(next_input_message, AIMessage) and next_input_message.example:
+ examples.extend(
+ [
+ genai.types.MessageDict(
+ author="human", content=input_message.content
+ ),
+ genai.types.MessageDict(
+ author="ai", content=next_input_message.content
+ ),
+ ]
+ )
+ else:
+ raise ChatGooglePalmError(
+ "Human example message must be immediately followed by an "
+ " AI example response."
+ )
+ elif isinstance(input_message, AIMessage) and input_message.example:
+ raise ChatGooglePalmError(
+ "AI example message must be immediately preceded by a Human "
+ "example message."
+ )
+ elif isinstance(input_message, AIMessage):
+ messages.append(
+ genai.types.MessageDict(author="ai", content=input_message.content)
+ )
+ elif isinstance(input_message, HumanMessage):
+ messages.append(
+ genai.types.MessageDict(author="human", content=input_message.content)
+ )
+ elif isinstance(input_message, ChatMessage):
+ messages.append(
+ genai.types.MessageDict(
+ author=input_message.role, content=input_message.content
+ )
+ )
+ else:
+ raise ChatGooglePalmError(
+ "Messages without an explicit role not supported by PaLM API."
+ )
+
+ return genai.types.MessagePromptDict(
+ context=context,
+ examples=examples,
+ messages=messages,
+ )
+
+
+def _create_retry_decorator() -> Callable[[Any], Any]:
+ """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
+ import google.api_core.exceptions
+
+ multiplier = 2
+ min_seconds = 1
+ max_seconds = 60
+ max_retries = 10
+
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(max_retries),
+ wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
+ retry=(
+ retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
+ | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
+ | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator()
+
+ @retry_decorator
+ def _chat_with_retry(**kwargs: Any) -> Any:
+ return llm.client.chat(**kwargs)
+
+ return _chat_with_retry(**kwargs)
+
+
+async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
+ """Use tenacity to retry the async completion call."""
+ retry_decorator = _create_retry_decorator()
+
+ @retry_decorator
+ async def _achat_with_retry(**kwargs: Any) -> Any:
+ # Use OpenAI's async api https://github.com/openai/openai-python#async-api
+ return await llm.client.chat_async(**kwargs)
+
+ return await _achat_with_retry(**kwargs)
+
+
+class ChatGooglePalm(BaseChatModel, BaseModel):
+ """`Google PaLM` Chat models API.
+
+ To use you must have the google.generativeai Python package installed and
+ either:
+
+ 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
+ 2. Pass your API key using the google_api_key kwarg to the ChatGoogle
+ constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatGooglePalm
+ chat = ChatGooglePalm()
+
+ """
+
+ client: Any #: :meta private:
+ model_name: str = "models/chat-bison-001"
+ """Model name to use."""
+ google_api_key: Optional[str] = None
+ temperature: Optional[float] = None
+ """Run inference with this temperature. Must by in the closed
+ interval [0.0, 1.0]."""
+ top_p: Optional[float] = None
+ """Decode using nucleus sampling: consider the smallest set of tokens whose
+ probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
+ top_k: Optional[int] = None
+ """Decode using top-k sampling: consider the set of top_k most probable tokens.
+ Must be positive."""
+ n: int = 1
+ """Number of chat completions to generate for each prompt. Note that the API may
+ not return the full n completions if duplicates are generated."""
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"google_api_key": "GOOGLE_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(self) -> bool:
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "chat_models", "google_palm"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate api key, python package exists, temperature, top_p, and top_k."""
+ google_api_key = get_from_dict_or_env(
+ values, "google_api_key", "GOOGLE_API_KEY"
+ )
+ try:
+ import google.generativeai as genai
+
+ genai.configure(api_key=google_api_key)
+ except ImportError:
+ raise ChatGooglePalmError(
+ "Could not import google.generativeai python package. "
+ "Please install it with `pip install google-generativeai`"
+ )
+
+ values["client"] = genai
+
+ if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
+ raise ValueError("temperature must be in the range [0.0, 1.0]")
+
+ if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
+ raise ValueError("top_p must be in the range [0.0, 1.0]")
+
+ if values["top_k"] is not None and values["top_k"] <= 0:
+ raise ValueError("top_k must be positive")
+
+ return values
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ prompt = _messages_to_prompt_dict(messages)
+
+ response: genai.types.ChatResponse = chat_with_retry(
+ self,
+ model=self.model_name,
+ prompt=prompt,
+ temperature=self.temperature,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ candidate_count=self.n,
+ **kwargs,
+ )
+
+ return _response_to_result(response, stop)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ prompt = _messages_to_prompt_dict(messages)
+
+ response: genai.types.ChatResponse = await achat_with_retry(
+ self,
+ model=self.model_name,
+ prompt=prompt,
+ temperature=self.temperature,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ candidate_count=self.n,
+ )
+
+ return _response_to_result(response, stop)
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model_name": self.model_name,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "n": self.n,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ return "google-palm-chat"
diff --git a/libs/community/langchain_community/chat_models/human.py b/libs/community/langchain_community/chat_models/human.py
new file mode 100644
index 00000000000..0ac1a407c92
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/human.py
@@ -0,0 +1,125 @@
+"""ChatModel wrapper which returns user input as the response.."""
+import asyncio
+from functools import partial
+from io import StringIO
+from typing import Any, Callable, Dict, List, Mapping, Optional
+
+import yaml
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ BaseMessage,
+ HumanMessage,
+ _message_from_dict,
+ messages_to_dict,
+)
+from langchain_core.outputs import ChatGeneration, ChatResult
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+def _display_messages(messages: List[BaseMessage]) -> None:
+ dict_messages = messages_to_dict(messages)
+ for message in dict_messages:
+ yaml_string = yaml.dump(
+ message,
+ default_flow_style=False,
+ sort_keys=False,
+ allow_unicode=True,
+ width=10000,
+ line_break=None,
+ )
+ print("\n", "======= start of message =======", "\n\n")
+ print(yaml_string)
+ print("======= end of message =======", "\n\n")
+
+
+def _collect_yaml_input(
+ messages: List[BaseMessage], stop: Optional[List[str]] = None
+) -> BaseMessage:
+ """Collects and returns user input as a single string."""
+ lines = []
+ while True:
+ line = input()
+ if not line.strip():
+ break
+ if stop and any(seq in line for seq in stop):
+ break
+ lines.append(line)
+ yaml_string = "\n".join(lines)
+
+ # Try to parse the input string as YAML
+ try:
+ message = _message_from_dict(yaml.safe_load(StringIO(yaml_string)))
+ if message is None:
+ return HumanMessage(content="")
+ if stop:
+ if isinstance(message.content, str):
+ message.content = enforce_stop_tokens(message.content, stop)
+ else:
+ raise ValueError("Cannot use when output is not a string.")
+ return message
+ except yaml.YAMLError:
+ raise ValueError("Invalid YAML string entered.")
+ except ValueError:
+ raise ValueError("Invalid message entered.")
+
+
+class HumanInputChatModel(BaseChatModel):
+ """ChatModel which returns user input as the response."""
+
+ input_func: Callable = Field(default_factory=lambda: _collect_yaml_input)
+ message_func: Callable = Field(default_factory=lambda: _display_messages)
+ separator: str = "\n"
+ input_kwargs: Mapping[str, Any] = {}
+ message_kwargs: Mapping[str, Any] = {}
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return {
+ "input_func": self.input_func.__name__,
+ "message_func": self.message_func.__name__,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Returns the type of LLM."""
+ return "human-input-chat-model"
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """
+ Displays the messages to the user and returns their input as a response.
+
+ Args:
+ messages (List[BaseMessage]): The messages to be displayed to the user.
+ stop (Optional[List[str]]): A list of stop strings.
+ run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.
+
+ Returns:
+ ChatResult: The user's input as a response.
+ """
+ self.message_func(messages, **self.message_kwargs)
+ user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
+ return ChatResult(generations=[ChatGeneration(message=user_input)])
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ func = partial(
+ self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await asyncio.get_event_loop().run_in_executor(None, func)
diff --git a/libs/community/langchain_community/chat_models/hunyuan.py b/libs/community/langchain_community/chat_models/hunyuan.py
new file mode 100644
index 00000000000..badbb1f2f6f
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/hunyuan.py
@@ -0,0 +1,321 @@
+import base64
+import hashlib
+import hmac
+import json
+import logging
+import time
+from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
+from urllib.parse import urlparse
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessage,
+ ChatMessageChunk,
+ HumanMessage,
+ HumanMessageChunk,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import (
+ convert_to_secret_str,
+ get_from_dict_or_env,
+ get_pydantic_field_names,
+)
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_API_BASE = "https://hunyuan.cloud.tencent.com"
+DEFAULT_PATH = "/hyllm/v1/chat/completions"
+
+
+def _convert_message_to_dict(message: BaseMessage) -> dict:
+ message_dict: Dict[str, Any]
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ else:
+ raise TypeError(f"Got unknown type {message}")
+
+ return message_dict
+
+
+def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["role"]
+ if role == "user":
+ return HumanMessage(content=_dict["content"])
+ elif role == "assistant":
+ return AIMessage(content=_dict.get("content", "") or "")
+ else:
+ return ChatMessage(content=_dict["content"], role=role)
+
+
+def _convert_delta_to_message_chunk(
+ _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
+) -> BaseMessageChunk:
+ role = _dict.get("role")
+ content = _dict.get("content") or ""
+
+ if role == "user" or default_class == HumanMessageChunk:
+ return HumanMessageChunk(content=content)
+ elif role == "assistant" or default_class == AIMessageChunk:
+ return AIMessageChunk(content=content)
+ elif role or default_class == ChatMessageChunk:
+ return ChatMessageChunk(content=content, role=role)
+ else:
+ return default_class(content=content)
+
+
+# signature generation
+# https://cloud.tencent.com/document/product/1729/97732#532252ce-e960-48a7-8821-940a9ce2ccf3
+def _signature(secret_key: SecretStr, url: str, payload: Dict[str, Any]) -> str:
+ sorted_keys = sorted(payload.keys())
+
+ url_info = urlparse(url)
+
+ sign_str = url_info.netloc + url_info.path + "?"
+
+ for key in sorted_keys:
+ value = payload[key]
+
+ if isinstance(value, list) or isinstance(value, dict):
+ value = json.dumps(value, separators=(",", ":"))
+ elif isinstance(value, float):
+ value = "%g" % value
+
+ sign_str = sign_str + key + "=" + str(value) + "&"
+
+ sign_str = sign_str[:-1]
+
+ hmacstr = hmac.new(
+ key=secret_key.get_secret_value().encode("utf-8"),
+ msg=sign_str.encode("utf-8"),
+ digestmod=hashlib.sha1,
+ ).digest()
+
+ return base64.b64encode(hmacstr).decode("utf-8")
+
+
+def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for choice in response["choices"]:
+ message = _convert_dict_to_message(choice["messages"])
+ generations.append(ChatGeneration(message=message))
+
+ token_usage = response["usage"]
+ llm_output = {"token_usage": token_usage}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+
+class ChatHunyuan(BaseChatModel):
+ """Tencent Hunyuan chat models API by Tencent.
+
+ For more information, see https://cloud.tencent.com/document/product/1729
+ """
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {
+ "hunyuan_app_id": "HUNYUAN_APP_ID",
+ "hunyuan_secret_id": "HUNYUAN_SECRET_ID",
+ "hunyuan_secret_key": "HUNYUAN_SECRET_KEY",
+ }
+
+ @property
+ def lc_serializable(self) -> bool:
+ return True
+
+ hunyuan_api_base: str = Field(default=DEFAULT_API_BASE)
+ """Hunyuan custom endpoints"""
+ hunyuan_app_id: Optional[int] = None
+ """Hunyuan App ID"""
+ hunyuan_secret_id: Optional[str] = None
+ """Hunyuan Secret ID"""
+ hunyuan_secret_key: Optional[SecretStr] = None
+ """Hunyuan Secret Key"""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+ request_timeout: int = 60
+ """Timeout for requests to Hunyuan API. Default is 60 seconds."""
+
+ query_id: Optional[str] = None
+ """Query id for troubleshooting"""
+ temperature: float = 1.0
+ """What sampling temperature to use."""
+ top_p: float = 1.0
+ """What probability mass to use."""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for API call not explicitly specified."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ values["hunyuan_api_base"] = get_from_dict_or_env(
+ values,
+ "hunyuan_api_base",
+ "HUNYUAN_API_BASE",
+ DEFAULT_API_BASE,
+ )
+ values["hunyuan_app_id"] = get_from_dict_or_env(
+ values,
+ "hunyuan_app_id",
+ "HUNYUAN_APP_ID",
+ )
+ values["hunyuan_secret_id"] = get_from_dict_or_env(
+ values,
+ "hunyuan_secret_id",
+ "HUNYUAN_SECRET_ID",
+ )
+ values["hunyuan_secret_key"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "hunyuan_secret_key",
+ "HUNYUAN_SECRET_KEY",
+ )
+ )
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Hunyuan API."""
+ normal_params = {
+ "app_id": self.hunyuan_app_id,
+ "secret_id": self.hunyuan_secret_id,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ }
+
+ if self.query_id is not None:
+ normal_params["query_id"] = self.query_id
+
+ return {**normal_params, **self.model_kwargs}
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._stream(
+ messages=messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ res = self._chat(messages, **kwargs)
+
+ response = res.json()
+
+ if "error" in response:
+ raise ValueError(f"Error from Hunyuan api response: {response}")
+
+ return _create_chat_result(response)
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ res = self._chat(messages, **kwargs)
+
+ default_chunk_class = AIMessageChunk
+ for chunk in res.iter_lines():
+ response = json.loads(chunk)
+ if "error" in response:
+ raise ValueError(f"Error from Hunyuan api response: {response}")
+
+ for choice in response["choices"]:
+ chunk = _convert_delta_to_message_chunk(
+ choice["delta"], default_chunk_class
+ )
+ default_chunk_class = chunk.__class__
+ yield ChatGenerationChunk(message=chunk)
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.content)
+
+ def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
+ if self.hunyuan_secret_key is None:
+ raise ValueError("Hunyuan secret key is not set.")
+
+ parameters = {**self._default_params, **kwargs}
+
+ headers = parameters.pop("headers", {})
+ timestamp = parameters.pop("timestamp", int(time.time()))
+ expired = parameters.pop("expired", timestamp + 24 * 60 * 60)
+
+ payload = {
+ "timestamp": timestamp,
+ "expired": expired,
+ "messages": [_convert_message_to_dict(m) for m in messages],
+ **parameters,
+ }
+
+ if self.streaming:
+ payload["stream"] = 1
+
+ url = self.hunyuan_api_base + DEFAULT_PATH
+
+ res = requests.post(
+ url=url,
+ timeout=self.request_timeout,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": _signature(
+ secret_key=self.hunyuan_secret_key, url=url, payload=payload
+ ),
+ **headers,
+ },
+ json=payload,
+ stream=self.streaming,
+ )
+ return res
+
+ @property
+ def _llm_type(self) -> str:
+ return "hunyuan-chat"
diff --git a/libs/community/langchain_community/chat_models/javelin_ai_gateway.py b/libs/community/langchain_community/chat_models/javelin_ai_gateway.py
new file mode 100644
index 00000000000..6b7001b6260
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/javelin_ai_gateway.py
@@ -0,0 +1,224 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional, cast
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ FunctionMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import (
+ ChatGeneration,
+ ChatResult,
+)
+from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr
+
+logger = logging.getLogger(__name__)
+
+
+# Ignoring type because below is valid pydantic code
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
+class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Parameters for the `Javelin AI Gateway` LLM."""
+
+ temperature: float = 0.0
+ stop: Optional[List[str]] = None
+ max_tokens: Optional[int] = None
+
+
+class ChatJavelinAIGateway(BaseChatModel):
+ """`Javelin AI Gateway` chat models API.
+
+ To use, you should have the ``javelin_sdk`` python package installed.
+ For more information, see https://docs.getjavelin.io
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatJavelinAIGateway
+
+ chat = ChatJavelinAIGateway(
+ gateway_uri="",
+ route="",
+ params={
+ "temperature": 0.1
+ }
+ )
+ """
+
+ route: str
+ """The route to use for the Javelin AI Gateway API."""
+
+ gateway_uri: Optional[str] = None
+ """The URI for the Javelin AI Gateway API."""
+
+ params: Optional[ChatParams] = None
+ """Parameters for the Javelin AI Gateway LLM."""
+
+ client: Any
+ """javelin client."""
+
+ javelin_api_key: Optional[SecretStr] = None
+ """The API key for the Javelin AI Gateway."""
+
+ def __init__(self, **kwargs: Any):
+ try:
+ from javelin_sdk import (
+ JavelinClient,
+ UnauthorizedError,
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import javelin_sdk python package. "
+ "Please install it with `pip install javelin_sdk`."
+ )
+
+ super().__init__(**kwargs)
+ if self.gateway_uri:
+ try:
+ self.client = JavelinClient(
+ base_url=self.gateway_uri,
+ api_key=cast(SecretStr, self.javelin_api_key).get_secret_value(),
+ )
+ except UnauthorizedError as e:
+ raise ValueError("Javelin: Incorrect API Key.") from e
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ params: Dict[str, Any] = {
+ "gateway_uri": self.gateway_uri,
+ "javelin_api_key": cast(SecretStr, self.javelin_api_key).get_secret_value(),
+ "route": self.route,
+ **(self.params.dict() if self.params else {}),
+ }
+ return params
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ message_dicts = [
+ ChatJavelinAIGateway._convert_message_to_dict(message)
+ for message in messages
+ ]
+ data: Dict[str, Any] = {
+ "messages": message_dicts,
+ **(self.params.dict() if self.params else {}),
+ }
+
+ resp = self.client.query_route(self.route, query_body=data)
+
+ return ChatJavelinAIGateway._create_chat_result(resp.dict())
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ message_dicts = [
+ ChatJavelinAIGateway._convert_message_to_dict(message)
+ for message in messages
+ ]
+ data: Dict[str, Any] = {
+ "messages": message_dicts,
+ **(self.params.dict() if self.params else {}),
+ }
+
+ resp = await self.client.aquery_route(self.route, query_body=data)
+
+ return ChatJavelinAIGateway._create_chat_result(resp.dict())
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return self._default_params
+
+ def _get_invocation_params(
+ self, stop: Optional[List[str]] = None, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Get the parameters used to invoke the model FOR THE CALLBACKS."""
+ return {
+ **self._default_params,
+ **super()._get_invocation_params(stop=stop, **kwargs),
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "javelin-ai-gateway-chat"
+
+ @staticmethod
+ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["role"]
+ content = _dict["content"]
+ if role == "user":
+ return HumanMessage(content=content)
+ elif role == "assistant":
+ return AIMessage(content=content)
+ elif role == "system":
+ return SystemMessage(content=content)
+ else:
+ return ChatMessage(content=content, role=role)
+
+ @staticmethod
+ def _raise_functions_not_supported() -> None:
+ raise ValueError(
+ "Function messages are not supported by the Javelin AI Gateway. Please"
+ " create a feature request at https://docs.getjavelin.io"
+ )
+
+ @staticmethod
+ def _convert_message_to_dict(message: BaseMessage) -> dict:
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ raise ValueError(
+ "Function messages are not supported by the Javelin AI Gateway. Please"
+ " create a feature request at https://docs.getjavelin.io"
+ )
+ else:
+ raise ValueError(f"Got unknown message type: {message}")
+
+ if "function_call" in message.additional_kwargs:
+ ChatJavelinAIGateway._raise_functions_not_supported()
+ if message.additional_kwargs:
+ logger.warning(
+ "Additional message arguments are unsupported by Javelin AI Gateway "
+ " and will be ignored: %s",
+ message.additional_kwargs,
+ )
+ return message_dict
+
+ @staticmethod
+ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for candidate in response["llm_response"]["choices"]:
+ message = ChatJavelinAIGateway._convert_dict_to_message(
+ candidate["message"]
+ )
+ message_metadata = candidate.get("metadata", {})
+ gen = ChatGeneration(
+ message=message,
+ generation_info=dict(message_metadata),
+ )
+ generations.append(gen)
+
+ response_metadata = response.get("metadata", {})
+ return ChatResult(generations=generations, llm_output=response_metadata)
diff --git a/libs/community/langchain_community/chat_models/jinachat.py b/libs/community/langchain_community/chat_models/jinachat.py
new file mode 100644
index 00000000000..b234c1e01db
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/jinachat.py
@@ -0,0 +1,410 @@
+"""JinaChat wrapper."""
+from __future__ import annotations
+
+import logging
+from typing import (
+ Any,
+ AsyncIterator,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ agenerate_from_stream,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessage,
+ ChatMessageChunk,
+ FunctionMessage,
+ HumanMessage,
+ HumanMessageChunk,
+ SystemMessage,
+ SystemMessageChunk,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import (
+ convert_to_secret_str,
+ get_from_dict_or_env,
+ get_pydantic_field_names,
+)
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
+ import openai
+
+ min_seconds = 1
+ max_seconds = 60
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(llm.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
+ """Use tenacity to retry the async completion call."""
+ retry_decorator = _create_retry_decorator(llm)
+
+ @retry_decorator
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ # Use OpenAI's async api https://github.com/openai/openai-python#async-api
+ return await llm.client.acreate(**kwargs)
+
+ return await _completion_with_retry(**kwargs)
+
+
+def _convert_delta_to_message_chunk(
+ _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
+) -> BaseMessageChunk:
+ role = _dict.get("role")
+ content = _dict.get("content") or ""
+
+ if role == "user" or default_class == HumanMessageChunk:
+ return HumanMessageChunk(content=content)
+ elif role == "assistant" or default_class == AIMessageChunk:
+ return AIMessageChunk(content=content)
+ elif role == "system" or default_class == SystemMessageChunk:
+ return SystemMessageChunk(content=content)
+ elif role or default_class == ChatMessageChunk:
+ return ChatMessageChunk(content=content, role=role)
+ else:
+ return default_class(content=content)
+
+
+def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["role"]
+ if role == "user":
+ return HumanMessage(content=_dict["content"])
+ elif role == "assistant":
+ content = _dict["content"] or ""
+ return AIMessage(content=content)
+ elif role == "system":
+ return SystemMessage(content=_dict["content"])
+ else:
+ return ChatMessage(content=_dict["content"], role=role)
+
+
+def _convert_message_to_dict(message: BaseMessage) -> dict:
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ message_dict = {
+ "role": "function",
+ "name": message.name,
+ "content": message.content,
+ }
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ if "name" in message.additional_kwargs:
+ message_dict["name"] = message.additional_kwargs["name"]
+ return message_dict
+
+
+class JinaChat(BaseChatModel):
+ """`Jina AI` Chat models API.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``JINACHAT_API_KEY`` set to your API key, which you
+ can generate at https://chat.jina.ai/api.
+
+ Any parameters that are valid to be passed to the openai.create call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import JinaChat
+ chat = JinaChat()
+ """
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"jinachat_api_key": "JINACHAT_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this model can be serialized by Langchain."""
+ return False
+
+ client: Any #: :meta private:
+ temperature: float = 0.7
+ """What sampling temperature to use."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not explicitly specified."""
+ jinachat_api_key: Optional[SecretStr] = None
+ """Base URL path for API requests,
+ leave blank if not using a proxy or service emulator."""
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
+ """Timeout for requests to JinaChat completion API. Default is 600 seconds."""
+ max_retries: int = 6
+ """Maximum number of retries to make when generating."""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+ max_tokens: Optional[int] = None
+ """Maximum number of tokens to generate."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["jinachat_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "jinachat_api_key", "JINACHAT_API_KEY")
+ )
+ try:
+ import openai
+
+ except ImportError:
+ raise ValueError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ try:
+ values["client"] = openai.ChatCompletion
+ except AttributeError:
+ raise ValueError(
+ "`openai` has no `ChatCompletion` attribute, this is likely "
+ "due to an old version of the openai package. Try upgrading it "
+ "with `pip install --upgrade openai`."
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling JinaChat API."""
+ return {
+ "request_timeout": self.request_timeout,
+ "max_tokens": self.max_tokens,
+ "stream": self.streaming,
+ "temperature": self.temperature,
+ **self.model_kwargs,
+ }
+
+ def _create_retry_decorator(self) -> Callable[[Any], Any]:
+ import openai
+
+ min_seconds = 1
+ max_seconds = 60
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(self.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+ def completion_with_retry(self, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = self._create_retry_decorator()
+
+ @retry_decorator
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return self.client.create(**kwargs)
+
+ return _completion_with_retry(**kwargs)
+
+ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
+ overall_token_usage: dict = {}
+ for output in llm_outputs:
+ if output is None:
+ # Happens in streaming
+ continue
+ token_usage = output["token_usage"]
+ for k, v in token_usage.items():
+ if k in overall_token_usage:
+ overall_token_usage[k] += v
+ else:
+ overall_token_usage[k] = v
+ return {"token_usage": overall_token_usage}
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs, "stream": True}
+
+ default_chunk_class = AIMessageChunk
+ for chunk in self.completion_with_retry(messages=message_dicts, **params):
+ delta = chunk["choices"][0]["delta"]
+ chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
+ default_chunk_class = chunk.__class__
+ yield ChatGenerationChunk(message=chunk)
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.content)
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._stream(
+ messages=messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs}
+ response = self.completion_with_retry(messages=message_dicts, **params)
+ return self._create_chat_result(response)
+
+ def _create_message_dicts(
+ self, messages: List[BaseMessage], stop: Optional[List[str]]
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ params = dict(self._invocation_params)
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+ message_dicts = [_convert_message_to_dict(m) for m in messages]
+ return message_dicts, params
+
+ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for res in response["choices"]:
+ message = _convert_dict_to_message(res["message"])
+ gen = ChatGeneration(message=message)
+ generations.append(gen)
+ llm_output = {"token_usage": response["usage"]}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs, "stream": True}
+
+ default_chunk_class = AIMessageChunk
+ async for chunk in await acompletion_with_retry(
+ self, messages=message_dicts, **params
+ ):
+ delta = chunk["choices"][0]["delta"]
+ chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
+ default_chunk_class = chunk.__class__
+ yield ChatGenerationChunk(message=chunk)
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.content)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._astream(
+ messages=messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await agenerate_from_stream(stream_iter)
+
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs}
+ response = await acompletion_with_retry(self, messages=message_dicts, **params)
+ return self._create_chat_result(response)
+
+ @property
+ def _invocation_params(self) -> Mapping[str, Any]:
+ """Get the parameters used to invoke the model."""
+ jinachat_creds: Dict[str, Any] = {
+ "api_key": self.jinachat_api_key
+ and self.jinachat_api_key.get_secret_value(),
+ "api_base": "https://api.chat.jina.ai/v1",
+ "model": "jinachat",
+ }
+ return {**jinachat_creds, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "jinachat"
diff --git a/libs/community/langchain_community/chat_models/konko.py b/libs/community/langchain_community/chat_models/konko.py
new file mode 100644
index 00000000000..ff88bd417f3
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/konko.py
@@ -0,0 +1,295 @@
+"""KonkoAI chat wrapper."""
+from __future__ import annotations
+
+import logging
+import os
+from typing import (
+ Any,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
+
+import requests
+from langchain_core.callbacks import (
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ generate_from_stream,
+)
+from langchain_core.messages import AIMessageChunk, BaseMessage
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.adapters.openai import (
+ convert_dict_to_message,
+ convert_message_to_dict,
+)
+from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
+
+DEFAULT_API_BASE = "https://api.konko.ai/v1"
+DEFAULT_MODEL = "meta-llama/Llama-2-13b-chat-hf"
+
+logger = logging.getLogger(__name__)
+
+
+class ChatKonko(BaseChatModel):
+ """`ChatKonko` Chat large language models API.
+
+ To use, you should have the ``konko`` python package installed, and the
+ environment variable ``KONKO_API_KEY`` and ``OPENAI_API_KEY`` set with your API key.
+
+ Any parameters that are valid to be passed to the konko.create call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatKonko
+ llm = ChatKonko(model="meta-llama/Llama-2-13b-chat-hf")
+ """
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"konko_api_key": "KONKO_API_KEY", "openai_api_key": "OPENAI_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this model can be serialized by Langchain."""
+ return False
+
+ client: Any = None #: :meta private:
+ model: str = Field(default=DEFAULT_MODEL, alias="model")
+ """Model name to use."""
+ temperature: float = 0.7
+ """What sampling temperature to use."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not explicitly specified."""
+ openai_api_key: Optional[str] = None
+ konko_api_key: Optional[str] = None
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
+ """Timeout for requests to Konko completion API."""
+ max_retries: int = 6
+ """Maximum number of retries to make when generating."""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+ n: int = 1
+ """Number of chat completions to generate for each prompt."""
+ max_tokens: int = 20
+ """Maximum number of tokens to generate."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["konko_api_key"] = get_from_dict_or_env(
+ values, "konko_api_key", "KONKO_API_KEY"
+ )
+ try:
+ import konko
+
+ except ImportError:
+ raise ValueError(
+ "Could not import konko python package. "
+ "Please install it with `pip install konko`."
+ )
+ try:
+ values["client"] = konko.ChatCompletion
+ except AttributeError:
+ raise ValueError(
+ "`konko` has no `ChatCompletion` attribute, this is likely "
+ "due to an old version of the konko package. Try upgrading it "
+ "with `pip install --upgrade konko`."
+ )
+ if values["n"] < 1:
+ raise ValueError("n must be at least 1.")
+ if values["n"] > 1 and values["streaming"]:
+ raise ValueError("n must be 1 when streaming.")
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Konko API."""
+ return {
+ "model": self.model,
+ "request_timeout": self.request_timeout,
+ "max_tokens": self.max_tokens,
+ "stream": self.streaming,
+ "n": self.n,
+ "temperature": self.temperature,
+ **self.model_kwargs,
+ }
+
+ @staticmethod
+ def get_available_models(
+ konko_api_key: Optional[str] = None,
+ openai_api_key: Optional[str] = None,
+ konko_api_base: str = DEFAULT_API_BASE,
+ ) -> Set[str]:
+ """Get available models from Konko API."""
+
+ # Try to retrieve the OpenAI API key if it's not passed as an argument
+ if not openai_api_key:
+ try:
+ openai_api_key = os.environ["OPENAI_API_KEY"]
+ except KeyError:
+ pass # It's okay if it's not set, we just won't use it
+
+ # Try to retrieve the Konko API key if it's not passed as an argument
+ if not konko_api_key:
+ try:
+ konko_api_key = os.environ["KONKO_API_KEY"]
+ except KeyError:
+ raise ValueError(
+ "Konko API key must be passed as keyword argument or "
+ "set in environment variable KONKO_API_KEY."
+ )
+
+ models_url = f"{konko_api_base}/models"
+
+ headers = {
+ "Authorization": f"Bearer {konko_api_key}",
+ }
+
+ if openai_api_key:
+ headers["X-OpenAI-Api-Key"] = openai_api_key
+
+ models_response = requests.get(models_url, headers=headers)
+
+ if models_response.status_code != 200:
+ raise ValueError(
+ f"Error getting models from {models_url}: "
+ f"{models_response.status_code}"
+ )
+
+ return {model["id"] for model in models_response.json()["data"]}
+
+ def completion_with_retry(
+ self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
+ ) -> Any:
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return self.client.create(**kwargs)
+
+ return _completion_with_retry(**kwargs)
+
+ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
+ overall_token_usage: dict = {}
+ for output in llm_outputs:
+ if output is None:
+ # Happens in streaming
+ continue
+ token_usage = output["token_usage"]
+ for k, v in token_usage.items():
+ if k in overall_token_usage:
+ overall_token_usage[k] += v
+ else:
+ overall_token_usage[k] = v
+ return {"token_usage": overall_token_usage, "model_name": self.model}
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs, "stream": True}
+
+ default_chunk_class = AIMessageChunk
+ for chunk in self.completion_with_retry(
+ messages=message_dicts, run_manager=run_manager, **params
+ ):
+ if len(chunk["choices"]) == 0:
+ continue
+ choice = chunk["choices"][0]
+ chunk = _convert_delta_to_message_chunk(
+ choice["delta"], default_chunk_class
+ )
+ finish_reason = choice.get("finish_reason")
+ generation_info = (
+ dict(finish_reason=finish_reason) if finish_reason is not None else None
+ )
+ default_chunk_class = chunk.__class__
+ chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._stream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs}
+ response = self.completion_with_retry(
+ messages=message_dicts, run_manager=run_manager, **params
+ )
+ return self._create_chat_result(response)
+
+ def _create_message_dicts(
+ self, messages: List[BaseMessage], stop: Optional[List[str]]
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ params = self._client_params
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+ message_dicts = [convert_message_to_dict(m) for m in messages]
+ return message_dicts, params
+
+ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for res in response["choices"]:
+ message = convert_dict_to_message(res["message"])
+ gen = ChatGeneration(
+ message=message,
+ generation_info=dict(finish_reason=res.get("finish_reason")),
+ )
+ generations.append(gen)
+ token_usage = response.get("usage", {})
+ llm_output = {"token_usage": token_usage, "model_name": self.model}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_name": self.model}, **self._default_params}
+
+ @property
+ def _client_params(self) -> Dict[str, Any]:
+ """Get the parameters used for the konko client."""
+ return {**self._default_params}
+
+ def _get_invocation_params(
+ self, stop: Optional[List[str]] = None, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Get the parameters used to invoke the model."""
+ return {
+ "model": self.model,
+ **super()._get_invocation_params(stop=stop),
+ **self._default_params,
+ **kwargs,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "konko-chat"
diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py
new file mode 100644
index 00000000000..fb30d7463c1
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/litellm.py
@@ -0,0 +1,422 @@
+"""Wrapper around LiteLLM's model I/O library."""
+from __future__ import annotations
+
+import logging
+from typing import (
+ Any,
+ AsyncIterator,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ agenerate_from_stream,
+ generate_from_stream,
+)
+from langchain_core.language_models.llms import create_base_retry_decorator
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessage,
+ ChatMessageChunk,
+ FunctionMessage,
+ FunctionMessageChunk,
+ HumanMessage,
+ HumanMessageChunk,
+ SystemMessage,
+ SystemMessageChunk,
+)
+from langchain_core.outputs import (
+ ChatGeneration,
+ ChatGenerationChunk,
+ ChatResult,
+)
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class ChatLiteLLMException(Exception):
+ """Error with the `LiteLLM I/O` library"""
+
+
+def _create_retry_decorator(
+ llm: ChatLiteLLM,
+ run_manager: Optional[
+ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
+ ] = None,
+) -> Callable[[Any], Any]:
+ """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
+ import litellm
+
+ errors = [
+ litellm.Timeout,
+ litellm.APIError,
+ litellm.APIConnectionError,
+ litellm.RateLimitError,
+ ]
+ return create_base_retry_decorator(
+ error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
+ )
+
+
+def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["role"]
+ if role == "user":
+ return HumanMessage(content=_dict["content"])
+ elif role == "assistant":
+ # Fix for azure
+ # Also OpenAI returns None for tool invocations
+ content = _dict.get("content", "") or ""
+ if _dict.get("function_call"):
+ additional_kwargs = {"function_call": dict(_dict["function_call"])}
+ else:
+ additional_kwargs = {}
+ return AIMessage(content=content, additional_kwargs=additional_kwargs)
+ elif role == "system":
+ return SystemMessage(content=_dict["content"])
+ elif role == "function":
+ return FunctionMessage(content=_dict["content"], name=_dict["name"])
+ else:
+ return ChatMessage(content=_dict["content"], role=role)
+
+
+async def acompletion_with_retry(
+ llm: ChatLiteLLM,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the async completion call."""
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @retry_decorator
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ # Use OpenAI's async api https://github.com/openai/openai-python#async-api
+ return await llm.client.acreate(**kwargs)
+
+ return await _completion_with_retry(**kwargs)
+
+
+def _convert_delta_to_message_chunk(
+ _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
+) -> BaseMessageChunk:
+ role = _dict.get("role")
+ content = _dict.get("content") or ""
+ if _dict.get("function_call"):
+ additional_kwargs = {"function_call": dict(_dict["function_call"])}
+ else:
+ additional_kwargs = {}
+
+ if role == "user" or default_class == HumanMessageChunk:
+ return HumanMessageChunk(content=content)
+ elif role == "assistant" or default_class == AIMessageChunk:
+ return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
+ elif role == "system" or default_class == SystemMessageChunk:
+ return SystemMessageChunk(content=content)
+ elif role == "function" or default_class == FunctionMessageChunk:
+ return FunctionMessageChunk(content=content, name=_dict["name"])
+ elif role or default_class == ChatMessageChunk:
+ return ChatMessageChunk(content=content, role=role)
+ else:
+ return default_class(content=content)
+
+
+def _convert_message_to_dict(message: BaseMessage) -> dict:
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ if "function_call" in message.additional_kwargs:
+ message_dict["function_call"] = message.additional_kwargs["function_call"]
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ message_dict = {
+ "role": "function",
+ "content": message.content,
+ "name": message.name,
+ }
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ if "name" in message.additional_kwargs:
+ message_dict["name"] = message.additional_kwargs["name"]
+ return message_dict
+
+
+class ChatLiteLLM(BaseChatModel):
+ """A chat model that uses the LiteLLM API."""
+
+ client: Any #: :meta private:
+ model: str = "gpt-3.5-turbo"
+ model_name: Optional[str] = None
+ """Model name to use."""
+ openai_api_key: Optional[str] = None
+ azure_api_key: Optional[str] = None
+ anthropic_api_key: Optional[str] = None
+ replicate_api_key: Optional[str] = None
+ cohere_api_key: Optional[str] = None
+ openrouter_api_key: Optional[str] = None
+ streaming: bool = False
+ api_base: Optional[str] = None
+ organization: Optional[str] = None
+ custom_llm_provider: Optional[str] = None
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
+ temperature: Optional[float] = 1
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Run inference with this temperature. Must by in the closed
+ interval [0.0, 1.0]."""
+ top_p: Optional[float] = None
+ """Decode using nucleus sampling: consider the smallest set of tokens whose
+ probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
+ top_k: Optional[int] = None
+ """Decode using top-k sampling: consider the set of top_k most probable tokens.
+ Must be positive."""
+ n: int = 1
+ """Number of chat completions to generate for each prompt. Note that the API may
+ not return the full n completions if duplicates are generated."""
+ max_tokens: int = 256
+
+ max_retries: int = 6
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling OpenAI API."""
+ set_model_value = self.model
+ if self.model_name is not None:
+ set_model_value = self.model_name
+ return {
+ "model": set_model_value,
+ "force_timeout": self.request_timeout,
+ "max_tokens": self.max_tokens,
+ "stream": self.streaming,
+ "n": self.n,
+ "temperature": self.temperature,
+ "custom_llm_provider": self.custom_llm_provider,
+ **self.model_kwargs,
+ }
+
+ @property
+ def _client_params(self) -> Dict[str, Any]:
+ """Get the parameters used for the openai client."""
+ set_model_value = self.model
+ if self.model_name is not None:
+ set_model_value = self.model_name
+ self.client.api_base = self.api_base
+ self.client.organization = self.organization
+ creds: Dict[str, Any] = {
+ "model": set_model_value,
+ "force_timeout": self.request_timeout,
+ }
+ return {**self._default_params, **creds}
+
+ def completion_with_retry(
+ self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
+ ) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
+
+ @retry_decorator
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return self.client.completion(**kwargs)
+
+ return _completion_with_retry(**kwargs)
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate api key, python package exists, temperature, top_p, and top_k."""
+ try:
+ import litellm
+ except ImportError:
+ raise ChatLiteLLMException(
+ "Could not import google.generativeai python package. "
+ "Please install it with `pip install google-generativeai`"
+ )
+
+ values["openai_api_key"] = get_from_dict_or_env(
+ values, "openai_api_key", "OPENAI_API_KEY", default=""
+ )
+ values["azure_api_key"] = get_from_dict_or_env(
+ values, "azure_api_key", "AZURE_API_KEY", default=""
+ )
+ values["anthropic_api_key"] = get_from_dict_or_env(
+ values, "anthropic_api_key", "ANTHROPIC_API_KEY", default=""
+ )
+ values["replicate_api_key"] = get_from_dict_or_env(
+ values, "replicate_api_key", "REPLICATE_API_KEY", default=""
+ )
+ values["openrouter_api_key"] = get_from_dict_or_env(
+ values, "openrouter_api_key", "OPENROUTER_API_KEY", default=""
+ )
+ values["cohere_api_key"] = get_from_dict_or_env(
+ values, "cohere_api_key", "COHERE_API_KEY", default=""
+ )
+ values["huggingface_api_key"] = get_from_dict_or_env(
+ values, "huggingface_api_key", "HUGGINGFACE_API_KEY", default=""
+ )
+ values["together_ai_api_key"] = get_from_dict_or_env(
+ values, "together_ai_api_key", "TOGETHERAI_API_KEY", default=""
+ )
+ values["client"] = litellm
+
+ if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
+ raise ValueError("temperature must be in the range [0.0, 1.0]")
+
+ if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
+ raise ValueError("top_p must be in the range [0.0, 1.0]")
+
+ if values["top_k"] is not None and values["top_k"] <= 0:
+ raise ValueError("top_k must be positive")
+
+ return values
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._stream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs}
+ response = self.completion_with_retry(
+ messages=message_dicts, run_manager=run_manager, **params
+ )
+ return self._create_chat_result(response)
+
+ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for res in response["choices"]:
+ message = _convert_dict_to_message(res["message"])
+ gen = ChatGeneration(
+ message=message,
+ generation_info=dict(finish_reason=res.get("finish_reason")),
+ )
+ generations.append(gen)
+ token_usage = response.get("usage", {})
+ set_model_value = self.model
+ if self.model_name is not None:
+ set_model_value = self.model_name
+ llm_output = {"token_usage": token_usage, "model": set_model_value}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ def _create_message_dicts(
+ self, messages: List[BaseMessage], stop: Optional[List[str]]
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ params = self._client_params
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+ message_dicts = [_convert_message_to_dict(m) for m in messages]
+ return message_dicts, params
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs, "stream": True}
+
+ default_chunk_class = AIMessageChunk
+ for chunk in self.completion_with_retry(
+ messages=message_dicts, run_manager=run_manager, **params
+ ):
+ if len(chunk["choices"]) == 0:
+ continue
+ delta = chunk["choices"][0]["delta"]
+ chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
+ default_chunk_class = chunk.__class__
+ yield ChatGenerationChunk(message=chunk)
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.content)
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs, "stream": True}
+
+ default_chunk_class = AIMessageChunk
+ async for chunk in await acompletion_with_retry(
+ self, messages=message_dicts, run_manager=run_manager, **params
+ ):
+ if len(chunk["choices"]) == 0:
+ continue
+ delta = chunk["choices"][0]["delta"]
+ chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
+ default_chunk_class = chunk.__class__
+ yield ChatGenerationChunk(message=chunk)
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.content)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._astream(
+ messages=messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await agenerate_from_stream(stream_iter)
+
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs}
+ response = await acompletion_with_retry(
+ self, messages=message_dicts, run_manager=run_manager, **params
+ )
+ return self._create_chat_result(response)
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ set_model_value = self.model
+ if self.model_name is not None:
+ set_model_value = self.model_name
+ return {
+ "model": set_model_value,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "n": self.n,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ return "litellm-chat"
diff --git a/libs/community/langchain_community/chat_models/meta.py b/libs/community/langchain_community/chat_models/meta.py
new file mode 100644
index 00000000000..038561c7f63
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/meta.py
@@ -0,0 +1,29 @@
+from typing import List
+
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+
+
+def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
+ if isinstance(message, ChatMessage):
+ message_text = f"\n\n{message.role.capitalize()}: {message.content}"
+ elif isinstance(message, HumanMessage):
+ message_text = f"[INST] {message.content} [/INST]"
+ elif isinstance(message, AIMessage):
+ message_text = f"{message.content}"
+ elif isinstance(message, SystemMessage):
+ message_text = f"<> {message.content} <>"
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ return message_text
+
+
+def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
+ return "\n".join(
+ [_convert_one_message_to_text_llama(message) for message in messages]
+ )
diff --git a/libs/community/langchain_community/chat_models/minimax.py b/libs/community/langchain_community/chat_models/minimax.py
new file mode 100644
index 00000000000..f2385510a43
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/minimax.py
@@ -0,0 +1,95 @@
+"""Wrapper around Minimax chat models."""
+import logging
+from typing import Any, Dict, List, Optional, cast
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ HumanMessage,
+)
+from langchain_core.outputs import ChatResult
+
+from langchain_community.llms.minimax import MinimaxCommon
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+def _parse_message(msg_type: str, text: str) -> Dict:
+ return {"sender_type": msg_type, "text": text}
+
+
+def _parse_chat_history(history: List[BaseMessage]) -> List:
+ """Parse a sequence of messages into history."""
+ chat_history = []
+ for message in history:
+ content = cast(str, message.content)
+ if isinstance(message, HumanMessage):
+ chat_history.append(_parse_message("USER", content))
+ if isinstance(message, AIMessage):
+ chat_history.append(_parse_message("BOT", content))
+ return chat_history
+
+
+class MiniMaxChat(MinimaxCommon, BaseChatModel):
+ """Wrapper around Minimax large language models.
+
+ To use, you should have the environment variable ``MINIMAX_GROUP_ID`` and
+ ``MINIMAX_API_KEY`` set with your API token, or pass it as a named parameter to
+ the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import MiniMaxChat
+ llm = MiniMaxChat(model_name="abab5-chat")
+
+ """
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Generate next turn in the conversation.
+ Args:
+ messages: The history of the conversation as a list of messages. Code chat
+ does not support context.
+ stop: The list of stop words (optional).
+ run_manager: The CallbackManager for LLM run, it's not used at the moment.
+
+ Returns:
+ The ChatResult that contains outputs generated by the model.
+
+ Raises:
+ ValueError: if the last message in the list is not from human.
+ """
+ if not messages:
+ raise ValueError(
+ "You should provide at least one message to start the chat!"
+ )
+ history = _parse_chat_history(messages)
+ payload = self._default_params
+ payload["messages"] = history
+ text = self._client.post(payload)
+
+ # This is required since the stop are not enforced by the model parameters
+ return text if stop is None else enforce_stop_tokens(text, stop)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ raise NotImplementedError(
+ """Minimax AI doesn't support async requests at the moment."""
+ )
diff --git a/libs/community/langchain_community/chat_models/mlflow.py b/libs/community/langchain_community/chat_models/mlflow.py
new file mode 100644
index 00000000000..ee289527bb0
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/mlflow.py
@@ -0,0 +1,219 @@
+import asyncio
+import logging
+from functools import partial
+from typing import Any, Dict, List, Mapping, Optional
+from urllib.parse import urlparse
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ FunctionMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatResult
+from langchain_core.pydantic_v1 import (
+ Field,
+ PrivateAttr,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ChatMlflow(BaseChatModel):
+ """`MLflow` chat models API.
+
+ To use, you should have the `mlflow[genai]` python package installed.
+ For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatMlflow
+
+ chat = ChatMlflow(
+ target_uri="http://localhost:5000",
+ endpoint="chat",
+ temperature-0.1,
+ )
+ """
+
+ endpoint: str
+ """The endpoint to use."""
+ target_uri: str
+ """The target URI to use."""
+ temperature: float = 0.0
+ """The sampling temperature."""
+ n: int = 1
+ """The number of completion choices to generate."""
+ stop: Optional[List[str]] = None
+ """The stop sequence."""
+ max_tokens: Optional[int] = None
+ """The maximum number of tokens to generate."""
+ extra_params: dict = Field(default_factory=dict)
+ """Any extra parameters to pass to the endpoint."""
+ _client: Any = PrivateAttr()
+
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
+ self._validate_uri()
+ try:
+ from mlflow.deployments import get_deploy_client
+
+ self._client = get_deploy_client(self.target_uri)
+ except ImportError as e:
+ raise ImportError(
+ "Failed to create the client. "
+ f"Please run `pip install mlflow{self._mlflow_extras}` to install "
+ "required dependencies."
+ ) from e
+
+ @property
+ def _mlflow_extras(self) -> str:
+ return "[genai]"
+
+ def _validate_uri(self) -> None:
+ if self.target_uri == "databricks":
+ return
+ allowed = ["http", "https", "databricks"]
+ if urlparse(self.target_uri).scheme not in allowed:
+ raise ValueError(
+ f"Invalid target URI: {self.target_uri}. "
+ f"The scheme must be one of {allowed}."
+ )
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ params: Dict[str, Any] = {
+ "target_uri": self.target_uri,
+ "endpoint": self.endpoint,
+ "temperature": self.temperature,
+ "n": self.n,
+ "stop": self.stop,
+ "max_tokens": self.max_tokens,
+ "extra_params": self.extra_params,
+ }
+ return params
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ message_dicts = [
+ ChatMlflow._convert_message_to_dict(message) for message in messages
+ ]
+ data: Dict[str, Any] = {
+ "messages": message_dicts,
+ "temperature": self.temperature,
+ "n": self.n,
+ **self.extra_params,
+ **kwargs,
+ }
+ if stop := self.stop or stop:
+ data["stop"] = stop
+ if self.max_tokens is not None:
+ data["max_tokens"] = self.max_tokens
+ resp = self._client.predict(endpoint=self.endpoint, inputs=data)
+ return ChatMlflow._create_chat_result(resp)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ func = partial(
+ self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await asyncio.get_event_loop().run_in_executor(None, func)
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return self._default_params
+
+ def _get_invocation_params(
+ self, stop: Optional[List[str]] = None, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Get the parameters used to invoke the model FOR THE CALLBACKS."""
+ return {
+ **self._default_params,
+ **super()._get_invocation_params(stop=stop, **kwargs),
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "mlflow-chat"
+
+ @staticmethod
+ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["role"]
+ content = _dict["content"]
+ if role == "user":
+ return HumanMessage(content=content)
+ elif role == "assistant":
+ return AIMessage(content=content)
+ elif role == "system":
+ return SystemMessage(content=content)
+ else:
+ return ChatMessage(content=content, role=role)
+
+ @staticmethod
+ def _raise_functions_not_supported() -> None:
+ raise ValueError(
+ "Function messages are not supported by Databricks. Please"
+ " create a feature request at https://github.com/mlflow/mlflow/issues."
+ )
+
+ @staticmethod
+ def _convert_message_to_dict(message: BaseMessage) -> dict:
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ raise ValueError(
+ "Function messages are not supported by Databricks. Please"
+ " create a feature request at https://github.com/mlflow/mlflow/issues."
+ )
+ else:
+ raise ValueError(f"Got unknown message type: {message}")
+
+ if "function_call" in message.additional_kwargs:
+ ChatMlflow._raise_functions_not_supported()
+ if message.additional_kwargs:
+ logger.warning(
+ "Additional message arguments are unsupported by Databricks"
+ " and will be ignored: %s",
+ message.additional_kwargs,
+ )
+ return message_dict
+
+ @staticmethod
+ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for choice in response["choices"]:
+ message = ChatMlflow._convert_dict_to_message(choice["message"])
+ usage = choice.get("usage", {})
+ gen = ChatGeneration(
+ message=message,
+ generation_info=usage,
+ )
+ generations.append(gen)
+
+ usage = response.get("usage", {})
+ return ChatResult(generations=generations, llm_output=usage)
diff --git a/libs/community/langchain_community/chat_models/mlflow_ai_gateway.py b/libs/community/langchain_community/chat_models/mlflow_ai_gateway.py
new file mode 100644
index 00000000000..5674f69fc2c
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/mlflow_ai_gateway.py
@@ -0,0 +1,210 @@
+import asyncio
+import logging
+import warnings
+from functools import partial
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ FunctionMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import (
+ ChatGeneration,
+ ChatResult,
+)
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+logger = logging.getLogger(__name__)
+
+
+# Ignoring type because below is valid pydantic code
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
+class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Parameters for the `MLflow AI Gateway` LLM."""
+
+ temperature: float = 0.0
+ candidate_count: int = 1
+ """The number of candidates to return."""
+ stop: Optional[List[str]] = None
+ max_tokens: Optional[int] = None
+
+
+class ChatMLflowAIGateway(BaseChatModel):
+ """`MLflow AI Gateway` chat models API.
+
+ To use, you should have the ``mlflow[gateway]`` python package installed.
+ For more information, see https://mlflow.org/docs/latest/gateway/index.html.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatMLflowAIGateway
+
+ chat = ChatMLflowAIGateway(
+ gateway_uri="",
+ route="",
+ params={
+ "temperature": 0.1
+ }
+ )
+ """
+
+ def __init__(self, **kwargs: Any):
+ warnings.warn(
+ "`ChatMLflowAIGateway` is deprecated. Use `ChatMlflow` or "
+ "`ChatDatabricks` instead.",
+ DeprecationWarning,
+ )
+ try:
+ import mlflow.gateway
+ except ImportError as e:
+ raise ImportError(
+ "Could not import `mlflow.gateway` module. "
+ "Please install it with `pip install mlflow[gateway]`."
+ ) from e
+
+ super().__init__(**kwargs)
+ if self.gateway_uri:
+ mlflow.gateway.set_gateway_uri(self.gateway_uri)
+
+ route: str
+ gateway_uri: Optional[str] = None
+ params: Optional[ChatParams] = None
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ params: Dict[str, Any] = {
+ "gateway_uri": self.gateway_uri,
+ "route": self.route,
+ **(self.params.dict() if self.params else {}),
+ }
+ return params
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ try:
+ import mlflow.gateway
+ except ImportError as e:
+ raise ImportError(
+ "Could not import `mlflow.gateway` module. "
+ "Please install it with `pip install mlflow[gateway]`."
+ ) from e
+
+ message_dicts = [
+ ChatMLflowAIGateway._convert_message_to_dict(message)
+ for message in messages
+ ]
+ data: Dict[str, Any] = {
+ "messages": message_dicts,
+ **(self.params.dict() if self.params else {}),
+ }
+
+ resp = mlflow.gateway.query(self.route, data=data)
+ return ChatMLflowAIGateway._create_chat_result(resp)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ func = partial(
+ self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await asyncio.get_event_loop().run_in_executor(None, func)
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return self._default_params
+
+ def _get_invocation_params(
+ self, stop: Optional[List[str]] = None, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Get the parameters used to invoke the model FOR THE CALLBACKS."""
+ return {
+ **self._default_params,
+ **super()._get_invocation_params(stop=stop, **kwargs),
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "mlflow-ai-gateway-chat"
+
+ @staticmethod
+ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["role"]
+ content = _dict["content"]
+ if role == "user":
+ return HumanMessage(content=content)
+ elif role == "assistant":
+ return AIMessage(content=content)
+ elif role == "system":
+ return SystemMessage(content=content)
+ else:
+ return ChatMessage(content=content, role=role)
+
+ @staticmethod
+ def _raise_functions_not_supported() -> None:
+ raise ValueError(
+ "Function messages are not supported by the MLflow AI Gateway. Please"
+ " create a feature request at https://github.com/mlflow/mlflow/issues."
+ )
+
+ @staticmethod
+ def _convert_message_to_dict(message: BaseMessage) -> dict:
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ raise ValueError(
+ "Function messages are not supported by the MLflow AI Gateway. Please"
+ " create a feature request at https://github.com/mlflow/mlflow/issues."
+ )
+ else:
+ raise ValueError(f"Got unknown message type: {message}")
+
+ if "function_call" in message.additional_kwargs:
+ ChatMLflowAIGateway._raise_functions_not_supported()
+ if message.additional_kwargs:
+ logger.warning(
+ "Additional message arguments are unsupported by MLflow AI Gateway "
+ " and will be ignored: %s",
+ message.additional_kwargs,
+ )
+ return message_dict
+
+ @staticmethod
+ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for candidate in response["candidates"]:
+ message = ChatMLflowAIGateway._convert_dict_to_message(candidate["message"])
+ message_metadata = candidate.get("metadata", {})
+ gen = ChatGeneration(
+ message=message,
+ generation_info=dict(message_metadata),
+ )
+ generations.append(gen)
+
+ response_metadata = response.get("metadata", {})
+ return ChatResult(generations=generations, llm_output=response_metadata)
diff --git a/libs/community/langchain_community/chat_models/ollama.py b/libs/community/langchain_community/chat_models/ollama.py
new file mode 100644
index 00000000000..91dda64e45e
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/ollama.py
@@ -0,0 +1,123 @@
+import json
+from typing import Any, Iterator, List, Optional
+
+from langchain_core.callbacks import (
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+
+from langchain_community.llms.ollama import _OllamaCommon
+
+
+def _stream_response_to_chat_generation_chunk(
+ stream_response: str,
+) -> ChatGenerationChunk:
+ """Convert a stream response to a generation chunk."""
+ parsed_response = json.loads(stream_response)
+ generation_info = parsed_response if parsed_response.get("done") is True else None
+ return ChatGenerationChunk(
+ message=AIMessageChunk(content=parsed_response.get("response", "")),
+ generation_info=generation_info,
+ )
+
+
+class ChatOllama(BaseChatModel, _OllamaCommon):
+ """Ollama locally runs large language models.
+
+ To use, follow the instructions at https://ollama.ai/.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatOllama
+ ollama = ChatOllama(model="llama2")
+ """
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "ollama-chat"
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this model can be serialized by Langchain."""
+ return False
+
+ def _format_message_as_text(self, message: BaseMessage) -> str:
+ if isinstance(message, ChatMessage):
+ message_text = f"\n\n{message.role.capitalize()}: {message.content}"
+ elif isinstance(message, HumanMessage):
+ message_text = f"[INST] {message.content} [/INST]"
+ elif isinstance(message, AIMessage):
+ message_text = f"{message.content}"
+ elif isinstance(message, SystemMessage):
+ message_text = f"<> {message.content} <>"
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ return message_text
+
+ def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
+ return "\n".join(
+ [self._format_message_as_text(message) for message in messages]
+ )
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Call out to Ollama's generate endpoint.
+
+ Args:
+ messages: The list of base messages to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ Chat generations from the model
+
+ Example:
+ .. code-block:: python
+
+ response = ollama([
+ HumanMessage(content="Tell me about the history of AI")
+ ])
+ """
+
+ prompt = self._format_messages_as_text(messages)
+ final_chunk = super()._stream_with_aggregation(
+ prompt, stop=stop, run_manager=run_manager, verbose=self.verbose, **kwargs
+ )
+ chat_generation = ChatGeneration(
+ message=AIMessage(content=final_chunk.text),
+ generation_info=final_chunk.generation_info,
+ )
+ return ChatResult(generations=[chat_generation])
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ prompt = self._format_messages_as_text(messages)
+ for stream_resp in self._create_stream(prompt, stop, **kwargs):
+ if stream_resp:
+ chunk = _stream_response_to_chat_generation_chunk(stream_resp)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(
+ chunk.text,
+ verbose=self.verbose,
+ )
diff --git a/libs/community/langchain_community/chat_models/openai.py b/libs/community/langchain_community/chat_models/openai.py
new file mode 100644
index 00000000000..1bb86bf0851
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/openai.py
@@ -0,0 +1,675 @@
+"""OpenAI chat wrapper."""
+from __future__ import annotations
+
+import logging
+import os
+import sys
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncIterator,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models import LanguageModelInput
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ agenerate_from_stream,
+ generate_from_stream,
+)
+from langchain_core.language_models.llms import create_base_retry_decorator
+from langchain_core.messages import (
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessageChunk,
+ FunctionMessageChunk,
+ HumanMessageChunk,
+ SystemMessageChunk,
+ ToolMessageChunk,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+from langchain_core.runnables import Runnable
+from langchain_core.utils import (
+ get_from_dict_or_env,
+ get_pydantic_field_names,
+)
+
+from langchain_community.adapters.openai import (
+ convert_dict_to_message,
+ convert_message_to_dict,
+)
+from langchain_community.utils.openai import is_openai_v1
+
+if TYPE_CHECKING:
+ import tiktoken
+
+
+logger = logging.getLogger(__name__)
+
+
+def _import_tiktoken() -> Any:
+ try:
+ import tiktoken
+ except ImportError:
+ raise ValueError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to calculate get_token_ids. "
+ "Please install it with `pip install tiktoken`."
+ )
+ return tiktoken
+
+
+def _create_retry_decorator(
+ llm: ChatOpenAI,
+ run_manager: Optional[
+ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
+ ] = None,
+) -> Callable[[Any], Any]:
+ import openai
+
+ errors = [
+ openai.error.Timeout,
+ openai.error.APIError,
+ openai.error.APIConnectionError,
+ openai.error.RateLimitError,
+ openai.error.ServiceUnavailableError,
+ ]
+ return create_base_retry_decorator(
+ error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
+ )
+
+
+async def acompletion_with_retry(
+ llm: ChatOpenAI,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the async completion call."""
+ if is_openai_v1():
+ return await llm.async_client.create(**kwargs)
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @retry_decorator
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ # Use OpenAI's async api https://github.com/openai/openai-python#async-api
+ return await llm.client.acreate(**kwargs)
+
+ return await _completion_with_retry(**kwargs)
+
+
+def _convert_delta_to_message_chunk(
+ _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
+) -> BaseMessageChunk:
+ role = _dict.get("role")
+ content = _dict.get("content") or ""
+ additional_kwargs: Dict = {}
+ if _dict.get("function_call"):
+ function_call = dict(_dict["function_call"])
+ if "name" in function_call and function_call["name"] is None:
+ function_call["name"] = ""
+ additional_kwargs["function_call"] = function_call
+ if _dict.get("tool_calls"):
+ additional_kwargs["tool_calls"] = _dict["tool_calls"]
+
+ if role == "user" or default_class == HumanMessageChunk:
+ return HumanMessageChunk(content=content)
+ elif role == "assistant" or default_class == AIMessageChunk:
+ return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
+ elif role == "system" or default_class == SystemMessageChunk:
+ return SystemMessageChunk(content=content)
+ elif role == "function" or default_class == FunctionMessageChunk:
+ return FunctionMessageChunk(content=content, name=_dict["name"])
+ elif role == "tool" or default_class == ToolMessageChunk:
+ return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
+ elif role or default_class == ChatMessageChunk:
+ return ChatMessageChunk(content=content, role=role)
+ else:
+ return default_class(content=content)
+
+
+class ChatOpenAI(BaseChatModel):
+ """`OpenAI` Chat large language models API.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``OPENAI_API_KEY`` set with your API key.
+
+ Any parameters that are valid to be passed to the openai.create call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatOpenAI
+ openai = ChatOpenAI(model_name="gpt-3.5-turbo")
+ """
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"openai_api_key": "OPENAI_API_KEY"}
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "chat_models", "openai"]
+
+ @property
+ def lc_attributes(self) -> Dict[str, Any]:
+ attributes: Dict[str, Any] = {}
+
+ if self.openai_organization:
+ attributes["openai_organization"] = self.openai_organization
+
+ if self.openai_api_base:
+ attributes["openai_api_base"] = self.openai_api_base
+
+ if self.openai_proxy:
+ attributes["openai_proxy"] = self.openai_proxy
+
+ return attributes
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this model can be serialized by Langchain."""
+ return True
+
+ client: Any = Field(default=None, exclude=True) #: :meta private:
+ async_client: Any = Field(default=None, exclude=True) #: :meta private:
+ model_name: str = Field(default="gpt-3.5-turbo", alias="model")
+ """Model name to use."""
+ temperature: float = 0.7
+ """What sampling temperature to use."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not explicitly specified."""
+ # When updating this to use a SecretStr
+ # Check for classes that derive from this class (as some of them
+ # may assume openai_api_key is a str)
+ openai_api_key: Optional[str] = Field(default=None, alias="api_key")
+ """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
+ openai_api_base: Optional[str] = Field(default=None, alias="base_url")
+ """Base URL path for API requests, leave blank if not using a proxy or service
+ emulator."""
+ openai_organization: Optional[str] = Field(default=None, alias="organization")
+ """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
+ # to support explicit proxy for OpenAI
+ openai_proxy: Optional[str] = None
+ request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
+ default=None, alias="timeout"
+ )
+ """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
+ None."""
+ max_retries: int = 2
+ """Maximum number of retries to make when generating."""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+ n: int = 1
+ """Number of chat completions to generate for each prompt."""
+ max_tokens: Optional[int] = None
+ """Maximum number of tokens to generate."""
+ tiktoken_model_name: Optional[str] = None
+ """The model name to pass to tiktoken when using this class.
+ Tiktoken is used to count the number of tokens in documents to constrain
+ them to be under a certain limit. By default, when set to None, this will
+ be the same as the embedding model name. However, there are some cases
+ where you may want to use this Embedding class with a model name not
+ supported by tiktoken. This can include when using Azure embeddings or
+ when using one of the many model providers that expose an OpenAI-like
+ API but with different models. In those cases, in order to avoid erroring
+ when tiktoken is called, you can specify a model name to use here."""
+ default_headers: Union[Mapping[str, str], None] = None
+ default_query: Union[Mapping[str, object], None] = None
+ # Configure a custom httpx client. See the
+ # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
+ http_client: Union[Any, None] = None
+ """Optional httpx.Client."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ if values["n"] < 1:
+ raise ValueError("n must be at least 1.")
+ if values["n"] > 1 and values["streaming"]:
+ raise ValueError("n must be 1 when streaming.")
+
+ values["openai_api_key"] = get_from_dict_or_env(
+ values, "openai_api_key", "OPENAI_API_KEY"
+ )
+ # Check OPENAI_ORGANIZATION for backwards compatibility.
+ values["openai_organization"] = (
+ values["openai_organization"]
+ or os.getenv("OPENAI_ORG_ID")
+ or os.getenv("OPENAI_ORGANIZATION")
+ )
+ values["openai_api_base"] = values["openai_api_base"] or os.getenv(
+ "OPENAI_API_BASE"
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values,
+ "openai_proxy",
+ "OPENAI_PROXY",
+ default="",
+ )
+ try:
+ import openai
+
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+
+ if is_openai_v1():
+ client_params = {
+ "api_key": values["openai_api_key"],
+ "organization": values["openai_organization"],
+ "base_url": values["openai_api_base"],
+ "timeout": values["request_timeout"],
+ "max_retries": values["max_retries"],
+ "default_headers": values["default_headers"],
+ "default_query": values["default_query"],
+ "http_client": values["http_client"],
+ }
+
+ if not values.get("client"):
+ values["client"] = openai.OpenAI(**client_params).chat.completions
+ if not values.get("async_client"):
+ values["async_client"] = openai.AsyncOpenAI(
+ **client_params
+ ).chat.completions
+ elif not values.get("client"):
+ values["client"] = openai.ChatCompletion
+ else:
+ pass
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling OpenAI API."""
+ params = {
+ "model": self.model_name,
+ "stream": self.streaming,
+ "n": self.n,
+ "temperature": self.temperature,
+ **self.model_kwargs,
+ }
+ if self.max_tokens is not None:
+ params["max_tokens"] = self.max_tokens
+ if self.request_timeout is not None and not is_openai_v1():
+ params["request_timeout"] = self.request_timeout
+ return params
+
+ def completion_with_retry(
+ self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
+ ) -> Any:
+ """Use tenacity to retry the completion call."""
+ if is_openai_v1():
+ return self.client.create(**kwargs)
+
+ retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
+
+ @retry_decorator
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return self.client.create(**kwargs)
+
+ return _completion_with_retry(**kwargs)
+
+ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
+ overall_token_usage: dict = {}
+ system_fingerprint = None
+ for output in llm_outputs:
+ if output is None:
+ # Happens in streaming
+ continue
+ token_usage = output["token_usage"]
+ for k, v in token_usage.items():
+ if k in overall_token_usage:
+ overall_token_usage[k] += v
+ else:
+ overall_token_usage[k] = v
+ if system_fingerprint is None:
+ system_fingerprint = output.get("system_fingerprint")
+ combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
+ if system_fingerprint:
+ combined["system_fingerprint"] = system_fingerprint
+ return combined
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs, "stream": True}
+
+ default_chunk_class = AIMessageChunk
+ for chunk in self.completion_with_retry(
+ messages=message_dicts, run_manager=run_manager, **params
+ ):
+ if not isinstance(chunk, dict):
+ chunk = chunk.dict()
+ if len(chunk["choices"]) == 0:
+ continue
+ choice = chunk["choices"][0]
+ chunk = _convert_delta_to_message_chunk(
+ choice["delta"], default_chunk_class
+ )
+ finish_reason = choice.get("finish_reason")
+ generation_info = (
+ dict(finish_reason=finish_reason) if finish_reason is not None else None
+ )
+ default_chunk_class = chunk.__class__
+ chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._stream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {
+ **params,
+ **({"stream": stream} if stream is not None else {}),
+ **kwargs,
+ }
+ response = self.completion_with_retry(
+ messages=message_dicts, run_manager=run_manager, **params
+ )
+ return self._create_chat_result(response)
+
+ def _create_message_dicts(
+ self, messages: List[BaseMessage], stop: Optional[List[str]]
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ params = self._client_params
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+ message_dicts = [convert_message_to_dict(m) for m in messages]
+ return message_dicts, params
+
+ def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
+ generations = []
+ if not isinstance(response, dict):
+ response = response.dict()
+ for res in response["choices"]:
+ message = convert_dict_to_message(res["message"])
+ gen = ChatGeneration(
+ message=message,
+ generation_info=dict(finish_reason=res.get("finish_reason")),
+ )
+ generations.append(gen)
+ token_usage = response.get("usage", {})
+ llm_output = {
+ "token_usage": token_usage,
+ "model_name": self.model_name,
+ "system_fingerprint": response.get("system_fingerprint", ""),
+ }
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs, "stream": True}
+
+ default_chunk_class = AIMessageChunk
+ async for chunk in await acompletion_with_retry(
+ self, messages=message_dicts, run_manager=run_manager, **params
+ ):
+ if not isinstance(chunk, dict):
+ chunk = chunk.dict()
+ if len(chunk["choices"]) == 0:
+ continue
+ choice = chunk["choices"][0]
+ chunk = _convert_delta_to_message_chunk(
+ choice["delta"], default_chunk_class
+ )
+ finish_reason = choice.get("finish_reason")
+ generation_info = (
+ dict(finish_reason=finish_reason) if finish_reason is not None else None
+ )
+ default_chunk_class = chunk.__class__
+ chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._astream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await agenerate_from_stream(stream_iter)
+
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {
+ **params,
+ **({"stream": stream} if stream is not None else {}),
+ **kwargs,
+ }
+ response = await acompletion_with_retry(
+ self, messages=message_dicts, run_manager=run_manager, **params
+ )
+ return self._create_chat_result(response)
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_name": self.model_name}, **self._default_params}
+
+ @property
+ def _client_params(self) -> Dict[str, Any]:
+ """Get the parameters used for the openai client."""
+ openai_creds: Dict[str, Any] = {
+ "model": self.model_name,
+ }
+ if not is_openai_v1():
+ openai_creds.update(
+ {
+ "api_key": self.openai_api_key,
+ "api_base": self.openai_api_base,
+ "organization": self.openai_organization,
+ }
+ )
+ if self.openai_proxy:
+ import openai
+
+ openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
+ return {**self._default_params, **openai_creds}
+
+ def _get_invocation_params(
+ self, stop: Optional[List[str]] = None, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Get the parameters used to invoke the model."""
+ return {
+ "model": self.model_name,
+ **super()._get_invocation_params(stop=stop),
+ **self._default_params,
+ **kwargs,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "openai-chat"
+
+ def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
+ tiktoken_ = _import_tiktoken()
+ if self.tiktoken_model_name is not None:
+ model = self.tiktoken_model_name
+ else:
+ model = self.model_name
+ if model == "gpt-3.5-turbo":
+ # gpt-3.5-turbo may change over time.
+ # Returning num tokens assuming gpt-3.5-turbo-0301.
+ model = "gpt-3.5-turbo-0301"
+ elif model == "gpt-4":
+ # gpt-4 may change over time.
+ # Returning num tokens assuming gpt-4-0314.
+ model = "gpt-4-0314"
+ # Returns the number of tokens used by a list of messages.
+ try:
+ encoding = tiktoken_.encoding_for_model(model)
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ model = "cl100k_base"
+ encoding = tiktoken_.get_encoding(model)
+ return model, encoding
+
+ def get_token_ids(self, text: str) -> List[int]:
+ """Get the tokens present in the text with tiktoken package."""
+ # tiktoken NOT supported for Python 3.7 or below
+ if sys.version_info[1] <= 7:
+ return super().get_token_ids(text)
+ _, encoding_model = self._get_encoding_model()
+ return encoding_model.encode(text)
+
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
+
+ Official documentation: https://github.com/openai/openai-cookbook/blob/
+ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+ if sys.version_info[1] <= 7:
+ return super().get_num_tokens_from_messages(messages)
+ model, encoding = self._get_encoding_model()
+ if model.startswith("gpt-3.5-turbo-0301"):
+ # every message follows {role/name}\n{content}\n
+ tokens_per_message = 4
+ # if there's a name, the role is omitted
+ tokens_per_name = -1
+ elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
+ tokens_per_message = 3
+ tokens_per_name = 1
+ else:
+ raise NotImplementedError(
+ f"get_num_tokens_from_messages() is not presently implemented "
+ f"for model {model}."
+ "See https://github.com/openai/openai-python/blob/main/chatml.md for "
+ "information on how messages are converted to tokens."
+ )
+ num_tokens = 0
+ messages_dict = [convert_message_to_dict(m) for m in messages]
+ for message in messages_dict:
+ num_tokens += tokens_per_message
+ for key, value in message.items():
+ # Cast str(value) in case the message value is not a string
+ # This occurs with function messages
+ num_tokens += len(encoding.encode(str(value)))
+ if key == "name":
+ num_tokens += tokens_per_name
+ # every reply is primed with assistant
+ num_tokens += 3
+ return num_tokens
+
+ def bind_functions(
+ self,
+ functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
+ function_call: Optional[str] = None,
+ **kwargs: Any,
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
+ """Bind functions (and other objects) to this chat model.
+
+ Args:
+ functions: A list of function definitions to bind to this chat model.
+ Can be a dictionary, pydantic model, or callable. Pydantic
+ models and callables will be automatically converted to
+ their schema dictionary representation.
+ function_call: Which function to require the model to call.
+ Must be the name of the single provided function or
+ "auto" to automatically determine which function to call
+ (if any).
+ kwargs: Any additional parameters to pass to the
+ :class:`~langchain.runnable.Runnable` constructor.
+ """
+ from langchain.chains.openai_functions.base import convert_to_openai_function
+
+ formatted_functions = [convert_to_openai_function(fn) for fn in functions]
+ if function_call is not None:
+ if len(formatted_functions) != 1:
+ raise ValueError(
+ "When specifying `function_call`, you must provide exactly one "
+ "function."
+ )
+ if formatted_functions[0]["name"] != function_call:
+ raise ValueError(
+ f"Function call {function_call} was specified, but the only "
+ f"provided function was {formatted_functions[0]['name']}."
+ )
+ function_call_ = {"name": function_call}
+ kwargs = {**kwargs, "function_call": function_call_}
+ return super().bind(
+ functions=formatted_functions,
+ **kwargs,
+ )
diff --git a/libs/community/langchain_community/chat_models/pai_eas_endpoint.py b/libs/community/langchain_community/chat_models/pai_eas_endpoint.py
new file mode 100644
index 00000000000..85f13246817
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/pai_eas_endpoint.py
@@ -0,0 +1,324 @@
+import asyncio
+import json
+import logging
+from functools import partial
+from typing import Any, AsyncIterator, Dict, List, Optional, cast
+
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class PaiEasChatEndpoint(BaseChatModel):
+ """Eas LLM Service chat model API.
+
+ To use, must have a deployed eas chat llm service on AliCloud. One can set the
+ environment variable ``eas_service_url`` and ``eas_service_token`` set with your eas
+ service url and service token.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import PaiEasChatEndpoint
+ eas_chat_endpoint = PaiEasChatEndpoint(
+ eas_service_url="your_service_url",
+ eas_service_token="your_service_token"
+ )
+ """
+
+ """PAI-EAS Service URL"""
+ eas_service_url: str
+
+ """PAI-EAS Service TOKEN"""
+ eas_service_token: str
+
+ """PAI-EAS Service Infer Params"""
+ max_new_tokens: Optional[int] = 512
+ temperature: Optional[float] = 0.8
+ top_p: Optional[float] = 0.1
+ top_k: Optional[int] = 10
+ do_sample: Optional[bool] = False
+ use_cache: Optional[bool] = True
+ stop_sequences: Optional[List[str]] = None
+
+ """Enable stream chat mode."""
+ streaming: bool = False
+
+ """Key/value arguments to pass to the model. Reserved for future use"""
+ model_kwargs: Optional[dict] = None
+
+ version: Optional[str] = "2.0"
+
+ timeout: Optional[int] = 5000
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["eas_service_url"] = get_from_dict_or_env(
+ values, "eas_service_url", "EAS_SERVICE_URL"
+ )
+ values["eas_service_token"] = get_from_dict_or_env(
+ values, "eas_service_token", "EAS_SERVICE_TOKEN"
+ )
+
+ return values
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ "eas_service_url": self.eas_service_url,
+ "eas_service_token": self.eas_service_token,
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "pai_eas_chat_endpoint"
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Cohere API."""
+ return {
+ "max_new_tokens": self.max_new_tokens,
+ "temperature": self.temperature,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ "stop_sequences": [],
+ "do_sample": self.do_sample,
+ "use_cache": self.use_cache,
+ }
+
+ def _invocation_params(
+ self, stop_sequences: Optional[List[str]], **kwargs: Any
+ ) -> dict:
+ params = self._default_params
+ if self.model_kwargs:
+ params.update(self.model_kwargs)
+ if self.stop_sequences is not None and stop_sequences is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+ elif self.stop_sequences is not None:
+ params["stop"] = self.stop_sequences
+ else:
+ params["stop"] = stop_sequences
+ return {**params, **kwargs}
+
+ def format_request_payload(
+ self, messages: List[BaseMessage], **model_kwargs: Any
+ ) -> dict:
+ prompt: Dict[str, Any] = {}
+ user_content: List[str] = []
+ assistant_content: List[str] = []
+
+ for message in messages:
+ """Converts message to a dict according to role"""
+ content = cast(str, message.content)
+ if isinstance(message, HumanMessage):
+ user_content = user_content + [content]
+ elif isinstance(message, AIMessage):
+ assistant_content = assistant_content + [content]
+ elif isinstance(message, SystemMessage):
+ prompt["system_prompt"] = content
+ elif isinstance(message, ChatMessage) and message.role in [
+ "user",
+ "assistant",
+ "system",
+ ]:
+ if message.role == "system":
+ prompt["system_prompt"] = content
+ elif message.role == "user":
+ user_content = user_content + [content]
+ elif message.role == "assistant":
+ assistant_content = assistant_content + [content]
+ else:
+ supported = ",".join([role for role in ["user", "assistant", "system"]])
+ raise ValueError(
+ f"""Received unsupported role.
+ Supported roles for the LLaMa Foundation Model: {supported}"""
+ )
+ prompt["prompt"] = user_content[len(user_content) - 1]
+ history = [
+ history_item
+ for _, history_item in enumerate(zip(user_content[:-1], assistant_content))
+ ]
+
+ prompt["history"] = history
+
+ return {**prompt, **model_kwargs}
+
+ def _format_response_payload(
+ self, output: bytes, stop_sequences: Optional[List[str]]
+ ) -> str:
+ """Formats response"""
+ try:
+ text = json.loads(output)["response"]
+ if stop_sequences:
+ text = enforce_stop_tokens(text, stop_sequences)
+ return text
+ except Exception as e:
+ if isinstance(e, json.decoder.JSONDecodeError):
+ return output.decode("utf-8")
+ raise e
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
+ message = AIMessage(content=output_str)
+ generation = ChatGeneration(message=message)
+ return ChatResult(generations=[generation])
+
+ def _call(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ params = self._invocation_params(stop, **kwargs)
+
+ request_payload = self.format_request_payload(messages, **params)
+ response_payload = self._call_eas(request_payload)
+ generated_text = self._format_response_payload(response_payload, params["stop"])
+
+ if run_manager:
+ run_manager.on_llm_new_token(generated_text)
+
+ return generated_text
+
+ def _call_eas(self, query_body: dict) -> Any:
+ """Generate text from the eas service."""
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ "Authorization": f"{self.eas_service_token}",
+ }
+
+ # make request
+ response = requests.post(
+ self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
+ )
+
+ if response.status_code != 200:
+ raise Exception(
+ f"Request failed with status code {response.status_code}"
+ f" and message {response.text}"
+ )
+
+ return response.text
+
+ def _call_eas_stream(self, query_body: dict) -> Any:
+ """Generate text from the eas service."""
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ "Authorization": f"{self.eas_service_token}",
+ }
+
+ # make request
+ response = requests.post(
+ self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
+ )
+
+ if response.status_code != 200:
+ raise Exception(
+ f"Request failed with status code {response.status_code}"
+ f" and message {response.text}"
+ )
+
+ return response
+
+ def _convert_chunk_to_message_message(
+ self,
+ chunk: str,
+ ) -> AIMessageChunk:
+ data = json.loads(chunk.encode("utf-8"))
+ return AIMessageChunk(content=data.get("response", ""))
+
+ async def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ params = self._invocation_params(stop, **kwargs)
+
+ request_payload = self.format_request_payload(messages, **params)
+ request_payload["use_stream_chat"] = True
+
+ response = self._call_eas_stream(request_payload)
+ for chunk in response.iter_lines(
+ chunk_size=8192, decode_unicode=False, delimiter=b"\0"
+ ):
+ if chunk:
+ content = self._convert_chunk_to_message_message(chunk)
+
+ # identify stop sequence in generated text, if any
+ stop_seq_found: Optional[str] = None
+ for stop_seq in params["stop"]:
+ if stop_seq in content.content:
+ stop_seq_found = stop_seq
+
+ # identify text to yield
+ text: Optional[str] = None
+ if stop_seq_found:
+ content.content = content.content[
+ : content.content.index(stop_seq_found)
+ ]
+
+ # yield text, if any
+ if text:
+ if run_manager:
+ await run_manager.on_llm_new_token(cast(str, content.content))
+ yield ChatGenerationChunk(message=content)
+
+ # break if stop sequence found
+ if stop_seq_found:
+ break
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if stream if stream is not None else self.streaming:
+ generation: Optional[ChatGenerationChunk] = None
+ async for chunk in self._astream(
+ messages=messages, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ generation = chunk
+ assert generation is not None
+ return ChatResult(generations=[generation])
+
+ func = partial(
+ self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return await asyncio.get_event_loop().run_in_executor(None, func)
diff --git a/libs/community/langchain_community/chat_models/promptlayer_openai.py b/libs/community/langchain_community/chat_models/promptlayer_openai.py
new file mode 100644
index 00000000000..551655e4c7d
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/promptlayer_openai.py
@@ -0,0 +1,140 @@
+"""PromptLayer wrapper."""
+import datetime
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.messages import BaseMessage
+from langchain_core.outputs import ChatResult
+
+from langchain_community.chat_models import ChatOpenAI
+
+
+class PromptLayerChatOpenAI(ChatOpenAI):
+ """`PromptLayer` and `OpenAI` Chat large language models API.
+
+ To use, you should have the ``openai`` and ``promptlayer`` python
+ package installed, and the environment variable ``OPENAI_API_KEY``
+ and ``PROMPTLAYER_API_KEY`` set with your openAI API key and
+ promptlayer key respectively.
+
+ All parameters that can be passed to the OpenAI LLM can also
+ be passed here. The PromptLayerChatOpenAI adds to optional
+
+ parameters:
+ ``pl_tags``: List of strings to tag the request with.
+ ``return_pl_id``: If True, the PromptLayer request ID will be
+ returned in the ``generation_info`` field of the
+ ``Generation`` object.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import PromptLayerChatOpenAI
+ openai = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo")
+ """
+
+ pl_tags: Optional[List[str]]
+ return_pl_id: Optional[bool] = False
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Call ChatOpenAI generate and then call PromptLayer API to log the request."""
+ from promptlayer.utils import get_api_key, promptlayer_api_request
+
+ request_start_time = datetime.datetime.now().timestamp()
+ generated_responses = super()._generate(
+ messages, stop, run_manager, stream=stream, **kwargs
+ )
+ request_end_time = datetime.datetime.now().timestamp()
+ message_dicts, params = super()._create_message_dicts(messages, stop)
+ for i, generation in enumerate(generated_responses.generations):
+ response_dict, params = super()._create_message_dicts(
+ [generation.message], stop
+ )
+ params = {**params, **kwargs}
+ pl_request_id = promptlayer_api_request(
+ "langchain.PromptLayerChatOpenAI",
+ "langchain",
+ message_dicts,
+ params,
+ self.pl_tags,
+ response_dict,
+ request_start_time,
+ request_end_time,
+ get_api_key(),
+ return_pl_id=self.return_pl_id,
+ )
+ if self.return_pl_id:
+ if generation.generation_info is None or not isinstance(
+ generation.generation_info, dict
+ ):
+ generation.generation_info = {}
+ generation.generation_info["pl_request_id"] = pl_request_id
+ return generated_responses
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Call ChatOpenAI agenerate and then call PromptLayer to log."""
+ from promptlayer.utils import get_api_key, promptlayer_api_request_async
+
+ request_start_time = datetime.datetime.now().timestamp()
+ generated_responses = await super()._agenerate(
+ messages, stop, run_manager, stream=stream, **kwargs
+ )
+ request_end_time = datetime.datetime.now().timestamp()
+ message_dicts, params = super()._create_message_dicts(messages, stop)
+ for i, generation in enumerate(generated_responses.generations):
+ response_dict, params = super()._create_message_dicts(
+ [generation.message], stop
+ )
+ params = {**params, **kwargs}
+ pl_request_id = await promptlayer_api_request_async(
+ "langchain.PromptLayerChatOpenAI.async",
+ "langchain",
+ message_dicts,
+ params,
+ self.pl_tags,
+ response_dict,
+ request_start_time,
+ request_end_time,
+ get_api_key(),
+ return_pl_id=self.return_pl_id,
+ )
+ if self.return_pl_id:
+ if generation.generation_info is None or not isinstance(
+ generation.generation_info, dict
+ ):
+ generation.generation_info = {}
+ generation.generation_info["pl_request_id"] = pl_request_id
+ return generated_responses
+
+ @property
+ def _llm_type(self) -> str:
+ return "promptlayer-openai-chat"
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return {
+ **super()._identifying_params,
+ "pl_tags": self.pl_tags,
+ "return_pl_id": self.return_pl_id,
+ }
diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py
new file mode 100644
index 00000000000..59ea6bec915
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/tongyi.py
@@ -0,0 +1,404 @@
+from __future__ import annotations
+
+import logging
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Type,
+)
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessage,
+ ChatMessageChunk,
+ FunctionMessage,
+ FunctionMessageChunk,
+ HumanMessage,
+ HumanMessageChunk,
+ SystemMessage,
+ SystemMessageChunk,
+)
+from langchain_core.outputs import (
+ ChatGeneration,
+ ChatGenerationChunk,
+ ChatResult,
+ GenerationChunk,
+)
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from requests.exceptions import HTTPError
+from tenacity import (
+ RetryCallState,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["role"]
+ if role == "user":
+ return HumanMessage(content=_dict["content"])
+ elif role == "assistant":
+ content = _dict.get("content", "") or ""
+ if _dict.get("function_call"):
+ additional_kwargs = {"function_call": dict(_dict["function_call"])}
+ else:
+ additional_kwargs = {}
+ return AIMessage(content=content, additional_kwargs=additional_kwargs)
+ elif role == "system":
+ return SystemMessage(content=_dict["content"])
+ elif role == "function":
+ return FunctionMessage(content=_dict["content"], name=_dict["name"])
+ else:
+ return ChatMessage(content=_dict["content"], role=role)
+
+
+def convert_message_to_dict(message: BaseMessage) -> dict:
+ message_dict: Dict[str, Any]
+ if isinstance(message, ChatMessage):
+ message_dict = {"role": message.role, "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ if "function_call" in message.additional_kwargs:
+ message_dict["function_call"] = message.additional_kwargs["function_call"]
+ # If function call only, content is None not empty string
+ if message_dict["content"] == "":
+ message_dict["content"] = None
+ elif isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ message_dict = {
+ "role": "function",
+ "content": message.content,
+ "name": message.name,
+ }
+ else:
+ raise TypeError(f"Got unknown type {message}")
+ if "name" in message.additional_kwargs:
+ message_dict["name"] = message.additional_kwargs["name"]
+ return message_dict
+
+
+def _stream_response_to_generation_chunk(
+ stream_response: Dict[str, Any],
+ length: int,
+) -> GenerationChunk:
+ """Convert a stream response to a generation chunk.
+
+ As the low level API implement is different from openai and other llm.
+ Stream response of Tongyi is not split into chunks, but all data generated before.
+ For example, the answer 'Hi Pickle Rick! How can I assist you today?'
+ Other llm will stream answer:
+ 'Hi Pickle',
+ ' Rick!',
+ ' How can I assist you today?'.
+
+ Tongyi answer:
+ 'Hi Pickle',
+ 'Hi Pickle Rick!',
+ 'Hi Pickle Rick! How can I assist you today?'.
+
+ As the GenerationChunk is implemented with chunks. Only return full_text[length:]
+ for new chunk.
+ """
+ full_text = stream_response["output"]["text"]
+ text = full_text[length:]
+ finish_reason = stream_response["output"].get("finish_reason", None)
+
+ return GenerationChunk(
+ text=text,
+ generation_info=dict(
+ finish_reason=finish_reason,
+ ),
+ )
+
+
+def _create_retry_decorator(
+ llm: ChatTongyi,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+) -> Callable[[Any], Any]:
+ def _before_sleep(retry_state: RetryCallState) -> None:
+ if run_manager:
+ run_manager.on_retry(retry_state)
+ return None
+
+ min_seconds = 1
+ max_seconds = 4
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(llm.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(retry_if_exception_type(HTTPError)),
+ before_sleep=_before_sleep,
+ )
+
+
+def _convert_delta_to_message_chunk(
+ _dict: Mapping[str, Any],
+ default_class: Type[BaseMessageChunk],
+ length: int,
+) -> BaseMessageChunk:
+ role = _dict.get("role")
+ full_content = _dict.get("content") or ""
+ content = full_content[length:]
+ if _dict.get("function_call"):
+ additional_kwargs = {"function_call": dict(_dict["function_call"])}
+ else:
+ additional_kwargs = {}
+
+ if role == "user" or default_class == HumanMessageChunk:
+ return HumanMessageChunk(content=content)
+ elif role == "assistant" or default_class == AIMessageChunk:
+ return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
+ elif role == "system" or default_class == SystemMessageChunk:
+ return SystemMessageChunk(content=content)
+ elif role == "function" or default_class == FunctionMessageChunk:
+ return FunctionMessageChunk(content=content, name=_dict["name"])
+ elif role or default_class == ChatMessageChunk:
+ return ChatMessageChunk(content=content, role=role)
+ else:
+ return default_class(content=content)
+
+
+class ChatTongyi(BaseChatModel):
+ """Alibaba Tongyi Qwen chat models API.
+
+ To use, you should have the ``dashscope`` python package installed,
+ and set env ``DASHSCOPE_API_KEY`` with your API key, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import Tongyi
+ Tongyi_chat = ChatTongyi()
+ """
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"dashscope_api_key": "DASHSCOPE_API_KEY"}
+
+ @property
+ def lc_serializable(self) -> bool:
+ return True
+
+ client: Any #: :meta private:
+ model_name: str = Field(default="qwen-turbo", alias="model")
+
+ """Model name to use."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+
+ top_p: float = 0.8
+ """Total probability mass of tokens to consider at each step."""
+
+ dashscope_api_key: Optional[str] = None
+ """Dashscope api key provide by alicloud."""
+
+ n: int = 1
+ """How many completions to generate for each prompt."""
+
+ streaming: bool = False
+ """Whether to stream the results or not."""
+
+ max_retries: int = 10
+ """Maximum number of retries to make when generating."""
+
+ prefix_messages: List = Field(default_factory=list)
+ """Series of messages for Chat input."""
+
+ result_format: str = Field(default="message")
+ """Return result format"""
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "tongyi"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ get_from_dict_or_env(values, "dashscope_api_key", "DASHSCOPE_API_KEY")
+ try:
+ import dashscope
+ except ImportError:
+ raise ImportError(
+ "Could not import dashscope python package. "
+ "Please install it with `pip install dashscope --upgrade`."
+ )
+ try:
+ values["client"] = dashscope.Generation
+ except AttributeError:
+ raise ValueError(
+ "`dashscope` has no `Generation` attribute, this is likely "
+ "due to an old version of the dashscope package. Try upgrading it "
+ "with `pip install --upgrade dashscope`."
+ )
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling OpenAI API."""
+ return {
+ "model": self.model_name,
+ "top_p": self.top_p,
+ "stream": self.streaming,
+ "n": self.n,
+ "result_format": self.result_format,
+ **self.model_kwargs,
+ }
+
+ def completion_with_retry(
+ self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
+ ) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
+
+ @retry_decorator
+ def _completion_with_retry(**_kwargs: Any) -> Any:
+ resp = self.client.call(**_kwargs)
+ if resp.status_code == 200:
+ return resp
+ elif resp.status_code in [400, 401]:
+ raise ValueError(
+ f"status_code: {resp.status_code} \n "
+ f"code: {resp.code} \n message: {resp.message}"
+ )
+ else:
+ raise HTTPError(
+ f"HTTP error occurred: status_code: {resp.status_code} \n "
+ f"code: {resp.code} \n message: {resp.message}",
+ response=resp,
+ )
+
+ return _completion_with_retry(**kwargs)
+
+ def stream_completion_with_retry(
+ self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
+ ) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
+
+ @retry_decorator
+ def _stream_completion_with_retry(**_kwargs: Any) -> Any:
+ return self.client.call(**_kwargs)
+
+ return _stream_completion_with_retry(**kwargs)
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._stream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ if not messages:
+ raise ValueError("No messages provided.")
+
+ message_dicts, params = self._create_message_dicts(messages, stop)
+
+ if message_dicts[-1]["role"] != "user":
+ raise ValueError("Last message should be user message.")
+
+ params = {**params, **kwargs}
+ response = self.completion_with_retry(
+ messages=message_dicts, run_manager=run_manager, **params
+ )
+ return self._create_chat_result(response)
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ message_dicts, params = self._create_message_dicts(messages, stop)
+ params = {**params, **kwargs, "stream": True}
+ # Mark current chunk total length
+ length = 0
+ default_chunk_class = AIMessageChunk
+ for chunk in self.stream_completion_with_retry(
+ messages=message_dicts, run_manager=run_manager, **params
+ ):
+ if len(chunk["output"]["choices"]) == 0:
+ continue
+ choice = chunk["output"]["choices"][0]
+
+ chunk = _convert_delta_to_message_chunk(
+ choice["message"], default_chunk_class, length
+ )
+ finish_reason = choice.get("finish_reason")
+ generation_info = (
+ dict(finish_reason=finish_reason) if finish_reason is not None else None
+ )
+ default_chunk_class = chunk.__class__
+ chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+ length = len(choice["message"]["content"])
+
+ def _create_message_dicts(
+ self, messages: List[BaseMessage], stop: Optional[List[str]]
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ params = self._client_params()
+
+ # Ensure `stop` is a list of strings
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+
+ message_dicts = [convert_message_to_dict(m) for m in messages]
+ return message_dicts, params
+
+ def _client_params(self) -> Dict[str, Any]:
+ """Get the parameters used for the openai client."""
+ creds: Dict[str, Any] = {
+ "api_key": self.dashscope_api_key,
+ }
+ return {**self._default_params, **creds}
+
+ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for res in response["output"]["choices"]:
+ message = convert_dict_to_message(res["message"])
+ gen = ChatGeneration(
+ message=message,
+ generation_info=dict(finish_reason=res.get("finish_reason")),
+ )
+ generations.append(gen)
+ token_usage = response.get("usage", {})
+ llm_output = {"token_usage": token_usage, "model_name": self.model_name}
+ return ChatResult(generations=generations, llm_output=llm_output)
diff --git a/libs/community/langchain_community/chat_models/vertexai.py b/libs/community/langchain_community/chat_models/vertexai.py
new file mode 100644
index 00000000000..11f14e3d7fa
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/vertexai.py
@@ -0,0 +1,273 @@
+"""Wrapper around Google VertexAI chat-based models."""
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import root_validator
+
+from langchain_community.llms.vertexai import _VertexAICommon, is_codey_model
+from langchain_community.utilities.vertexai import raise_vertex_import_error
+
+if TYPE_CHECKING:
+ from vertexai.language_models import (
+ ChatMessage,
+ ChatSession,
+ CodeChatSession,
+ InputOutputTextPair,
+ )
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class _ChatHistory:
+ """Represents a context and a history of messages."""
+
+ history: List["ChatMessage"] = field(default_factory=list)
+ context: Optional[str] = None
+
+
+def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
+ """Parse a sequence of messages into history.
+
+ Args:
+ history: The list of messages to re-create the history of the chat.
+ Returns:
+ A parsed chat history.
+ Raises:
+ ValueError: If a sequence of message has a SystemMessage not at the
+ first place.
+ """
+ from vertexai.language_models import ChatMessage
+
+ vertex_messages, context = [], None
+ for i, message in enumerate(history):
+ content = cast(str, message.content)
+ if i == 0 and isinstance(message, SystemMessage):
+ context = content
+ elif isinstance(message, AIMessage):
+ vertex_message = ChatMessage(content=message.content, author="bot")
+ vertex_messages.append(vertex_message)
+ elif isinstance(message, HumanMessage):
+ vertex_message = ChatMessage(content=message.content, author="user")
+ vertex_messages.append(vertex_message)
+ else:
+ raise ValueError(
+ f"Unexpected message with type {type(message)} at the position {i}."
+ )
+ chat_history = _ChatHistory(context=context, history=vertex_messages)
+ return chat_history
+
+
+def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
+ from vertexai.language_models import InputOutputTextPair
+
+ if len(examples) % 2 != 0:
+ raise ValueError(
+ f"Expect examples to have an even amount of messages, got {len(examples)}."
+ )
+ example_pairs = []
+ input_text = None
+ for i, example in enumerate(examples):
+ if i % 2 == 0:
+ if not isinstance(example, HumanMessage):
+ raise ValueError(
+ f"Expected the first message in a part to be from human, got "
+ f"{type(example)} for the {i}th message."
+ )
+ input_text = example.content
+ if i % 2 == 1:
+ if not isinstance(example, AIMessage):
+ raise ValueError(
+ f"Expected the second message in a part to be from AI, got "
+ f"{type(example)} for the {i}th message."
+ )
+ pair = InputOutputTextPair(
+ input_text=input_text, output_text=example.content
+ )
+ example_pairs.append(pair)
+ return example_pairs
+
+
+def _get_question(messages: List[BaseMessage]) -> HumanMessage:
+ """Get the human message at the end of a list of input messages to a chat model."""
+ if not messages:
+ raise ValueError("You should provide at least one message to start the chat!")
+ question = messages[-1]
+ if not isinstance(question, HumanMessage):
+ raise ValueError(
+ f"Last message in the list should be from human, got {question.type}."
+ )
+ return question
+
+
+class ChatVertexAI(_VertexAICommon, BaseChatModel):
+ """`Vertex AI` Chat large language models API."""
+
+ model_name: str = "chat-bison"
+ "Underlying model name."
+ examples: Optional[List[BaseMessage]] = None
+
+ @classmethod
+ def is_lc_serializable(self) -> bool:
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "chat_models", "vertexai"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in environment."""
+ cls._try_init_vertexai(values)
+ try:
+ from vertexai.language_models import ChatModel, CodeChatModel
+ except ImportError:
+ raise_vertex_import_error()
+ if is_codey_model(values["model_name"]):
+ model_cls = CodeChatModel
+ else:
+ model_cls = ChatModel
+ values["client"] = model_cls.from_pretrained(values["model_name"])
+ return values
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Generate next turn in the conversation.
+
+ Args:
+ messages: The history of the conversation as a list of messages. Code chat
+ does not support context.
+ stop: The list of stop words (optional).
+ run_manager: The CallbackManager for LLM run, it's not used at the moment.
+ stream: Whether to use the streaming endpoint.
+
+ Returns:
+ The ChatResult that contains outputs generated by the model.
+
+ Raises:
+ ValueError: if the last message in the list is not from human.
+ """
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ stream_iter = self._stream(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ question = _get_question(messages)
+ history = _parse_chat_history(messages[:-1])
+ params = self._prepare_params(stop=stop, stream=False, **kwargs)
+ examples = kwargs.get("examples") or self.examples
+ if examples:
+ params["examples"] = _parse_examples(examples)
+
+ msg_params = {}
+ if "candidate_count" in params:
+ msg_params["candidate_count"] = params.pop("candidate_count")
+
+ chat = self._start_chat(history, **params)
+ response = chat.send_message(question.content, **msg_params)
+ generations = [
+ ChatGeneration(message=AIMessage(content=r.text))
+ for r in response.candidates
+ ]
+ return ChatResult(generations=generations)
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Asynchronously generate next turn in the conversation.
+
+ Args:
+ messages: The history of the conversation as a list of messages. Code chat
+ does not support context.
+ stop: The list of stop words (optional).
+ run_manager: The CallbackManager for LLM run, it's not used at the moment.
+
+ Returns:
+ The ChatResult that contains outputs generated by the model.
+
+ Raises:
+ ValueError: if the last message in the list is not from human.
+ """
+ if "stream" in kwargs:
+ kwargs.pop("stream")
+ logger.warning("ChatVertexAI does not currently support async streaming.")
+ question = _get_question(messages)
+ history = _parse_chat_history(messages[:-1])
+ params = self._prepare_params(stop=stop, **kwargs)
+ examples = kwargs.get("examples", None)
+ if examples:
+ params["examples"] = _parse_examples(examples)
+
+ msg_params = {}
+ if "candidate_count" in params:
+ msg_params["candidate_count"] = params.pop("candidate_count")
+ chat = self._start_chat(history, **params)
+ response = await chat.send_message_async(question.content, **msg_params)
+ generations = [
+ ChatGeneration(message=AIMessage(content=r.text))
+ for r in response.candidates
+ ]
+ return ChatResult(generations=generations)
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ question = _get_question(messages)
+ history = _parse_chat_history(messages[:-1])
+ params = self._prepare_params(stop=stop, stream=True, **kwargs)
+ examples = kwargs.get("examples", None)
+ if examples:
+ params["examples"] = _parse_examples(examples)
+
+ chat = self._start_chat(history, **params)
+ responses = chat.send_message_streaming(question.content, **params)
+ for response in responses:
+ if run_manager:
+ run_manager.on_llm_new_token(response.text)
+ yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
+
+ def _start_chat(
+ self, history: _ChatHistory, **kwargs: Any
+ ) -> Union[ChatSession, CodeChatSession]:
+ if not self.is_codey_model:
+ return self.client.start_chat(
+ context=history.context, message_history=history.history, **kwargs
+ )
+ else:
+ return self.client.start_chat(message_history=history.history, **kwargs)
diff --git a/libs/community/langchain_community/chat_models/volcengine_maas.py b/libs/community/langchain_community/chat_models/volcengine_maas.py
new file mode 100644
index 00000000000..d684aece066
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/volcengine_maas.py
@@ -0,0 +1,139 @@
+from __future__ import annotations
+
+from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ FunctionMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+
+from langchain_community.llms.volcengine_maas import VolcEngineMaasBase
+
+
+def _convert_message_to_dict(message: BaseMessage) -> dict:
+ if isinstance(message, SystemMessage):
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"role": "user", "content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"role": "assistant", "content": message.content}
+ elif isinstance(message, FunctionMessage):
+ message_dict = {"role": "function", "content": message.content}
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ return message_dict
+
+
+def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
+ content = _dict.get("choice", {}).get("message", {}).get("content", "")
+ return AIMessage(content=content)
+
+
+class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
+
+ """volc engine maas hosts a plethora of models.
+ You can utilize these models through this class.
+
+ To use, you should have the ``volcengine`` python package installed.
+ and set access key and secret key by environment variable or direct pass those
+ to this class.
+ access key, secret key are required parameters which you could get help
+ https://www.volcengine.com/docs/6291/65568
+
+ In order to use them, it is necessary to install the 'volcengine' Python package.
+ The access key and secret key must be set either via environment variables or
+ passed directly to this class.
+ access key and secret key are mandatory parameters for which assistance can be
+ sought at https://www.volcengine.com/docs/6291/65568.
+
+ The two methods are as follows:
+ * Environment Variable
+ Set the environment variables 'VOLC_ACCESSKEY' and 'VOLC_SECRETKEY' with your
+ access key and secret key.
+
+ * Pass Directly to Class
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import VolcEngineMaasLLM
+ model = VolcEngineMaasChat(model="skylark-lite-public",
+ volc_engine_maas_ak="your_ak",
+ volc_engine_maas_sk="your_sk")
+ """
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "volc-engine-maas-chat"
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this model can be serialized by Langchain."""
+ return False
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return {
+ **{"endpoint": self.endpoint, "model": self.model},
+ **super()._identifying_params,
+ }
+
+ def _convert_prompt_msg_params(
+ self,
+ messages: List[BaseMessage],
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ model_req = {
+ "model": {
+ "name": self.model,
+ }
+ }
+ if self.model_version is not None:
+ model_req["model"]["version"] = self.model_version
+ return {
+ **model_req,
+ "messages": [_convert_message_to_dict(message) for message in messages],
+ "parameters": {**self._default_params, **kwargs},
+ }
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ params = self._convert_prompt_msg_params(messages, **kwargs)
+ for res in self.client.stream_chat(params):
+ if res:
+ msg = convert_dict_to_message(res)
+ yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
+ if run_manager:
+ run_manager.on_llm_new_token(cast(str, msg.content))
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ completion = ""
+ if self.streaming:
+ for chunk in self._stream(messages, stop, run_manager, **kwargs):
+ completion += chunk.text
+ else:
+ params = self._convert_prompt_msg_params(messages, **kwargs)
+ res = self.client.chat(params)
+ msg = convert_dict_to_message(res)
+ completion = cast(str, msg.content)
+
+ message = AIMessage(content=completion)
+ return ChatResult(generations=[ChatGeneration(message=message)])
diff --git a/libs/community/langchain_community/chat_models/yandex.py b/libs/community/langchain_community/chat_models/yandex.py
new file mode 100644
index 00000000000..f94be83a899
--- /dev/null
+++ b/libs/community/langchain_community/chat_models/yandex.py
@@ -0,0 +1,132 @@
+"""Wrapper around YandexGPT chat models."""
+import logging
+from typing import Any, Dict, List, Optional, Tuple, cast
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ HumanMessage,
+ SystemMessage,
+)
+from langchain_core.outputs import ChatGeneration, ChatResult
+
+from langchain_community.llms.utils import enforce_stop_tokens
+from langchain_community.llms.yandex import _BaseYandexGPT
+
+logger = logging.getLogger(__name__)
+
+
+def _parse_message(role: str, text: str) -> Dict:
+ return {"role": role, "text": text}
+
+
+def _parse_chat_history(history: List[BaseMessage]) -> Tuple[List[Dict[str, str]], str]:
+ """Parse a sequence of messages into history.
+
+ Returns:
+ A tuple of a list of parsed messages and an instruction message for the model.
+ """
+ chat_history = []
+ instruction = ""
+ for message in history:
+ content = cast(str, message.content)
+ if isinstance(message, HumanMessage):
+ chat_history.append(_parse_message("user", content))
+ if isinstance(message, AIMessage):
+ chat_history.append(_parse_message("assistant", content))
+ if isinstance(message, SystemMessage):
+ instruction = content
+ return chat_history, instruction
+
+
+class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
+ """Wrapper around YandexGPT large language models.
+
+ There are two authentication options for the service account
+ with the ``ai.languageModels.user`` role:
+ - You can specify the token in a constructor parameter `iam_token`
+ or in an environment variable `YC_IAM_TOKEN`.
+ - You can specify the key in a constructor parameter `api_key`
+ or in an environment variable `YC_API_KEY`.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatYandexGPT
+ chat_model = ChatYandexGPT(iam_token="t1.9eu...")
+
+ """
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Generate next turn in the conversation.
+ Args:
+ messages: The history of the conversation as a list of messages.
+ stop: The list of stop words (optional).
+ run_manager: The CallbackManager for LLM run, it's not used at the moment.
+
+ Returns:
+ The ChatResult that contains outputs generated by the model.
+
+ Raises:
+ ValueError: if the last message in the list is not from human.
+ """
+ try:
+ import grpc
+ from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
+ from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions, Message
+ from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import ChatRequest
+ from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
+ TextGenerationServiceStub,
+ )
+ except ImportError as e:
+ raise ImportError(
+ "Please install YandexCloud SDK" " with `pip install yandexcloud`."
+ ) from e
+ if not messages:
+ raise ValueError(
+ "You should provide at least one message to start the chat!"
+ )
+ message_history, instruction = _parse_chat_history(messages)
+ channel_credentials = grpc.ssl_channel_credentials()
+ channel = grpc.secure_channel(self.url, channel_credentials)
+ request = ChatRequest(
+ model=self.model_name,
+ generation_options=GenerationOptions(
+ temperature=DoubleValue(value=self.temperature),
+ max_tokens=Int64Value(value=self.max_tokens),
+ ),
+ instruction_text=instruction,
+ messages=[Message(**message) for message in message_history],
+ )
+ stub = TextGenerationServiceStub(channel)
+ if self.iam_token:
+ metadata = (("authorization", f"Bearer {self.iam_token}"),)
+ else:
+ metadata = (("authorization", f"Api-Key {self.api_key}"),)
+ res = stub.Chat(request, metadata=metadata)
+ text = list(res)[0].message.text
+ text = text if stop is None else enforce_stop_tokens(text, stop)
+ message = AIMessage(content=text)
+ return ChatResult(generations=[ChatGeneration(message=message)])
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ raise NotImplementedError(
+ """YandexGPT doesn't support async requests at the moment."""
+ )
diff --git a/libs/community/langchain_community/docstore/__init__.py b/libs/community/langchain_community/docstore/__init__.py
new file mode 100644
index 00000000000..1de54381661
--- /dev/null
+++ b/libs/community/langchain_community/docstore/__init__.py
@@ -0,0 +1,21 @@
+"""**Docstores** are classes to store and load Documents.
+
+The **Docstore** is a simplified version of the Document Loader.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ Docstore --> # Examples: InMemoryDocstore, Wikipedia
+
+**Main helpers:**
+
+.. code-block::
+
+ Document, AddableMixin
+"""
+from langchain_community.docstore.arbitrary_fn import DocstoreFn
+from langchain_community.docstore.in_memory import InMemoryDocstore
+from langchain_community.docstore.wikipedia import Wikipedia
+
+__all__ = ["DocstoreFn", "InMemoryDocstore", "Wikipedia"]
diff --git a/libs/community/langchain_community/docstore/arbitrary_fn.py b/libs/community/langchain_community/docstore/arbitrary_fn.py
new file mode 100644
index 00000000000..6495d37b5eb
--- /dev/null
+++ b/libs/community/langchain_community/docstore/arbitrary_fn.py
@@ -0,0 +1,38 @@
+from typing import Callable, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.docstore.base import Docstore
+
+
+class DocstoreFn(Docstore):
+ """Langchain Docstore via arbitrary lookup function.
+
+ This is useful when:
+ * it's expensive to construct an InMemoryDocstore/dict
+ * you retrieve documents from remote sources
+ * you just want to reuse existing objects
+ """
+
+ def __init__(
+ self,
+ lookup_fn: Callable[[str], Union[Document, str]],
+ ):
+ self._lookup_fn = lookup_fn
+
+ def search(self, search: str) -> Document:
+ """Search for a document.
+
+ Args:
+ search: search string
+
+ Returns:
+ Document if found, else error message.
+ """
+ r = self._lookup_fn(search)
+ if isinstance(r, str):
+ # NOTE: assume the search string is the source ID
+ return Document(page_content=r, metadata={"source": search})
+ elif isinstance(r, Document):
+ return r
+ raise ValueError(f"Unexpected type of document {type(r)}")
diff --git a/libs/community/langchain_community/docstore/base.py b/libs/community/langchain_community/docstore/base.py
new file mode 100644
index 00000000000..709c443f01f
--- /dev/null
+++ b/libs/community/langchain_community/docstore/base.py
@@ -0,0 +1,29 @@
+"""Interface to access to place that stores documents."""
+from abc import ABC, abstractmethod
+from typing import Dict, List, Union
+
+from langchain_core.documents import Document
+
+
+class Docstore(ABC):
+ """Interface to access to place that stores documents."""
+
+ @abstractmethod
+ def search(self, search: str) -> Union[str, Document]:
+ """Search for document.
+
+ If page exists, return the page summary, and a Document object.
+ If page does not exist, return similar entries.
+ """
+
+ def delete(self, ids: List) -> None:
+ """Deleting IDs from in memory dictionary."""
+ raise NotImplementedError
+
+
+class AddableMixin(ABC):
+ """Mixin class that supports adding texts."""
+
+ @abstractmethod
+ def add(self, texts: Dict[str, Document]) -> None:
+ """Add more documents."""
diff --git a/libs/community/langchain_community/docstore/document.py b/libs/community/langchain_community/docstore/document.py
new file mode 100644
index 00000000000..88aebd27950
--- /dev/null
+++ b/libs/community/langchain_community/docstore/document.py
@@ -0,0 +1,3 @@
+from langchain_core.documents import Document
+
+__all__ = ["Document"]
diff --git a/libs/community/langchain_community/docstore/in_memory.py b/libs/community/langchain_community/docstore/in_memory.py
new file mode 100644
index 00000000000..a52204d0b8f
--- /dev/null
+++ b/libs/community/langchain_community/docstore/in_memory.py
@@ -0,0 +1,50 @@
+"""Simple in memory docstore in the form of a dict."""
+from typing import Dict, List, Optional, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.docstore.base import AddableMixin, Docstore
+
+
+class InMemoryDocstore(Docstore, AddableMixin):
+ """Simple in memory docstore in the form of a dict."""
+
+ def __init__(self, _dict: Optional[Dict[str, Document]] = None):
+ """Initialize with dict."""
+ self._dict = _dict if _dict is not None else {}
+
+ def add(self, texts: Dict[str, Document]) -> None:
+ """Add texts to in memory dictionary.
+
+ Args:
+ texts: dictionary of id -> document.
+
+ Returns:
+ None
+ """
+ overlapping = set(texts).intersection(self._dict)
+ if overlapping:
+ raise ValueError(f"Tried to add ids that already exist: {overlapping}")
+ self._dict = {**self._dict, **texts}
+
+ def delete(self, ids: List) -> None:
+ """Deleting IDs from in memory dictionary."""
+ overlapping = set(ids).intersection(self._dict)
+ if not overlapping:
+ raise ValueError(f"Tried to delete ids that does not exist: {ids}")
+ for _id in ids:
+ self._dict.pop(_id)
+
+ def search(self, search: str) -> Union[str, Document]:
+ """Search via direct lookup.
+
+ Args:
+ search: id of a document to search for.
+
+ Returns:
+ Document if found, else error message.
+ """
+ if search not in self._dict:
+ return f"ID {search} not found."
+ else:
+ return self._dict[search]
diff --git a/libs/community/langchain_community/docstore/wikipedia.py b/libs/community/langchain_community/docstore/wikipedia.py
new file mode 100644
index 00000000000..cc7b6ae20e9
--- /dev/null
+++ b/libs/community/langchain_community/docstore/wikipedia.py
@@ -0,0 +1,47 @@
+"""Wrapper around wikipedia API."""
+
+
+from typing import Union
+
+from langchain_core.documents import Document
+
+from langchain_community.docstore.base import Docstore
+
+
+class Wikipedia(Docstore):
+ """Wrapper around wikipedia API."""
+
+ def __init__(self) -> None:
+ """Check that wikipedia package is installed."""
+ try:
+ import wikipedia # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import wikipedia python package. "
+ "Please install it with `pip install wikipedia`."
+ )
+
+ def search(self, search: str) -> Union[str, Document]:
+ """Try to search for wiki page.
+
+ If page exists, return the page summary, and a PageWithLookups object.
+ If page does not exist, return similar entries.
+
+ Args:
+ search: search string.
+
+ Returns: a Document object or error message.
+ """
+ import wikipedia
+
+ try:
+ page_content = wikipedia.page(search).content
+ url = wikipedia.page(search).url
+ result: Union[str, Document] = Document(
+ page_content=page_content, metadata={"page": url}
+ )
+ except wikipedia.PageError:
+ result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
+ except wikipedia.DisambiguationError:
+ result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
+ return result
diff --git a/libs/community/langchain_community/document_loaders/__init__.py b/libs/community/langchain_community/document_loaders/__init__.py
new file mode 100644
index 00000000000..48143a30b66
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/__init__.py
@@ -0,0 +1,396 @@
+"""**Document Loaders** are classes to load Documents.
+
+**Document Loaders** are usually used to load a lot of Documents in a single run.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseLoader --> Loader # Examples: TextLoader, UnstructuredFileLoader
+
+**Main helpers:**
+
+.. code-block::
+
+ Document, TextSplitter
+"""
+
+from langchain_community.document_loaders.acreom import AcreomLoader
+from langchain_community.document_loaders.airbyte import (
+ AirbyteCDKLoader,
+ AirbyteGongLoader,
+ AirbyteHubspotLoader,
+ AirbyteSalesforceLoader,
+ AirbyteShopifyLoader,
+ AirbyteStripeLoader,
+ AirbyteTypeformLoader,
+ AirbyteZendeskSupportLoader,
+)
+from langchain_community.document_loaders.airbyte_json import AirbyteJSONLoader
+from langchain_community.document_loaders.airtable import AirtableLoader
+from langchain_community.document_loaders.apify_dataset import ApifyDatasetLoader
+from langchain_community.document_loaders.arcgis_loader import ArcGISLoader
+from langchain_community.document_loaders.arxiv import ArxivLoader
+from langchain_community.document_loaders.assemblyai import (
+ AssemblyAIAudioTranscriptLoader,
+)
+from langchain_community.document_loaders.async_html import AsyncHtmlLoader
+from langchain_community.document_loaders.azlyrics import AZLyricsLoader
+from langchain_community.document_loaders.azure_ai_data import (
+ AzureAIDataLoader,
+)
+from langchain_community.document_loaders.azure_blob_storage_container import (
+ AzureBlobStorageContainerLoader,
+)
+from langchain_community.document_loaders.azure_blob_storage_file import (
+ AzureBlobStorageFileLoader,
+)
+from langchain_community.document_loaders.bibtex import BibtexLoader
+from langchain_community.document_loaders.bigquery import BigQueryLoader
+from langchain_community.document_loaders.bilibili import BiliBiliLoader
+from langchain_community.document_loaders.blackboard import BlackboardLoader
+from langchain_community.document_loaders.blob_loaders import (
+ Blob,
+ BlobLoader,
+ FileSystemBlobLoader,
+ YoutubeAudioLoader,
+)
+from langchain_community.document_loaders.blockchain import BlockchainDocumentLoader
+from langchain_community.document_loaders.brave_search import BraveSearchLoader
+from langchain_community.document_loaders.browserless import BrowserlessLoader
+from langchain_community.document_loaders.chatgpt import ChatGPTLoader
+from langchain_community.document_loaders.chromium import AsyncChromiumLoader
+from langchain_community.document_loaders.college_confidential import (
+ CollegeConfidentialLoader,
+)
+from langchain_community.document_loaders.concurrent import ConcurrentLoader
+from langchain_community.document_loaders.confluence import ConfluenceLoader
+from langchain_community.document_loaders.conllu import CoNLLULoader
+from langchain_community.document_loaders.couchbase import CouchbaseLoader
+from langchain_community.document_loaders.csv_loader import (
+ CSVLoader,
+ UnstructuredCSVLoader,
+)
+from langchain_community.document_loaders.cube_semantic import CubeSemanticLoader
+from langchain_community.document_loaders.datadog_logs import DatadogLogsLoader
+from langchain_community.document_loaders.dataframe import DataFrameLoader
+from langchain_community.document_loaders.diffbot import DiffbotLoader
+from langchain_community.document_loaders.directory import DirectoryLoader
+from langchain_community.document_loaders.discord import DiscordChatLoader
+from langchain_community.document_loaders.docugami import DocugamiLoader
+from langchain_community.document_loaders.docusaurus import DocusaurusLoader
+from langchain_community.document_loaders.dropbox import DropboxLoader
+from langchain_community.document_loaders.duckdb_loader import DuckDBLoader
+from langchain_community.document_loaders.email import (
+ OutlookMessageLoader,
+ UnstructuredEmailLoader,
+)
+from langchain_community.document_loaders.epub import UnstructuredEPubLoader
+from langchain_community.document_loaders.etherscan import EtherscanLoader
+from langchain_community.document_loaders.evernote import EverNoteLoader
+from langchain_community.document_loaders.excel import UnstructuredExcelLoader
+from langchain_community.document_loaders.facebook_chat import FacebookChatLoader
+from langchain_community.document_loaders.fauna import FaunaLoader
+from langchain_community.document_loaders.figma import FigmaFileLoader
+from langchain_community.document_loaders.gcs_directory import GCSDirectoryLoader
+from langchain_community.document_loaders.gcs_file import GCSFileLoader
+from langchain_community.document_loaders.geodataframe import GeoDataFrameLoader
+from langchain_community.document_loaders.git import GitLoader
+from langchain_community.document_loaders.gitbook import GitbookLoader
+from langchain_community.document_loaders.github import GitHubIssuesLoader
+from langchain_community.document_loaders.google_speech_to_text import (
+ GoogleSpeechToTextLoader,
+)
+from langchain_community.document_loaders.googledrive import GoogleDriveLoader
+from langchain_community.document_loaders.gutenberg import GutenbergLoader
+from langchain_community.document_loaders.hn import HNLoader
+from langchain_community.document_loaders.html import UnstructuredHTMLLoader
+from langchain_community.document_loaders.html_bs import BSHTMLLoader
+from langchain_community.document_loaders.hugging_face_dataset import (
+ HuggingFaceDatasetLoader,
+)
+from langchain_community.document_loaders.ifixit import IFixitLoader
+from langchain_community.document_loaders.image import UnstructuredImageLoader
+from langchain_community.document_loaders.image_captions import ImageCaptionLoader
+from langchain_community.document_loaders.imsdb import IMSDbLoader
+from langchain_community.document_loaders.iugu import IuguLoader
+from langchain_community.document_loaders.joplin import JoplinLoader
+from langchain_community.document_loaders.json_loader import JSONLoader
+from langchain_community.document_loaders.lakefs import LakeFSLoader
+from langchain_community.document_loaders.larksuite import LarkSuiteDocLoader
+from langchain_community.document_loaders.markdown import UnstructuredMarkdownLoader
+from langchain_community.document_loaders.mastodon import MastodonTootsLoader
+from langchain_community.document_loaders.max_compute import MaxComputeLoader
+from langchain_community.document_loaders.mediawikidump import MWDumpLoader
+from langchain_community.document_loaders.merge import MergedDataLoader
+from langchain_community.document_loaders.mhtml import MHTMLLoader
+from langchain_community.document_loaders.modern_treasury import ModernTreasuryLoader
+from langchain_community.document_loaders.mongodb import MongodbLoader
+from langchain_community.document_loaders.news import NewsURLLoader
+from langchain_community.document_loaders.notebook import NotebookLoader
+from langchain_community.document_loaders.notion import NotionDirectoryLoader
+from langchain_community.document_loaders.notiondb import NotionDBLoader
+from langchain_community.document_loaders.obs_directory import OBSDirectoryLoader
+from langchain_community.document_loaders.obs_file import OBSFileLoader
+from langchain_community.document_loaders.obsidian import ObsidianLoader
+from langchain_community.document_loaders.odt import UnstructuredODTLoader
+from langchain_community.document_loaders.onedrive import OneDriveLoader
+from langchain_community.document_loaders.onedrive_file import OneDriveFileLoader
+from langchain_community.document_loaders.open_city_data import OpenCityDataLoader
+from langchain_community.document_loaders.org_mode import UnstructuredOrgModeLoader
+from langchain_community.document_loaders.pdf import (
+ AmazonTextractPDFLoader,
+ MathpixPDFLoader,
+ OnlinePDFLoader,
+ PDFMinerLoader,
+ PDFMinerPDFasHTMLLoader,
+ PDFPlumberLoader,
+ PyMuPDFLoader,
+ PyPDFDirectoryLoader,
+ PyPDFium2Loader,
+ PyPDFLoader,
+ UnstructuredPDFLoader,
+)
+from langchain_community.document_loaders.polars_dataframe import PolarsDataFrameLoader
+from langchain_community.document_loaders.powerpoint import UnstructuredPowerPointLoader
+from langchain_community.document_loaders.psychic import PsychicLoader
+from langchain_community.document_loaders.pubmed import PubMedLoader
+from langchain_community.document_loaders.pyspark_dataframe import (
+ PySparkDataFrameLoader,
+)
+from langchain_community.document_loaders.python import PythonLoader
+from langchain_community.document_loaders.readthedocs import ReadTheDocsLoader
+from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
+from langchain_community.document_loaders.reddit import RedditPostsLoader
+from langchain_community.document_loaders.roam import RoamLoader
+from langchain_community.document_loaders.rocksetdb import RocksetLoader
+from langchain_community.document_loaders.rss import RSSFeedLoader
+from langchain_community.document_loaders.rst import UnstructuredRSTLoader
+from langchain_community.document_loaders.rtf import UnstructuredRTFLoader
+from langchain_community.document_loaders.s3_directory import S3DirectoryLoader
+from langchain_community.document_loaders.s3_file import S3FileLoader
+from langchain_community.document_loaders.sharepoint import SharePointLoader
+from langchain_community.document_loaders.sitemap import SitemapLoader
+from langchain_community.document_loaders.slack_directory import SlackDirectoryLoader
+from langchain_community.document_loaders.snowflake_loader import SnowflakeLoader
+from langchain_community.document_loaders.spreedly import SpreedlyLoader
+from langchain_community.document_loaders.srt import SRTLoader
+from langchain_community.document_loaders.stripe import StripeLoader
+from langchain_community.document_loaders.telegram import (
+ TelegramChatApiLoader,
+ TelegramChatFileLoader,
+)
+from langchain_community.document_loaders.tencent_cos_directory import (
+ TencentCOSDirectoryLoader,
+)
+from langchain_community.document_loaders.tencent_cos_file import TencentCOSFileLoader
+from langchain_community.document_loaders.tensorflow_datasets import (
+ TensorflowDatasetLoader,
+)
+from langchain_community.document_loaders.text import TextLoader
+from langchain_community.document_loaders.tomarkdown import ToMarkdownLoader
+from langchain_community.document_loaders.toml import TomlLoader
+from langchain_community.document_loaders.trello import TrelloLoader
+from langchain_community.document_loaders.tsv import UnstructuredTSVLoader
+from langchain_community.document_loaders.twitter import TwitterTweetLoader
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredAPIFileIOLoader,
+ UnstructuredAPIFileLoader,
+ UnstructuredFileIOLoader,
+ UnstructuredFileLoader,
+)
+from langchain_community.document_loaders.url import UnstructuredURLLoader
+from langchain_community.document_loaders.url_playwright import PlaywrightURLLoader
+from langchain_community.document_loaders.url_selenium import SeleniumURLLoader
+from langchain_community.document_loaders.weather import WeatherDataLoader
+from langchain_community.document_loaders.web_base import WebBaseLoader
+from langchain_community.document_loaders.whatsapp_chat import WhatsAppChatLoader
+from langchain_community.document_loaders.wikipedia import WikipediaLoader
+from langchain_community.document_loaders.word_document import (
+ Docx2txtLoader,
+ UnstructuredWordDocumentLoader,
+)
+from langchain_community.document_loaders.xml import UnstructuredXMLLoader
+from langchain_community.document_loaders.xorbits import XorbitsLoader
+from langchain_community.document_loaders.youtube import (
+ GoogleApiClient,
+ GoogleApiYoutubeLoader,
+ YoutubeLoader,
+)
+
+# Legacy: only for backwards compatibility. Use PyPDFLoader instead
+PagedPDFSplitter = PyPDFLoader
+
+# For backwards compatibility
+TelegramChatLoader = TelegramChatFileLoader
+
+__all__ = [
+ "AcreomLoader",
+ "AsyncHtmlLoader",
+ "AsyncChromiumLoader",
+ "AZLyricsLoader",
+ "AcreomLoader",
+ "AirbyteCDKLoader",
+ "AirbyteGongLoader",
+ "AirbyteJSONLoader",
+ "AirbyteHubspotLoader",
+ "AirbyteSalesforceLoader",
+ "AirbyteShopifyLoader",
+ "AirbyteStripeLoader",
+ "AirbyteTypeformLoader",
+ "AirbyteZendeskSupportLoader",
+ "AirtableLoader",
+ "AmazonTextractPDFLoader",
+ "ApifyDatasetLoader",
+ "ArcGISLoader",
+ "ArxivLoader",
+ "AssemblyAIAudioTranscriptLoader",
+ "AsyncHtmlLoader",
+ "AzureAIDataLoader",
+ "AzureBlobStorageContainerLoader",
+ "AzureBlobStorageFileLoader",
+ "BSHTMLLoader",
+ "BibtexLoader",
+ "BigQueryLoader",
+ "BiliBiliLoader",
+ "BlackboardLoader",
+ "Blob",
+ "BlobLoader",
+ "BlockchainDocumentLoader",
+ "BraveSearchLoader",
+ "BrowserlessLoader",
+ "CSVLoader",
+ "ChatGPTLoader",
+ "CoNLLULoader",
+ "CollegeConfidentialLoader",
+ "ConcurrentLoader",
+ "ConfluenceLoader",
+ "CouchbaseLoader",
+ "CubeSemanticLoader",
+ "DataFrameLoader",
+ "DatadogLogsLoader",
+ "DiffbotLoader",
+ "DirectoryLoader",
+ "DiscordChatLoader",
+ "DocugamiLoader",
+ "DocusaurusLoader",
+ "Docx2txtLoader",
+ "DropboxLoader",
+ "DuckDBLoader",
+ "EtherscanLoader",
+ "EverNoteLoader",
+ "FacebookChatLoader",
+ "FaunaLoader",
+ "FigmaFileLoader",
+ "FileSystemBlobLoader",
+ "GCSDirectoryLoader",
+ "GCSFileLoader",
+ "GeoDataFrameLoader",
+ "GitHubIssuesLoader",
+ "GitLoader",
+ "GitbookLoader",
+ "GoogleApiClient",
+ "GoogleApiYoutubeLoader",
+ "GoogleSpeechToTextLoader",
+ "GoogleDriveLoader",
+ "GutenbergLoader",
+ "HNLoader",
+ "HuggingFaceDatasetLoader",
+ "IFixitLoader",
+ "IMSDbLoader",
+ "ImageCaptionLoader",
+ "IuguLoader",
+ "JSONLoader",
+ "JoplinLoader",
+ "LarkSuiteDocLoader",
+ "LakeFSLoader",
+ "MHTMLLoader",
+ "MWDumpLoader",
+ "MastodonTootsLoader",
+ "MathpixPDFLoader",
+ "MaxComputeLoader",
+ "MergedDataLoader",
+ "ModernTreasuryLoader",
+ "MongodbLoader",
+ "NewsURLLoader",
+ "NotebookLoader",
+ "NotionDBLoader",
+ "NotionDirectoryLoader",
+ "OBSDirectoryLoader",
+ "OBSFileLoader",
+ "ObsidianLoader",
+ "OneDriveFileLoader",
+ "OneDriveLoader",
+ "OnlinePDFLoader",
+ "OpenCityDataLoader",
+ "OutlookMessageLoader",
+ "PDFMinerLoader",
+ "PDFMinerPDFasHTMLLoader",
+ "PDFPlumberLoader",
+ "PagedPDFSplitter",
+ "PlaywrightURLLoader",
+ "PolarsDataFrameLoader",
+ "PsychicLoader",
+ "PubMedLoader",
+ "PyMuPDFLoader",
+ "PyPDFDirectoryLoader",
+ "PyPDFLoader",
+ "PyPDFium2Loader",
+ "PySparkDataFrameLoader",
+ "PythonLoader",
+ "RSSFeedLoader",
+ "ReadTheDocsLoader",
+ "RecursiveUrlLoader",
+ "RedditPostsLoader",
+ "RoamLoader",
+ "RocksetLoader",
+ "S3DirectoryLoader",
+ "S3FileLoader",
+ "SRTLoader",
+ "SeleniumURLLoader",
+ "SharePointLoader",
+ "SitemapLoader",
+ "SlackDirectoryLoader",
+ "SnowflakeLoader",
+ "SpreedlyLoader",
+ "StripeLoader",
+ "TelegramChatApiLoader",
+ "TelegramChatFileLoader",
+ "TelegramChatLoader",
+ "TensorflowDatasetLoader",
+ "TencentCOSDirectoryLoader",
+ "TencentCOSFileLoader",
+ "TextLoader",
+ "ToMarkdownLoader",
+ "TomlLoader",
+ "TrelloLoader",
+ "TwitterTweetLoader",
+ "UnstructuredAPIFileIOLoader",
+ "UnstructuredAPIFileLoader",
+ "UnstructuredCSVLoader",
+ "UnstructuredEPubLoader",
+ "UnstructuredEmailLoader",
+ "UnstructuredExcelLoader",
+ "UnstructuredFileIOLoader",
+ "UnstructuredFileLoader",
+ "UnstructuredHTMLLoader",
+ "UnstructuredImageLoader",
+ "UnstructuredMarkdownLoader",
+ "UnstructuredODTLoader",
+ "UnstructuredOrgModeLoader",
+ "UnstructuredPDFLoader",
+ "UnstructuredPowerPointLoader",
+ "UnstructuredRSTLoader",
+ "UnstructuredRTFLoader",
+ "UnstructuredTSVLoader",
+ "UnstructuredURLLoader",
+ "UnstructuredWordDocumentLoader",
+ "UnstructuredXMLLoader",
+ "WeatherDataLoader",
+ "WebBaseLoader",
+ "WhatsAppChatLoader",
+ "WikipediaLoader",
+ "XorbitsLoader",
+ "YoutubeAudioLoader",
+ "YoutubeLoader",
+]
diff --git a/libs/community/langchain_community/document_loaders/acreom.py b/libs/community/langchain_community/document_loaders/acreom.py
new file mode 100644
index 00000000000..618883270da
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/acreom.py
@@ -0,0 +1,79 @@
+import re
+from pathlib import Path
+from typing import Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class AcreomLoader(BaseLoader):
+ """Load `acreom` vault from a directory."""
+
+ FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL)
+ """Regex to match front matter metadata in markdown files."""
+
+ def __init__(
+ self, path: str, encoding: str = "UTF-8", collect_metadata: bool = True
+ ):
+ """Initialize the loader."""
+ self.file_path = path
+ """Path to the directory containing the markdown files."""
+ self.encoding = encoding
+ """Encoding to use when reading the files."""
+ self.collect_metadata = collect_metadata
+ """Whether to collect metadata from the front matter."""
+
+ def _parse_front_matter(self, content: str) -> dict:
+ """Parse front matter metadata from the content and return it as a dict."""
+ if not self.collect_metadata:
+ return {}
+ match = self.FRONT_MATTER_REGEX.search(content)
+ front_matter = {}
+ if match:
+ lines = match.group(1).split("\n")
+ for line in lines:
+ if ":" in line:
+ key, value = line.split(":", 1)
+ front_matter[key.strip()] = value.strip()
+ else:
+ # Skip lines without a colon
+ continue
+ return front_matter
+
+ def _remove_front_matter(self, content: str) -> str:
+ """Remove front matter metadata from the given content."""
+ if not self.collect_metadata:
+ return content
+ return self.FRONT_MATTER_REGEX.sub("", content)
+
+ def _process_acreom_content(self, content: str) -> str:
+ # remove acreom specific elements from content that
+ # do not contribute to the context of current document
+ content = re.sub(r"\s*-\s\[\s\]\s.*|\s*\[\s\]\s.*", "", content) # rm tasks
+ content = re.sub(r"#", "", content) # rm hashtags
+ content = re.sub(r"\[\[.*?\]\]", "", content) # rm doclinks
+ return content
+
+ def lazy_load(self) -> Iterator[Document]:
+ ps = list(Path(self.file_path).glob("**/*.md"))
+
+ for p in ps:
+ with open(p, encoding=self.encoding) as f:
+ text = f.read()
+
+ front_matter = self._parse_front_matter(text)
+ text = self._remove_front_matter(text)
+
+ text = self._process_acreom_content(text)
+
+ metadata = {
+ "source": str(p.name),
+ "path": str(p),
+ **front_matter,
+ }
+
+ yield Document(page_content=text, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/airbyte.py b/libs/community/langchain_community/document_loaders/airbyte.py
new file mode 100644
index 00000000000..a9609c55b64
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/airbyte.py
@@ -0,0 +1,290 @@
+from typing import Any, Callable, Iterator, List, Mapping, Optional
+
+from langchain_core.documents import Document
+from langchain_core.utils.utils import guard_import
+
+from langchain_community.document_loaders.base import BaseLoader
+
+RecordHandler = Callable[[Any, Optional[str]], Document]
+
+
+class AirbyteCDKLoader(BaseLoader):
+ """Load with an `Airbyte` source connector implemented using the `CDK`."""
+
+ def __init__(
+ self,
+ config: Mapping[str, Any],
+ source_class: Any,
+ stream_name: str,
+ record_handler: Optional[RecordHandler] = None,
+ state: Optional[Any] = None,
+ ) -> None:
+ """Initializes the loader.
+
+ Args:
+ config: The config to pass to the source connector.
+ source_class: The source connector class.
+ stream_name: The name of the stream to load.
+ record_handler: A function that takes in a record and an optional id and
+ returns a Document. If None, the record will be used as the document.
+ Defaults to None.
+ state: The state to pass to the source connector. Defaults to None.
+ """
+ from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage
+ from airbyte_cdk.sources.embedded.base_integration import (
+ BaseEmbeddedIntegration,
+ )
+ from airbyte_cdk.sources.embedded.runner import CDKRunner
+
+ class CDKIntegration(BaseEmbeddedIntegration):
+ """A wrapper around the CDK integration."""
+
+ def _handle_record(
+ self, record: AirbyteRecordMessage, id: Optional[str]
+ ) -> Document:
+ if record_handler:
+ return record_handler(record, id)
+ return Document(page_content="", metadata=record.data)
+
+ self._integration = CDKIntegration(
+ config=config,
+ runner=CDKRunner(source=source_class(), name=source_class.__name__),
+ )
+ self._stream_name = stream_name
+ self._state = state
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ return self._integration._load_data(
+ stream_name=self._stream_name, state=self._state
+ )
+
+ @property
+ def last_state(self) -> Any:
+ return self._integration.last_state
+
+
+class AirbyteHubspotLoader(AirbyteCDKLoader):
+ """Load from `Hubspot` using an `Airbyte` source connector."""
+
+ def __init__(
+ self,
+ config: Mapping[str, Any],
+ stream_name: str,
+ record_handler: Optional[RecordHandler] = None,
+ state: Optional[Any] = None,
+ ) -> None:
+ """Initializes the loader.
+
+ Args:
+ config: The config to pass to the source connector.
+ stream_name: The name of the stream to load.
+ record_handler: A function that takes in a record and an optional id and
+ returns a Document. If None, the record will be used as the document.
+ Defaults to None.
+ state: The state to pass to the source connector. Defaults to None.
+ """
+ source_class = guard_import(
+ "source_hubspot", pip_name="airbyte-source-hubspot"
+ ).SourceHubspot
+ super().__init__(
+ config=config,
+ source_class=source_class,
+ stream_name=stream_name,
+ record_handler=record_handler,
+ state=state,
+ )
+
+
+class AirbyteStripeLoader(AirbyteCDKLoader):
+ """Load from `Stripe` using an `Airbyte` source connector."""
+
+ def __init__(
+ self,
+ config: Mapping[str, Any],
+ stream_name: str,
+ record_handler: Optional[RecordHandler] = None,
+ state: Optional[Any] = None,
+ ) -> None:
+ """Initializes the loader.
+
+ Args:
+ config: The config to pass to the source connector.
+ stream_name: The name of the stream to load.
+ record_handler: A function that takes in a record and an optional id and
+ returns a Document. If None, the record will be used as the document.
+ Defaults to None.
+ state: The state to pass to the source connector. Defaults to None.
+ """
+ source_class = guard_import(
+ "source_stripe", pip_name="airbyte-source-stripe"
+ ).SourceStripe
+ super().__init__(
+ config=config,
+ source_class=source_class,
+ stream_name=stream_name,
+ record_handler=record_handler,
+ state=state,
+ )
+
+
+class AirbyteTypeformLoader(AirbyteCDKLoader):
+ """Load from `Typeform` using an `Airbyte` source connector."""
+
+ def __init__(
+ self,
+ config: Mapping[str, Any],
+ stream_name: str,
+ record_handler: Optional[RecordHandler] = None,
+ state: Optional[Any] = None,
+ ) -> None:
+ """Initializes the loader.
+
+ Args:
+ config: The config to pass to the source connector.
+ stream_name: The name of the stream to load.
+ record_handler: A function that takes in a record and an optional id and
+ returns a Document. If None, the record will be used as the document.
+ Defaults to None.
+ state: The state to pass to the source connector. Defaults to None.
+ """
+ source_class = guard_import(
+ "source_typeform", pip_name="airbyte-source-typeform"
+ ).SourceTypeform
+ super().__init__(
+ config=config,
+ source_class=source_class,
+ stream_name=stream_name,
+ record_handler=record_handler,
+ state=state,
+ )
+
+
+class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
+ """Load from `Zendesk Support` using an `Airbyte` source connector."""
+
+ def __init__(
+ self,
+ config: Mapping[str, Any],
+ stream_name: str,
+ record_handler: Optional[RecordHandler] = None,
+ state: Optional[Any] = None,
+ ) -> None:
+ """Initializes the loader.
+
+ Args:
+ config: The config to pass to the source connector.
+ stream_name: The name of the stream to load.
+ record_handler: A function that takes in a record and an optional id and
+ returns a Document. If None, the record will be used as the document.
+ Defaults to None.
+ state: The state to pass to the source connector. Defaults to None.
+ """
+ source_class = guard_import(
+ "source_zendesk_support", pip_name="airbyte-source-zendesk-support"
+ ).SourceZendeskSupport
+ super().__init__(
+ config=config,
+ source_class=source_class,
+ stream_name=stream_name,
+ record_handler=record_handler,
+ state=state,
+ )
+
+
+class AirbyteShopifyLoader(AirbyteCDKLoader):
+ """Load from `Shopify` using an `Airbyte` source connector."""
+
+ def __init__(
+ self,
+ config: Mapping[str, Any],
+ stream_name: str,
+ record_handler: Optional[RecordHandler] = None,
+ state: Optional[Any] = None,
+ ) -> None:
+ """Initializes the loader.
+
+ Args:
+ config: The config to pass to the source connector.
+ stream_name: The name of the stream to load.
+ record_handler: A function that takes in a record and an optional id and
+ returns a Document. If None, the record will be used as the document.
+ Defaults to None.
+ state: The state to pass to the source connector. Defaults to None.
+ """
+ source_class = guard_import(
+ "source_shopify", pip_name="airbyte-source-shopify"
+ ).SourceShopify
+ super().__init__(
+ config=config,
+ source_class=source_class,
+ stream_name=stream_name,
+ record_handler=record_handler,
+ state=state,
+ )
+
+
+class AirbyteSalesforceLoader(AirbyteCDKLoader):
+ """Load from `Salesforce` using an `Airbyte` source connector."""
+
+ def __init__(
+ self,
+ config: Mapping[str, Any],
+ stream_name: str,
+ record_handler: Optional[RecordHandler] = None,
+ state: Optional[Any] = None,
+ ) -> None:
+ """Initializes the loader.
+
+ Args:
+ config: The config to pass to the source connector.
+ stream_name: The name of the stream to load.
+ record_handler: A function that takes in a record and an optional id and
+ returns a Document. If None, the record will be used as the document.
+ Defaults to None.
+ state: The state to pass to the source connector. Defaults to None.
+ """
+ source_class = guard_import(
+ "source_salesforce", pip_name="airbyte-source-salesforce"
+ ).SourceSalesforce
+ super().__init__(
+ config=config,
+ source_class=source_class,
+ stream_name=stream_name,
+ record_handler=record_handler,
+ state=state,
+ )
+
+
+class AirbyteGongLoader(AirbyteCDKLoader):
+ """Load from `Gong` using an `Airbyte` source connector."""
+
+ def __init__(
+ self,
+ config: Mapping[str, Any],
+ stream_name: str,
+ record_handler: Optional[RecordHandler] = None,
+ state: Optional[Any] = None,
+ ) -> None:
+ """Initializes the loader.
+
+ Args:
+ config: The config to pass to the source connector.
+ stream_name: The name of the stream to load.
+ record_handler: A function that takes in a record and an optional id and
+ returns a Document. If None, the record will be used as the document.
+ Defaults to None.
+ state: The state to pass to the source connector. Defaults to None.
+ """
+ source_class = guard_import(
+ "source_gong", pip_name="airbyte-source-gong"
+ ).SourceGong
+ super().__init__(
+ config=config,
+ source_class=source_class,
+ stream_name=stream_name,
+ record_handler=record_handler,
+ state=state,
+ )
diff --git a/libs/community/langchain_community/document_loaders/airbyte_json.py b/libs/community/langchain_community/document_loaders/airbyte_json.py
new file mode 100644
index 00000000000..b3a0e2fc0cc
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/airbyte_json.py
@@ -0,0 +1,24 @@
+import json
+from typing import List
+
+from langchain_core.documents import Document
+from langchain_core.utils import stringify_dict
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class AirbyteJSONLoader(BaseLoader):
+ """Load local `Airbyte` json files."""
+
+ def __init__(self, file_path: str):
+ """Initialize with a file path. This should start with '/tmp/airbyte_local/'."""
+ self.file_path = file_path
+ """Path to the directory containing the json files."""
+
+ def load(self) -> List[Document]:
+ text = ""
+ for line in open(self.file_path, "r"):
+ data = json.loads(line)["_airbyte_data"]
+ text += stringify_dict(data)
+ metadata = {"source": self.file_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/airtable.py b/libs/community/langchain_community/document_loaders/airtable.py
new file mode 100644
index 00000000000..1ca7fb1de92
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/airtable.py
@@ -0,0 +1,40 @@
+from typing import Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class AirtableLoader(BaseLoader):
+ """Load the `Airtable` tables."""
+
+ def __init__(self, api_token: str, table_id: str, base_id: str):
+ """Initialize with API token and the IDs for table and base"""
+ self.api_token = api_token
+ """Airtable API token."""
+ self.table_id = table_id
+ """Airtable table ID."""
+ self.base_id = base_id
+ """Airtable base ID."""
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load Documents from table."""
+
+ from pyairtable import Table
+
+ table = Table(self.api_token, self.base_id, self.table_id)
+ records = table.all()
+ for record in records:
+ # Need to convert record from dict to str
+ yield Document(
+ page_content=str(record),
+ metadata={
+ "source": self.base_id + "_" + self.table_id,
+ "base_id": self.base_id,
+ "table_id": self.table_id,
+ },
+ )
+
+ def load(self) -> List[Document]:
+ """Load Documents from table."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/apify_dataset.py b/libs/community/langchain_community/document_loaders/apify_dataset.py
new file mode 100644
index 00000000000..1d7d7ce1863
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/apify_dataset.py
@@ -0,0 +1,77 @@
+from typing import Any, Callable, Dict, List
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class ApifyDatasetLoader(BaseLoader, BaseModel):
+ """Load datasets from `Apify` web scraping, crawling, and data extraction platform.
+
+ For details, see https://docs.apify.com/platform/integrations/langchain
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import ApifyDatasetLoader
+ from langchain_core.documents import Document
+
+ loader = ApifyDatasetLoader(
+ dataset_id="YOUR-DATASET-ID",
+ dataset_mapping_function=lambda dataset_item: Document(
+ page_content=dataset_item["text"], metadata={"source": dataset_item["url"]}
+ ),
+ )
+ documents = loader.load()
+ """ # noqa: E501
+
+ apify_client: Any
+ """An instance of the ApifyClient class from the apify-client Python package."""
+ dataset_id: str
+ """The ID of the dataset on the Apify platform."""
+ dataset_mapping_function: Callable[[Dict], Document]
+ """A custom function that takes a single dictionary (an Apify dataset item)
+ and converts it to an instance of the Document class."""
+
+ def __init__(
+ self, dataset_id: str, dataset_mapping_function: Callable[[Dict], Document]
+ ):
+ """Initialize the loader with an Apify dataset ID and a mapping function.
+
+ Args:
+ dataset_id (str): The ID of the dataset on the Apify platform.
+ dataset_mapping_function (Callable): A function that takes a single
+ dictionary (an Apify dataset item) and converts it to an instance
+ of the Document class.
+ """
+ super().__init__(
+ dataset_id=dataset_id, dataset_mapping_function=dataset_mapping_function
+ )
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate environment.
+
+ Args:
+ values: The values to validate.
+ """
+
+ try:
+ from apify_client import ApifyClient
+
+ values["apify_client"] = ApifyClient()
+ except ImportError:
+ raise ImportError(
+ "Could not import apify-client Python package. "
+ "Please install it with `pip install apify-client`."
+ )
+
+ return values
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ dataset_items = (
+ self.apify_client.dataset(self.dataset_id).list_items(clean=True).items
+ )
+ return list(map(self.dataset_mapping_function, dataset_items))
diff --git a/libs/community/langchain_community/document_loaders/arcgis_loader.py b/libs/community/langchain_community/document_loaders/arcgis_loader.py
new file mode 100644
index 00000000000..4958ef2ce6a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/arcgis_loader.py
@@ -0,0 +1,154 @@
+"""Document Loader for ArcGIS FeatureLayers."""
+
+from __future__ import annotations
+
+import json
+import re
+import warnings
+from datetime import datetime, timezone
+from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ import arcgis
+
+_NOT_PROVIDED = "(Not Provided)"
+
+
+class ArcGISLoader(BaseLoader):
+ """Load records from an ArcGIS FeatureLayer."""
+
+ def __init__(
+ self,
+ layer: Union[str, arcgis.features.FeatureLayer],
+ gis: Optional[arcgis.gis.GIS] = None,
+ where: str = "1=1",
+ out_fields: Optional[Union[List[str], str]] = None,
+ return_geometry: bool = False,
+ result_record_count: Optional[int] = None,
+ lyr_desc: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ try:
+ import arcgis
+ except ImportError as e:
+ raise ImportError(
+ "arcgis is required to use the ArcGIS Loader. "
+ "Install it with pip or conda."
+ ) from e
+
+ try:
+ from bs4 import BeautifulSoup # type: ignore
+
+ self.BEAUTIFULSOUP = BeautifulSoup
+ except ImportError:
+ warnings.warn("BeautifulSoup not found. HTML will not be parsed.")
+ self.BEAUTIFULSOUP = None
+
+ self.gis = gis or arcgis.gis.GIS()
+
+ if isinstance(layer, str):
+ self.url = layer
+ self.layer = arcgis.features.FeatureLayer(layer, gis=gis)
+ else:
+ self.url = layer.url
+ self.layer = layer
+
+ self.layer_properties = self._get_layer_properties(lyr_desc)
+
+ self.where = where
+
+ if isinstance(out_fields, str):
+ self.out_fields = out_fields
+ elif out_fields is None:
+ self.out_fields = "*"
+ else:
+ self.out_fields = ",".join(out_fields)
+
+ self.return_geometry = return_geometry
+
+ self.result_record_count = result_record_count
+ self.return_all_records = not isinstance(result_record_count, int)
+
+ query_params = dict(
+ where=self.where,
+ out_fields=self.out_fields,
+ return_geometry=self.return_geometry,
+ return_all_records=self.return_all_records,
+ result_record_count=self.result_record_count,
+ )
+ query_params.update(kwargs)
+ self.query_params = query_params
+
+ def _get_layer_properties(self, lyr_desc: Optional[str] = None) -> dict:
+ """Get the layer properties from the FeatureLayer."""
+ import arcgis
+
+ layer_number_pattern = re.compile(r"/\d+$")
+ props = self.layer.properties
+
+ if lyr_desc is None:
+ # retrieve description from the FeatureLayer if not provided
+ try:
+ if self.BEAUTIFULSOUP:
+ lyr_desc = self.BEAUTIFULSOUP(props["description"]).text
+ else:
+ lyr_desc = props["description"]
+ lyr_desc = lyr_desc or _NOT_PROVIDED
+ except KeyError:
+ lyr_desc = _NOT_PROVIDED
+ try:
+ item_id = props["serviceItemId"]
+ item = self.gis.content.get(item_id) or arcgis.features.FeatureLayer(
+ re.sub(layer_number_pattern, "", self.url),
+ )
+ try:
+ raw_desc = item.description
+ except AttributeError:
+ raw_desc = item.properties.description
+ if self.BEAUTIFULSOUP:
+ item_desc = self.BEAUTIFULSOUP(raw_desc).text
+ else:
+ item_desc = raw_desc
+ item_desc = item_desc or _NOT_PROVIDED
+ except KeyError:
+ item_desc = _NOT_PROVIDED
+ return {
+ "layer_description": lyr_desc,
+ "item_description": item_desc,
+ "layer_properties": props,
+ }
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load records from FeatureLayer."""
+ query_response = self.layer.query(**self.query_params)
+ features = (feature.as_dict for feature in query_response)
+ for feature in features:
+ attributes = feature["attributes"]
+ page_content = json.dumps(attributes)
+
+ metadata = {
+ "accessed": f"{datetime.now(timezone.utc).isoformat()}Z",
+ "name": self.layer_properties["layer_properties"]["name"],
+ "url": self.url,
+ "layer_description": self.layer_properties["layer_description"],
+ "item_description": self.layer_properties["item_description"],
+ "layer_properties": self.layer_properties["layer_properties"],
+ }
+
+ if self.return_geometry:
+ try:
+ metadata["geometry"] = feature["geometry"]
+ except KeyError:
+ warnings.warn(
+ "Geometry could not be retrieved from the feature layer."
+ )
+
+ yield Document(page_content=page_content, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """Load all records from FeatureLayer."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/arxiv.py b/libs/community/langchain_community/document_loaders/arxiv.py
new file mode 100644
index 00000000000..968d5dcfc34
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/arxiv.py
@@ -0,0 +1,27 @@
+from typing import Any, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.arxiv import ArxivAPIWrapper
+
+
+class ArxivLoader(BaseLoader):
+ """Load a query result from `Arxiv`.
+
+ The loader converts the original PDF format into the text.
+
+ Args:
+ Supports all arguments of `ArxivAPIWrapper`.
+ """
+
+ def __init__(
+ self, query: str, doc_content_chars_max: Optional[int] = None, **kwargs: Any
+ ):
+ self.query = query
+ self.client = ArxivAPIWrapper(
+ doc_content_chars_max=doc_content_chars_max, **kwargs
+ )
+
+ def load(self) -> List[Document]:
+ return self.client.load(self.query)
diff --git a/libs/community/langchain_community/document_loaders/assemblyai.py b/libs/community/langchain_community/document_loaders/assemblyai.py
new file mode 100644
index 00000000000..0dd64256ab2
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/assemblyai.py
@@ -0,0 +1,112 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import TYPE_CHECKING, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ import assemblyai
+
+
+class TranscriptFormat(Enum):
+ """Transcript format to use for the document loader."""
+
+ TEXT = "text"
+ """One document with the transcription text"""
+ SENTENCES = "sentences"
+ """Multiple documents, splits the transcription by each sentence"""
+ PARAGRAPHS = "paragraphs"
+ """Multiple documents, splits the transcription by each paragraph"""
+ SUBTITLES_SRT = "subtitles_srt"
+ """One document with the transcript exported in SRT subtitles format"""
+ SUBTITLES_VTT = "subtitles_vtt"
+ """One document with the transcript exported in VTT subtitles format"""
+
+
+class AssemblyAIAudioTranscriptLoader(BaseLoader):
+ """
+ Loader for AssemblyAI audio transcripts.
+
+ It uses the AssemblyAI API to transcribe audio files
+ and loads the transcribed text into one or more Documents,
+ depending on the specified format.
+
+ To use, you should have the ``assemblyai`` python package installed, and the
+ environment variable ``ASSEMBLYAI_API_KEY`` set with your API key.
+ Alternatively, the API key can also be passed as an argument.
+
+ Audio files can be specified via an URL or a local file path.
+ """
+
+ def __init__(
+ self,
+ file_path: str,
+ *,
+ transcript_format: TranscriptFormat = TranscriptFormat.TEXT,
+ config: Optional[assemblyai.TranscriptionConfig] = None,
+ api_key: Optional[str] = None,
+ ):
+ """
+ Initializes the AssemblyAI AudioTranscriptLoader.
+
+ Args:
+ file_path: An URL or a local file path.
+ transcript_format: Transcript format to use.
+ See class ``TranscriptFormat`` for more info.
+ config: Transcription options and features. If ``None`` is given,
+ the Transcriber's default configuration will be used.
+ api_key: AssemblyAI API key.
+ """
+ try:
+ import assemblyai
+ except ImportError:
+ raise ImportError(
+ "Could not import assemblyai python package. "
+ "Please install it with `pip install assemblyai`."
+ )
+ if api_key is not None:
+ assemblyai.settings.api_key = api_key
+
+ self.file_path = file_path
+ self.transcript_format = transcript_format
+ self.transcriber = assemblyai.Transcriber(config=config)
+
+ def load(self) -> List[Document]:
+ """Transcribes the audio file and loads the transcript into documents.
+
+ It uses the AssemblyAI API to transcribe the audio file and blocks until
+ the transcription is finished.
+ """
+ transcript = self.transcriber.transcribe(self.file_path)
+ # This will raise a ValueError if no API key is set.
+
+ if transcript.error:
+ raise ValueError(f"Could not transcribe file: {transcript.error}")
+
+ if self.transcript_format == TranscriptFormat.TEXT:
+ return [
+ Document(
+ page_content=transcript.text, metadata=transcript.json_response
+ )
+ ]
+ elif self.transcript_format == TranscriptFormat.SENTENCES:
+ sentences = transcript.get_sentences()
+ return [
+ Document(page_content=s.text, metadata=s.dict(exclude={"text"}))
+ for s in sentences
+ ]
+ elif self.transcript_format == TranscriptFormat.PARAGRAPHS:
+ paragraphs = transcript.get_paragraphs()
+ return [
+ Document(page_content=p.text, metadata=p.dict(exclude={"text"}))
+ for p in paragraphs
+ ]
+ elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT:
+ return [Document(page_content=transcript.export_subtitles_srt())]
+ elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT:
+ return [Document(page_content=transcript.export_subtitles_vtt())]
+ else:
+ raise ValueError("Unknown transcript format.")
diff --git a/libs/community/langchain_community/document_loaders/async_html.py b/libs/community/langchain_community/document_loaders/async_html.py
new file mode 100644
index 00000000000..9d23a12abc0
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/async_html.py
@@ -0,0 +1,222 @@
+import asyncio
+import logging
+import warnings
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Dict, Iterator, List, Optional, Union, cast
+
+import aiohttp
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+default_header_template = {
+ "User-Agent": "",
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*"
+ ";q=0.8",
+ "Accept-Language": "en-US,en;q=0.5",
+ "Referer": "https://www.google.com/",
+ "DNT": "1",
+ "Connection": "keep-alive",
+ "Upgrade-Insecure-Requests": "1",
+}
+
+
+def _build_metadata(soup: Any, url: str) -> dict:
+ """Build metadata from BeautifulSoup output."""
+ metadata = {"source": url}
+ if title := soup.find("title"):
+ metadata["title"] = title.get_text()
+ if description := soup.find("meta", attrs={"name": "description"}):
+ metadata["description"] = description.get("content", "No description found.")
+ if html := soup.find("html"):
+ metadata["language"] = html.get("lang", "No language found.")
+ return metadata
+
+
+class AsyncHtmlLoader(BaseLoader):
+ """Load `HTML` asynchronously."""
+
+ def __init__(
+ self,
+ web_path: Union[str, List[str]],
+ header_template: Optional[dict] = None,
+ verify_ssl: Optional[bool] = True,
+ proxies: Optional[dict] = None,
+ autoset_encoding: bool = True,
+ encoding: Optional[str] = None,
+ default_parser: str = "html.parser",
+ requests_per_second: int = 2,
+ requests_kwargs: Optional[Dict[str, Any]] = None,
+ raise_for_status: bool = False,
+ ignore_load_errors: bool = False,
+ ):
+ """Initialize with a webpage path."""
+
+ # TODO: Deprecate web_path in favor of web_paths, and remove this
+ # left like this because there are a number of loaders that expect single
+ # urls
+ if isinstance(web_path, str):
+ self.web_paths = [web_path]
+ elif isinstance(web_path, List):
+ self.web_paths = web_path
+
+ headers = header_template or default_header_template
+ if not headers.get("User-Agent"):
+ try:
+ from fake_useragent import UserAgent
+
+ headers["User-Agent"] = UserAgent().random
+ except ImportError:
+ logger.info(
+ "fake_useragent not found, using default user agent."
+ "To get a realistic header for requests, "
+ "`pip install fake_useragent`."
+ )
+
+ self.session = requests.Session()
+ self.session.headers = dict(headers)
+ self.session.verify = verify_ssl
+
+ if proxies:
+ self.session.proxies.update(proxies)
+
+ self.requests_per_second = requests_per_second
+ self.default_parser = default_parser
+ self.requests_kwargs = requests_kwargs or {}
+ self.raise_for_status = raise_for_status
+ self.autoset_encoding = autoset_encoding
+ self.encoding = encoding
+ self.ignore_load_errors = ignore_load_errors
+
+ def _fetch_valid_connection_docs(self, url: str) -> Any:
+ if self.ignore_load_errors:
+ try:
+ return self.session.get(url, **self.requests_kwargs)
+ except Exception as e:
+ warnings.warn(str(e))
+ return None
+
+ return self.session.get(url, **self.requests_kwargs)
+
+ @staticmethod
+ def _check_parser(parser: str) -> None:
+ """Check that parser is valid for bs4."""
+ valid_parsers = ["html.parser", "lxml", "xml", "lxml-xml", "html5lib"]
+ if parser not in valid_parsers:
+ raise ValueError(
+ "`parser` must be one of " + ", ".join(valid_parsers) + "."
+ )
+
+ def _scrape(
+ self,
+ url: str,
+ parser: Union[str, None] = None,
+ bs_kwargs: Optional[dict] = None,
+ ) -> Any:
+ from bs4 import BeautifulSoup
+
+ if parser is None:
+ if url.endswith(".xml"):
+ parser = "xml"
+ else:
+ parser = self.default_parser
+
+ self._check_parser(parser)
+
+ html_doc = self._fetch_valid_connection_docs(url)
+ if not getattr(html_doc, "ok", False):
+ return None
+
+ if self.raise_for_status:
+ html_doc.raise_for_status()
+
+ if self.encoding is not None:
+ html_doc.encoding = self.encoding
+ elif self.autoset_encoding:
+ html_doc.encoding = html_doc.apparent_encoding
+ return BeautifulSoup(html_doc.text, parser, **(bs_kwargs or {}))
+
+ async def _fetch(
+ self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
+ ) -> str:
+ async with aiohttp.ClientSession() as session:
+ for i in range(retries):
+ try:
+ async with session.get(
+ url,
+ headers=self.session.headers,
+ ssl=None if self.session.verify else False,
+ ) as response:
+ try:
+ text = await response.text()
+ except UnicodeDecodeError:
+ logger.error(f"Failed to decode content from {url}")
+ text = ""
+ return text
+ except aiohttp.ClientConnectionError as e:
+ if i == retries - 1 and self.ignore_load_errors:
+ logger.warning(f"Error fetching {url} after {retries} retries.")
+ return ""
+ elif i == retries - 1:
+ raise
+ else:
+ logger.warning(
+ f"Error fetching {url} with attempt "
+ f"{i + 1}/{retries}: {e}. Retrying..."
+ )
+ await asyncio.sleep(cooldown * backoff**i)
+ raise ValueError("retry count exceeded")
+
+ async def _fetch_with_rate_limit(
+ self, url: str, semaphore: asyncio.Semaphore
+ ) -> str:
+ async with semaphore:
+ return await self._fetch(url)
+
+ async def fetch_all(self, urls: List[str]) -> Any:
+ """Fetch all urls concurrently with rate limiting."""
+ semaphore = asyncio.Semaphore(self.requests_per_second)
+ tasks = []
+ for url in urls:
+ task = asyncio.ensure_future(self._fetch_with_rate_limit(url, semaphore))
+ tasks.append(task)
+ try:
+ from tqdm.asyncio import tqdm_asyncio
+
+ return await tqdm_asyncio.gather(
+ *tasks, desc="Fetching pages", ascii=True, mininterval=1
+ )
+ except ImportError:
+ warnings.warn("For better logging of progress, `pip install tqdm`")
+ return await asyncio.gather(*tasks)
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load text from the url(s) in web_path."""
+ for doc in self.load():
+ yield doc
+
+ def load(self) -> List[Document]:
+ """Load text from the url(s) in web_path."""
+
+ try:
+ # Raises RuntimeError if there is no current event loop.
+ asyncio.get_running_loop()
+ # If there is a current event loop, we need to run the async code
+ # in a separate loop, in a separate thread.
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(asyncio.run, self.fetch_all(self.web_paths))
+ results = future.result()
+ except RuntimeError:
+ results = asyncio.run(self.fetch_all(self.web_paths))
+ docs = []
+ for i, text in enumerate(cast(List[str], results)):
+ soup = self._scrape(self.web_paths[i])
+ if not soup:
+ continue
+ metadata = _build_metadata(soup, self.web_paths[i])
+ docs.append(Document(page_content=text, metadata=metadata))
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/azlyrics.py b/libs/community/langchain_community/document_loaders/azlyrics.py
new file mode 100644
index 00000000000..b763c8fb4b3
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/azlyrics.py
@@ -0,0 +1,18 @@
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+
+class AZLyricsLoader(WebBaseLoader):
+ """Load `AZLyrics` webpages."""
+
+ def load(self) -> List[Document]:
+ """Load webpages into Documents."""
+ soup = self.scrape()
+ title = soup.title.text
+ lyrics = soup.find_all("div", {"class": ""})[2].text
+ text = title + lyrics
+ metadata = {"source": self.web_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/azure_ai_data.py b/libs/community/langchain_community/document_loaders/azure_ai_data.py
new file mode 100644
index 00000000000..77ac93d3e07
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/azure_ai_data.py
@@ -0,0 +1,43 @@
+from typing import Iterator, List, Optional
+
+from langchain_community.docstore.document import Document
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileIOLoader
+
+
+class AzureAIDataLoader(BaseLoader):
+ """Load from Azure AI Data."""
+
+ def __init__(self, url: str, glob: Optional[str] = None):
+ """Initialize with URL to a data asset or storage location
+ ."""
+ self.url = url
+ """URL to the data asset or storage location."""
+ self.glob_pattern = glob
+ """Optional glob pattern to select files. Defaults to None."""
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ """A lazy loader for Documents."""
+ try:
+ from azureml.fsspec import AzureMachineLearningFileSystem
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import azureml-fspec package."
+ "Please install it with `pip install azureml-fsspec`."
+ ) from exc
+
+ fs = AzureMachineLearningFileSystem(self.url)
+
+ if self.glob_pattern:
+ remote_paths_list = fs.glob(self.glob_pattern)
+ else:
+ remote_paths_list = fs.ls()
+
+ for remote_path in remote_paths_list:
+ with fs.open(remote_path) as f:
+ loader = UnstructuredFileIOLoader(file=f)
+ yield from loader.load()
diff --git a/libs/community/langchain_community/document_loaders/azure_blob_storage_container.py b/libs/community/langchain_community/document_loaders/azure_blob_storage_container.py
new file mode 100644
index 00000000000..4f3bc78b442
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/azure_blob_storage_container.py
@@ -0,0 +1,45 @@
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.azure_blob_storage_file import (
+ AzureBlobStorageFileLoader,
+)
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class AzureBlobStorageContainerLoader(BaseLoader):
+ """Load from `Azure Blob Storage` container."""
+
+ def __init__(self, conn_str: str, container: str, prefix: str = ""):
+ """Initialize with connection string, container and blob prefix."""
+ self.conn_str = conn_str
+ """Connection string for Azure Blob Storage."""
+ self.container = container
+ """Container name."""
+ self.prefix = prefix
+ """Prefix for blob names."""
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ try:
+ from azure.storage.blob import ContainerClient
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import azure storage blob python package. "
+ "Please install it with `pip install azure-storage-blob`."
+ ) from exc
+
+ container = ContainerClient.from_connection_string(
+ conn_str=self.conn_str, container_name=self.container
+ )
+ docs = []
+ blob_list = container.list_blobs(name_starts_with=self.prefix)
+ for blob in blob_list:
+ loader = AzureBlobStorageFileLoader(
+ self.conn_str,
+ self.container,
+ blob.name, # type: ignore
+ )
+ docs.extend(loader.load())
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/azure_blob_storage_file.py b/libs/community/langchain_community/document_loaders/azure_blob_storage_file.py
new file mode 100644
index 00000000000..9151f75a006
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/azure_blob_storage_file.py
@@ -0,0 +1,44 @@
+import os
+import tempfile
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class AzureBlobStorageFileLoader(BaseLoader):
+ """Load from `Azure Blob Storage` files."""
+
+ def __init__(self, conn_str: str, container: str, blob_name: str):
+ """Initialize with connection string, container and blob name."""
+ self.conn_str = conn_str
+ """Connection string for Azure Blob Storage."""
+ self.container = container
+ """Container name."""
+ self.blob = blob_name
+ """Blob name."""
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ try:
+ from azure.storage.blob import BlobClient
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import azure storage blob python package. "
+ "Please install it with `pip install azure-storage-blob`."
+ ) from exc
+
+ client = BlobClient.from_connection_string(
+ conn_str=self.conn_str, container_name=self.container, blob_name=self.blob
+ )
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ file_path = f"{temp_dir}/{self.container}/{self.blob}"
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ with open(f"{file_path}", "wb") as file:
+ blob_data = client.download_blob()
+ blob_data.readinto(file)
+ loader = UnstructuredFileLoader(file_path)
+ return loader.load()
diff --git a/libs/community/langchain_community/document_loaders/baiducloud_bos_directory.py b/libs/community/langchain_community/document_loaders/baiducloud_bos_directory.py
new file mode 100644
index 00000000000..05e41fbaaa8
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/baiducloud_bos_directory.py
@@ -0,0 +1,55 @@
+from typing import Any, Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class BaiduBOSDirectoryLoader(BaseLoader):
+ """Load from `Baidu BOS directory`."""
+
+ def __init__(self, conf: Any, bucket: str, prefix: str = ""):
+ """Initialize with BOS config, bucket and prefix.
+ :param conf(BosConfig): BOS config.
+ :param bucket(str): BOS bucket.
+ :param prefix(str): prefix.
+ """
+ self.conf = conf
+ self.bucket = bucket
+ self.prefix = prefix
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Load documents."""
+ try:
+ from baidubce.services.bos.bos_client import BosClient
+ except ImportError:
+ raise ImportError(
+ "Please install bce-python-sdk with `pip install bce-python-sdk`."
+ )
+ client = BosClient(self.conf)
+ contents = []
+ marker = ""
+ while True:
+ response = client.list_objects(
+ bucket_name=self.bucket,
+ prefix=self.prefix,
+ marker=marker,
+ max_keys=1000,
+ )
+ contents_len = len(response.contents)
+ contents.extend(response.contents)
+ if response.is_truncated or contents_len < int(str(response.max_keys)):
+ break
+ marker = response.next_marker
+ from langchain_community.document_loaders.baiducloud_bos_file import (
+ BaiduBOSFileLoader,
+ )
+
+ for content in contents:
+ if str(content.key).endswith("/"):
+ continue
+ loader = BaiduBOSFileLoader(self.conf, self.bucket, str(content.key))
+ yield loader.load()[0]
diff --git a/libs/community/langchain_community/document_loaders/baiducloud_bos_file.py b/libs/community/langchain_community/document_loaders/baiducloud_bos_file.py
new file mode 100644
index 00000000000..4f853ba3ddf
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/baiducloud_bos_file.py
@@ -0,0 +1,54 @@
+import logging
+import os
+import tempfile
+from typing import Any, Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+logger = logging.getLogger(__name__)
+
+
+class BaiduBOSFileLoader(BaseLoader):
+ """Load from `Baidu Cloud BOS` file."""
+
+ def __init__(self, conf: Any, bucket: str, key: str):
+ """Initialize with BOS config, bucket and key name.
+ :param conf(BceClientConfiguration): BOS config.
+ :param bucket(str): BOS bucket.
+ :param key(str): BOS file key.
+ """
+ self.conf = conf
+ self.bucket = bucket
+ self.key = key
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Load documents."""
+ try:
+ from baidubce.services.bos.bos_client import BosClient
+ except ImportError:
+ raise ImportError(
+ "Please using `pip install bce-python-sdk`"
+ + " before import bos related package."
+ )
+
+ # Initialize BOS Client
+ client = BosClient(self.conf)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ file_path = f"{temp_dir}/{self.bucket}/{self.key}"
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ # Download the file to a destination
+ logger.debug(f"get object key {self.key} to file {file_path}")
+ client.get_object_to_file(self.bucket, self.key, file_path)
+ try:
+ loader = UnstructuredFileLoader(file_path)
+ documents = loader.load()
+ return iter(documents)
+ except Exception as ex:
+ logger.error(f"load document error = {ex}")
+ return iter([Document(page_content="")])
diff --git a/libs/community/langchain_community/document_loaders/base.py b/libs/community/langchain_community/document_loaders/base.py
new file mode 100644
index 00000000000..8474fa0a579
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/base.py
@@ -0,0 +1,102 @@
+"""Abstract interface for document loader implementations."""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Iterator, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.blob_loaders import Blob
+
+if TYPE_CHECKING:
+ from langchain.text_splitter import TextSplitter
+
+
+class BaseLoader(ABC):
+ """Interface for Document Loader.
+
+ Implementations should implement the lazy-loading method using generators
+ to avoid loading all Documents into memory at once.
+
+ The `load` method will remain as is for backwards compatibility, but its
+ implementation should be just `list(self.lazy_load())`.
+ """
+
+ # Sub-classes should implement this method
+ # as return list(self.lazy_load()).
+ # This method returns a List which is materialized in memory.
+ @abstractmethod
+ def load(self) -> List[Document]:
+ """Load data into Document objects."""
+
+ def load_and_split(
+ self, text_splitter: Optional[TextSplitter] = None
+ ) -> List[Document]:
+ """Load Documents and split into chunks. Chunks are returned as Documents.
+
+ Args:
+ text_splitter: TextSplitter instance to use for splitting documents.
+ Defaults to RecursiveCharacterTextSplitter.
+
+ Returns:
+ List of Documents.
+ """
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+ if text_splitter is None:
+ _text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
+ else:
+ _text_splitter = text_splitter
+ docs = self.load()
+ return _text_splitter.split_documents(docs)
+
+ # Attention: This method will be upgraded into an abstractmethod once it's
+ # implemented in all the existing subclasses.
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """A lazy loader for Documents."""
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not implement lazy_load()"
+ )
+
+
+class BaseBlobParser(ABC):
+ """Abstract interface for blob parsers.
+
+ A blob parser provides a way to parse raw data stored in a blob into one
+ or more documents.
+
+ The parser can be composed with blob loaders, making it easy to reuse
+ a parser independent of how the blob was originally loaded.
+ """
+
+ @abstractmethod
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazy parsing interface.
+
+ Subclasses are required to implement this method.
+
+ Args:
+ blob: Blob instance
+
+ Returns:
+ Generator of documents
+ """
+
+ def parse(self, blob: Blob) -> List[Document]:
+ """Eagerly parse the blob into a document or documents.
+
+ This is a convenience method for interactive development environment.
+
+ Production applications should favor the lazy_parse method instead.
+
+ Subclasses should generally not over-ride this parse method.
+
+ Args:
+ blob: Blob instance
+
+ Returns:
+ List of documents
+ """
+ return list(self.lazy_parse(blob))
diff --git a/libs/community/langchain_community/document_loaders/base_o365.py b/libs/community/langchain_community/document_loaders/base_o365.py
new file mode 100644
index 00000000000..1400f36d8d3
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/base_o365.py
@@ -0,0 +1,194 @@
+"""Base class for all loaders that uses O365 Package"""
+from __future__ import annotations
+
+import logging
+import os
+import tempfile
+from abc import abstractmethod
+from enum import Enum
+from pathlib import Path
+from typing import TYPE_CHECKING, Dict, Iterable, List, Sequence, Union
+
+from langchain_core.pydantic_v1 import (
+ BaseModel,
+ BaseSettings,
+ Field,
+ FilePath,
+ SecretStr,
+)
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.blob_loaders.file_system import (
+ FileSystemBlobLoader,
+)
+from langchain_community.document_loaders.blob_loaders.schema import Blob
+
+if TYPE_CHECKING:
+ from O365 import Account
+ from O365.drive import Drive, Folder
+
+logger = logging.getLogger(__name__)
+
+CHUNK_SIZE = 1024 * 1024 * 5
+
+
+class _O365Settings(BaseSettings):
+ client_id: str = Field(..., env="O365_CLIENT_ID")
+ client_secret: SecretStr = Field(..., env="O365_CLIENT_SECRET")
+
+ class Config:
+ env_prefix = ""
+ case_sentive = False
+ env_file = ".env"
+
+
+class _O365TokenStorage(BaseSettings):
+ token_path: FilePath = Path.home() / ".credentials" / "o365_token.txt"
+
+
+class _FileType(str, Enum):
+ DOC = "doc"
+ DOCX = "docx"
+ PDF = "pdf"
+
+
+def fetch_mime_types(file_types: Sequence[_FileType]) -> Dict[str, str]:
+ """Fetch the mime types for the specified file types."""
+ mime_types_mapping = {}
+ for file_type in file_types:
+ if file_type.value == "doc":
+ mime_types_mapping[file_type.value] = "application/msword"
+ elif file_type.value == "docx":
+ mime_types_mapping[
+ file_type.value
+ ] = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" # noqa: E501
+ elif file_type.value == "pdf":
+ mime_types_mapping[file_type.value] = "application/pdf"
+ return mime_types_mapping
+
+
+class O365BaseLoader(BaseLoader, BaseModel):
+ """Base class for all loaders that uses O365 Package"""
+
+ settings: _O365Settings = Field(default_factory=_O365Settings)
+ """Settings for the Office365 API client."""
+ auth_with_token: bool = False
+ """Whether to authenticate with a token or not. Defaults to False."""
+ chunk_size: Union[int, str] = CHUNK_SIZE
+ """Number of bytes to retrieve from each api call to the server. int or 'auto'."""
+
+ @property
+ @abstractmethod
+ def _file_types(self) -> Sequence[_FileType]:
+ """Return supported file types."""
+
+ @property
+ def _fetch_mime_types(self) -> Dict[str, str]:
+ """Return a dict of supported file types to corresponding mime types."""
+ return fetch_mime_types(self._file_types)
+
+ @property
+ @abstractmethod
+ def _scopes(self) -> List[str]:
+ """Return required scopes."""
+
+ def _load_from_folder(self, folder: Folder) -> Iterable[Blob]:
+ """Lazily load all files from a specified folder of the configured MIME type.
+
+ Args:
+ folder: The Folder instance from which the files are to be loaded. This
+ Folder instance should represent a directory in a file system where the
+ files are stored.
+
+ Yields:
+ An iterator that yields Blob instances, which are binary representations of
+ the files loaded from the folder.
+ """
+ file_mime_types = self._fetch_mime_types
+ items = folder.get_items()
+ with tempfile.TemporaryDirectory() as temp_dir:
+ os.makedirs(os.path.dirname(temp_dir), exist_ok=True)
+ for file in items:
+ if file.is_file:
+ if file.mime_type in list(file_mime_types.values()):
+ file.download(to_path=temp_dir, chunk_size=self.chunk_size)
+ loader = FileSystemBlobLoader(path=temp_dir)
+ yield from loader.yield_blobs()
+
+ def _load_from_object_ids(
+ self, drive: Drive, object_ids: List[str]
+ ) -> Iterable[Blob]:
+ """Lazily load files specified by their object_ids from a drive.
+
+ Load files into the system as binary large objects (Blobs) and return Iterable.
+
+ Args:
+ drive: The Drive instance from which the files are to be loaded. This Drive
+ instance should represent a cloud storage service or similar storage
+ system where the files are stored.
+ object_ids: A list of object_id strings. Each object_id represents a unique
+ identifier for a file in the drive.
+
+ Yields:
+ An iterator that yields Blob instances, which are binary representations of
+ the files loaded from the drive using the specified object_ids.
+ """
+ file_mime_types = self._fetch_mime_types
+ with tempfile.TemporaryDirectory() as temp_dir:
+ for object_id in object_ids:
+ file = drive.get_item(object_id)
+ if not file:
+ logging.warning(
+ "There isn't a file with"
+ f"object_id {object_id} in drive {drive}."
+ )
+ continue
+ if file.is_file:
+ if file.mime_type in list(file_mime_types.values()):
+ file.download(to_path=temp_dir, chunk_size=self.chunk_size)
+ loader = FileSystemBlobLoader(path=temp_dir)
+ yield from loader.yield_blobs()
+
+ def _auth(self) -> Account:
+ """Authenticates the OneDrive API client
+
+ Returns:
+ The authenticated Account object.
+ """
+ try:
+ from O365 import Account, FileSystemTokenBackend
+ except ImportError:
+ raise ImportError(
+ "O365 package not found, please install it with `pip install o365`"
+ )
+ if self.auth_with_token:
+ token_storage = _O365TokenStorage()
+ token_path = token_storage.token_path
+ token_backend = FileSystemTokenBackend(
+ token_path=token_path.parent, token_filename=token_path.name
+ )
+ account = Account(
+ credentials=(
+ self.settings.client_id,
+ self.settings.client_secret.get_secret_value(),
+ ),
+ scopes=self._scopes,
+ token_backend=token_backend,
+ **{"raise_http_errors": False},
+ )
+ else:
+ token_backend = FileSystemTokenBackend(
+ token_path=Path.home() / ".credentials"
+ )
+ account = Account(
+ credentials=(
+ self.settings.client_id,
+ self.settings.client_secret.get_secret_value(),
+ ),
+ scopes=self._scopes,
+ token_backend=token_backend,
+ **{"raise_http_errors": False},
+ )
+ # make the auth
+ account.authenticate()
+ return account
diff --git a/libs/community/langchain_community/document_loaders/bibtex.py b/libs/community/langchain_community/document_loaders/bibtex.py
new file mode 100644
index 00000000000..1cae90b8d83
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/bibtex.py
@@ -0,0 +1,111 @@
+import logging
+import re
+from pathlib import Path
+from typing import Any, Iterator, List, Mapping, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.bibtex import BibtexparserWrapper
+
+logger = logging.getLogger(__name__)
+
+
+class BibtexLoader(BaseLoader):
+ """Load a `bibtex` file.
+
+ Each document represents one entry from the bibtex file.
+
+ If a PDF file is present in the `file` bibtex field, the original PDF
+ is loaded into the document text. If no such file entry is present,
+ the `abstract` field is used instead.
+ """
+
+ def __init__(
+ self,
+ file_path: str,
+ *,
+ parser: Optional[BibtexparserWrapper] = None,
+ max_docs: Optional[int] = None,
+ max_content_chars: Optional[int] = 4_000,
+ load_extra_metadata: bool = False,
+ file_pattern: str = r"[^:]+\.pdf",
+ ):
+ """Initialize the BibtexLoader.
+
+ Args:
+ file_path: Path to the bibtex file.
+ parser: The parser to use. If None, a default parser is used.
+ max_docs: Max number of associated documents to load. Use -1 means
+ no limit.
+ max_content_chars: Maximum number of characters to load from the PDF.
+ load_extra_metadata: Whether to load extra metadata from the PDF.
+ file_pattern: Regex pattern to match the file name in the bibtex.
+ """
+ self.file_path = file_path
+ self.parser = parser or BibtexparserWrapper()
+ self.max_docs = max_docs
+ self.max_content_chars = max_content_chars
+ self.load_extra_metadata = load_extra_metadata
+ self.file_regex = re.compile(file_pattern)
+
+ def _load_entry(self, entry: Mapping[str, Any]) -> Optional[Document]:
+ import fitz
+
+ parent_dir = Path(self.file_path).parent
+ # regex is useful for Zotero flavor bibtex files
+ file_names = self.file_regex.findall(entry.get("file", ""))
+ if not file_names:
+ return None
+ texts: List[str] = []
+ for file_name in file_names:
+ try:
+ with fitz.open(parent_dir / file_name) as f:
+ texts.extend(page.get_text() for page in f)
+ except FileNotFoundError as e:
+ logger.debug(e)
+ content = "\n".join(texts) or entry.get("abstract", "")
+ if self.max_content_chars:
+ content = content[: self.max_content_chars]
+ metadata = self.parser.get_metadata(entry, load_extra=self.load_extra_metadata)
+ return Document(
+ page_content=content,
+ metadata=metadata,
+ )
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Load bibtex file using bibtexparser and get the article texts plus the
+ article metadata.
+ See https://bibtexparser.readthedocs.io/en/master/
+
+ Returns:
+ a list of documents with the document.page_content in text format
+ """
+ try:
+ import fitz # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "PyMuPDF package not found, please install it with "
+ "`pip install pymupdf`"
+ )
+
+ entries = self.parser.load_bibtex_entries(self.file_path)
+ if self.max_docs:
+ entries = entries[: self.max_docs]
+ for entry in entries:
+ doc = self._load_entry(entry)
+ if doc:
+ yield doc
+
+ def load(self) -> List[Document]:
+ """Load bibtex file documents from the given bibtex file path.
+
+ See https://bibtexparser.readthedocs.io/en/master/
+
+ Args:
+ file_path: the path to the bibtex file
+
+ Returns:
+ a list of documents with the document.page_content in text format
+ """
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/bigquery.py b/libs/community/langchain_community/document_loaders/bigquery.py
new file mode 100644
index 00000000000..e007b1f4954
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/bigquery.py
@@ -0,0 +1,94 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.vertexai import get_client_info
+
+if TYPE_CHECKING:
+ from google.auth.credentials import Credentials
+
+
+class BigQueryLoader(BaseLoader):
+ """Load from the Google Cloud Platform `BigQuery`.
+
+ Each document represents one row of the result. The `page_content_columns`
+ are written into the `page_content` of the document. The `metadata_columns`
+ are written into the `metadata` of the document. By default, all columns
+ are written into the `page_content` and none into the `metadata`.
+
+ """
+
+ def __init__(
+ self,
+ query: str,
+ project: Optional[str] = None,
+ page_content_columns: Optional[List[str]] = None,
+ metadata_columns: Optional[List[str]] = None,
+ credentials: Optional[Credentials] = None,
+ ):
+ """Initialize BigQuery document loader.
+
+ Args:
+ query: The query to run in BigQuery.
+ project: Optional. The project to run the query in.
+ page_content_columns: Optional. The columns to write into the `page_content`
+ of the document.
+ metadata_columns: Optional. The columns to write into the `metadata` of the
+ document.
+ credentials : google.auth.credentials.Credentials, optional
+ Credentials for accessing Google APIs. Use this parameter to override
+ default credentials, such as to use Compute Engine
+ (`google.auth.compute_engine.Credentials`) or Service Account
+ (`google.oauth2.service_account.Credentials`) credentials directly.
+ """
+ self.query = query
+ self.project = project
+ self.page_content_columns = page_content_columns
+ self.metadata_columns = metadata_columns
+ self.credentials = credentials
+
+ def load(self) -> List[Document]:
+ try:
+ from google.cloud import bigquery
+ except ImportError as ex:
+ raise ImportError(
+ "Could not import google-cloud-bigquery python package. "
+ "Please install it with `pip install google-cloud-bigquery`."
+ ) from ex
+
+ bq_client = bigquery.Client(
+ credentials=self.credentials,
+ project=self.project,
+ client_info=get_client_info(module="bigquery"),
+ )
+ if not bq_client.project:
+ error_desc = (
+ "GCP project for Big Query is not set! Either provide a "
+ "`project` argument during BigQueryLoader instantiation, "
+ "or set a default project with `gcloud config set project` "
+ "command."
+ )
+ raise ValueError(error_desc)
+ query_result = bq_client.query(self.query).result()
+ docs: List[Document] = []
+
+ page_content_columns = self.page_content_columns
+ metadata_columns = self.metadata_columns
+
+ if page_content_columns is None:
+ page_content_columns = [column.name for column in query_result.schema]
+ if metadata_columns is None:
+ metadata_columns = []
+
+ for row in query_result:
+ page_content = "\n".join(
+ f"{k}: {v}" for k, v in row.items() if k in page_content_columns
+ )
+ metadata = {k: v for k, v in row.items() if k in metadata_columns}
+ doc = Document(page_content=page_content, metadata=metadata)
+ docs.append(doc)
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/bilibili.py b/libs/community/langchain_community/document_loaders/bilibili.py
new file mode 100644
index 00000000000..bdae3a2a1dc
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/bilibili.py
@@ -0,0 +1,83 @@
+import json
+import re
+import warnings
+from typing import List, Tuple
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class BiliBiliLoader(BaseLoader):
+ """Load `BiliBili` video transcripts."""
+
+ def __init__(self, video_urls: List[str]):
+ """Initialize with bilibili url.
+
+ Args:
+ video_urls: List of bilibili urls.
+ """
+ self.video_urls = video_urls
+
+ def load(self) -> List[Document]:
+ """Load Documents from bilibili url."""
+ results = []
+ for url in self.video_urls:
+ transcript, video_info = self._get_bilibili_subs_and_info(url)
+ doc = Document(page_content=transcript, metadata=video_info)
+ results.append(doc)
+
+ return results
+
+ def _get_bilibili_subs_and_info(self, url: str) -> Tuple[str, dict]:
+ try:
+ from bilibili_api import sync, video
+ except ImportError:
+ raise ImportError(
+ "requests package not found, please install it with "
+ "`pip install bilibili-api-python`"
+ )
+
+ bvid = re.search(r"BV\w+", url)
+ if bvid is not None:
+ v = video.Video(bvid=bvid.group())
+ else:
+ aid = re.search(r"av[0-9]+", url)
+ if aid is not None:
+ try:
+ v = video.Video(aid=int(aid.group()[2:]))
+ except AttributeError:
+ raise ValueError(f"{url} is not bilibili url.")
+ else:
+ raise ValueError(f"{url} is not bilibili url.")
+
+ video_info = sync(v.get_info())
+ video_info.update({"url": url})
+ sub = sync(v.get_subtitle(video_info["cid"]))
+
+ # Get subtitle url
+ sub_list = sub["subtitles"]
+ if sub_list:
+ sub_url = sub_list[0]["subtitle_url"]
+ if not sub_url.startswith("http"):
+ sub_url = "https:" + sub_url
+ result = requests.get(sub_url)
+ raw_sub_titles = json.loads(result.content)["body"]
+ raw_transcript = " ".join([c["content"] for c in raw_sub_titles])
+
+ raw_transcript_with_meta_info = (
+ f"Video Title: {video_info['title']},"
+ f"description: {video_info['desc']}\n\n"
+ f"Transcript: {raw_transcript}"
+ )
+ return raw_transcript_with_meta_info, video_info
+ else:
+ raw_transcript = ""
+ warnings.warn(
+ f"""
+ No subtitles found for video: {url}.
+ Return Empty transcript.
+ """
+ )
+ return raw_transcript, video_info
diff --git a/libs/community/langchain_community/document_loaders/blackboard.py b/libs/community/langchain_community/document_loaders/blackboard.py
new file mode 100644
index 00000000000..cc33db332dd
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/blackboard.py
@@ -0,0 +1,298 @@
+import contextlib
+import re
+from pathlib import Path
+from typing import Any, List, Optional, Tuple
+from urllib.parse import unquote
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.directory import DirectoryLoader
+from langchain_community.document_loaders.pdf import PyPDFLoader
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+
+class BlackboardLoader(WebBaseLoader):
+ """Load a `Blackboard` course.
+
+ This loader is not compatible with all Blackboard courses. It is only
+ compatible with courses that use the new Blackboard interface.
+ To use this loader, you must have the BbRouter cookie. You can get this
+ cookie by logging into the course and then copying the value of the
+ BbRouter cookie from the browser's developer tools.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import BlackboardLoader
+
+ loader = BlackboardLoader(
+ blackboard_course_url="https://blackboard.example.com/webapps/blackboard/execute/announcement?method=search&context=course_entry&course_id=_123456_1",
+ bbrouter="expires:12345...",
+ )
+ documents = loader.load()
+
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ blackboard_course_url: str,
+ bbrouter: str,
+ load_all_recursively: bool = True,
+ basic_auth: Optional[Tuple[str, str]] = None,
+ cookies: Optional[dict] = None,
+ continue_on_failure: bool = False,
+ ):
+ """Initialize with blackboard course url.
+
+ The BbRouter cookie is required for most blackboard courses.
+
+ Args:
+ blackboard_course_url: Blackboard course url.
+ bbrouter: BbRouter cookie.
+ load_all_recursively: If True, load all documents recursively.
+ basic_auth: Basic auth credentials.
+ cookies: Cookies.
+ continue_on_failure: whether to continue loading the sitemap if an error
+ occurs loading a url, emitting a warning instead of raising an
+ exception. Setting this to True makes the loader more robust, but also
+ may result in missing data. Default: False
+
+ Raises:
+ ValueError: If blackboard course url is invalid.
+ """
+ super().__init__(
+ web_paths=(blackboard_course_url), continue_on_failure=continue_on_failure
+ )
+ # Get base url
+ try:
+ self.base_url = blackboard_course_url.split("/webapps/blackboard")[0]
+ except IndexError:
+ raise IndexError(
+ "Invalid blackboard course url. "
+ "Please provide a url that starts with "
+ "https:///webapps/blackboard"
+ )
+ if basic_auth is not None:
+ self.session.auth = basic_auth
+ # Combine cookies
+ if cookies is None:
+ cookies = {}
+ cookies.update({"BbRouter": bbrouter})
+ self.session.cookies.update(cookies)
+ self.load_all_recursively = load_all_recursively
+ self.check_bs4()
+
+ def check_bs4(self) -> None:
+ """Check if BeautifulSoup4 is installed.
+
+ Raises:
+ ImportError: If BeautifulSoup4 is not installed.
+ """
+ try:
+ import bs4 # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "BeautifulSoup4 is required for BlackboardLoader. "
+ "Please install it with `pip install beautifulsoup4`."
+ )
+
+ def load(self) -> List[Document]:
+ """Load data into Document objects.
+
+ Returns:
+ List of Documents.
+ """
+ if self.load_all_recursively:
+ soup_info = self.scrape()
+ self.folder_path = self._get_folder_path(soup_info)
+ relative_paths = self._get_paths(soup_info)
+ documents = []
+ for path in relative_paths:
+ url = self.base_url + path
+ print(f"Fetching documents from {url}")
+ soup_info = self._scrape(url)
+ with contextlib.suppress(ValueError):
+ documents.extend(self._get_documents(soup_info))
+ return documents
+ else:
+ print(f"Fetching documents from {self.web_path}")
+ soup_info = self.scrape()
+ self.folder_path = self._get_folder_path(soup_info)
+ return self._get_documents(soup_info)
+
+ def _get_folder_path(self, soup: Any) -> str:
+ """Get the folder path to save the Documents in.
+
+ Args:
+ soup: BeautifulSoup4 soup object.
+
+ Returns:
+ Folder path.
+ """
+ # Get the course name
+ course_name = soup.find("span", {"id": "crumb_1"})
+ if course_name is None:
+ raise ValueError("No course name found.")
+ course_name = course_name.text.strip()
+ # Prepare the folder path
+ course_name_clean = (
+ unquote(course_name)
+ .replace(" ", "_")
+ .replace("/", "_")
+ .replace(":", "_")
+ .replace(",", "_")
+ .replace("?", "_")
+ .replace("'", "_")
+ .replace("!", "_")
+ .replace('"', "_")
+ )
+ # Get the folder path
+ folder_path = Path(".") / course_name_clean
+ return str(folder_path)
+
+ def _get_documents(self, soup: Any) -> List[Document]:
+ """Fetch content from page and return Documents.
+
+ Args:
+ soup: BeautifulSoup4 soup object.
+
+ Returns:
+ List of documents.
+ """
+ attachments = self._get_attachments(soup)
+ self._download_attachments(attachments)
+ documents = self._load_documents()
+ return documents
+
+ def _get_attachments(self, soup: Any) -> List[str]:
+ """Get all attachments from a page.
+
+ Args:
+ soup: BeautifulSoup4 soup object.
+
+ Returns:
+ List of attachments.
+ """
+ from bs4 import BeautifulSoup, Tag
+
+ # Get content list
+ content_list = soup.find("ul", {"class": "contentList"})
+ if content_list is None:
+ raise ValueError("No content list found.")
+ content_list: BeautifulSoup # type: ignore
+ # Get all attachments
+ attachments = []
+ for attachment in content_list.find_all("ul", {"class": "attachments"}):
+ attachment: Tag # type: ignore
+ for link in attachment.find_all("a"):
+ link: Tag # type: ignore
+ href = link.get("href")
+ # Only add if href is not None and does not start with #
+ if href is not None and not href.startswith("#"):
+ attachments.append(href)
+ return attachments
+
+ def _download_attachments(self, attachments: List[str]) -> None:
+ """Download all attachments.
+
+ Args:
+ attachments: List of attachments.
+ """
+ # Make sure the folder exists
+ Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+ # Download all attachments
+ for attachment in attachments:
+ self.download(attachment)
+
+ def _load_documents(self) -> List[Document]:
+ """Load all documents in the folder.
+
+ Returns:
+ List of documents.
+ """
+ # Create the document loader
+ loader = DirectoryLoader(
+ path=self.folder_path,
+ glob="*.pdf",
+ loader_cls=PyPDFLoader, # type: ignore
+ )
+ # Load the documents
+ documents = loader.load()
+ # Return all documents
+ return documents
+
+ def _get_paths(self, soup: Any) -> List[str]:
+ """Get all relative paths in the navbar."""
+ relative_paths = []
+ course_menu = soup.find("ul", {"class": "courseMenu"})
+ if course_menu is None:
+ raise ValueError("No course menu found.")
+ for link in course_menu.find_all("a"):
+ href = link.get("href")
+ if href is not None and href.startswith("/"):
+ relative_paths.append(href)
+ return relative_paths
+
+ def download(self, path: str) -> None:
+ """Download a file from an url.
+
+ Args:
+ path: Path to the file.
+ """
+ # Get the file content
+ response = self.session.get(self.base_url + path, allow_redirects=True)
+ # Get the filename
+ filename = self.parse_filename(response.url)
+ # Write the file to disk
+ with open(Path(self.folder_path) / filename, "wb") as f:
+ f.write(response.content)
+
+ def parse_filename(self, url: str) -> str:
+ """Parse the filename from an url.
+
+ Args:
+ url: Url to parse the filename from.
+
+ Returns:
+ The filename.
+ """
+ if (url_path := Path(url)) and url_path.suffix == ".pdf":
+ return url_path.name
+ else:
+ return self._parse_filename_from_url(url)
+
+ def _parse_filename_from_url(self, url: str) -> str:
+ """Parse the filename from an url.
+
+ Args:
+ url: Url to parse the filename from.
+
+ Returns:
+ The filename.
+
+ Raises:
+ ValueError: If the filename could not be parsed.
+ """
+ filename_matches = re.search(r"filename%2A%3DUTF-8%27%27(.+)", url)
+ if filename_matches:
+ filename = filename_matches.group(1)
+ else:
+ raise ValueError(f"Could not parse filename from {url}")
+ if ".pdf" not in filename:
+ raise ValueError(f"Incorrect file type: {filename}")
+ filename = filename.split(".pdf")[0] + ".pdf"
+ filename = unquote(filename)
+ filename = filename.replace("%20", " ")
+ return filename
+
+
+if __name__ == "__main__":
+ loader = BlackboardLoader(
+ "https:///webapps/blackboard/content/listContent.jsp?course_id=__1&content_id=__1&mode=reset",
+ "",
+ load_all_recursively=True,
+ )
+ documents = loader.load()
+ print(f"Loaded {len(documents)} pages of PDFs from {loader.web_path}")
diff --git a/libs/community/langchain_community/document_loaders/blob_loaders/__init__.py b/libs/community/langchain_community/document_loaders/blob_loaders/__init__.py
new file mode 100644
index 00000000000..174c71de026
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/blob_loaders/__init__.py
@@ -0,0 +1,9 @@
+from langchain_community.document_loaders.blob_loaders.file_system import (
+ FileSystemBlobLoader,
+)
+from langchain_community.document_loaders.blob_loaders.schema import Blob, BlobLoader
+from langchain_community.document_loaders.blob_loaders.youtube_audio import (
+ YoutubeAudioLoader,
+)
+
+__all__ = ["BlobLoader", "Blob", "FileSystemBlobLoader", "YoutubeAudioLoader"]
diff --git a/libs/community/langchain_community/document_loaders/blob_loaders/file_system.py b/libs/community/langchain_community/document_loaders/blob_loaders/file_system.py
new file mode 100644
index 00000000000..0fcdd4438ee
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/blob_loaders/file_system.py
@@ -0,0 +1,147 @@
+"""Use to load blobs from the local file system."""
+from pathlib import Path
+from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union
+
+from langchain_community.document_loaders.blob_loaders.schema import Blob, BlobLoader
+
+T = TypeVar("T")
+
+
+def _make_iterator(
+ length_func: Callable[[], int], show_progress: bool = False
+) -> Callable[[Iterable[T]], Iterator[T]]:
+ """Create a function that optionally wraps an iterable in tqdm."""
+ if show_progress:
+ try:
+ from tqdm.auto import tqdm
+ except ImportError:
+ raise ImportError(
+ "You must install tqdm to use show_progress=True."
+ "You can install tqdm with `pip install tqdm`."
+ )
+
+ # Make sure to provide `total` here so that tqdm can show
+ # a progress bar that takes into account the total number of files.
+ def _with_tqdm(iterable: Iterable[T]) -> Iterator[T]:
+ """Wrap an iterable in a tqdm progress bar."""
+ return tqdm(iterable, total=length_func())
+
+ iterator = _with_tqdm
+ else:
+ iterator = iter # type: ignore
+
+ return iterator
+
+
+# PUBLIC API
+
+
+class FileSystemBlobLoader(BlobLoader):
+ """Load blobs in the local file system.
+
+ Example:
+
+ .. code-block:: python
+
+ from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader
+ loader = FileSystemBlobLoader("/path/to/directory")
+ for blob in loader.yield_blobs():
+ print(blob)
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ path: Union[str, Path],
+ *,
+ glob: str = "**/[!.]*",
+ exclude: Sequence[str] = (),
+ suffixes: Optional[Sequence[str]] = None,
+ show_progress: bool = False,
+ ) -> None:
+ """Initialize with a path to directory and how to glob over it.
+
+ Args:
+ path: Path to directory to load from or path to file to load.
+ If a path to a file is provided, glob/exclude/suffixes are ignored.
+ glob: Glob pattern relative to the specified path
+ by default set to pick up all non-hidden files
+ exclude: patterns to exclude from results, use glob syntax
+ suffixes: Provide to keep only files with these suffixes
+ Useful when wanting to keep files with different suffixes
+ Suffixes must include the dot, e.g. ".txt"
+ show_progress: If true, will show a progress bar as the files are loaded.
+ This forces an iteration through all matching files
+ to count them prior to loading them.
+
+ Examples:
+
+ .. code-block:: python
+ from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader
+
+ # Load a single file.
+ loader = FileSystemBlobLoader("/path/to/file.txt")
+
+ # Recursively load all text files in a directory.
+ loader = FileSystemBlobLoader("/path/to/directory", glob="**/*.txt")
+
+ # Recursively load all non-hidden files in a directory.
+ loader = FileSystemBlobLoader("/path/to/directory", glob="**/[!.]*")
+
+ # Load all files in a directory without recursion.
+ loader = FileSystemBlobLoader("/path/to/directory", glob="*")
+
+ # Recursively load all files in a directory, except for py or pyc files.
+ loader = FileSystemBlobLoader(
+ "/path/to/directory",
+ glob="**/*.txt",
+ exclude=["**/*.py", "**/*.pyc"]
+ )
+ """ # noqa: E501
+ if isinstance(path, Path):
+ _path = path
+ elif isinstance(path, str):
+ _path = Path(path)
+ else:
+ raise TypeError(f"Expected str or Path, got {type(path)}")
+
+ self.path = _path.expanduser() # Expand user to handle ~
+ self.glob = glob
+ self.suffixes = set(suffixes or [])
+ self.show_progress = show_progress
+ self.exclude = exclude
+
+ def yield_blobs(
+ self,
+ ) -> Iterable[Blob]:
+ """Yield blobs that match the requested pattern."""
+ iterator = _make_iterator(
+ length_func=self.count_matching_files, show_progress=self.show_progress
+ )
+
+ for path in iterator(self._yield_paths()):
+ yield Blob.from_path(path)
+
+ def _yield_paths(self) -> Iterable[Path]:
+ """Yield paths that match the requested pattern."""
+ if self.path.is_file():
+ yield self.path
+ return
+
+ paths = self.path.glob(self.glob)
+ for path in paths:
+ if self.exclude:
+ if any(path.match(glob) for glob in self.exclude):
+ continue
+ if path.is_file():
+ if self.suffixes and path.suffix not in self.suffixes:
+ continue
+ yield path
+
+ def count_matching_files(self) -> int:
+ """Count files that match the pattern without loading them."""
+ # Carry out a full iteration to count the files without
+ # materializing anything expensive in memory.
+ num = 0
+ for _ in self._yield_paths():
+ num += 1
+ return num
diff --git a/libs/community/langchain_community/document_loaders/blob_loaders/schema.py b/libs/community/langchain_community/document_loaders/blob_loaders/schema.py
new file mode 100644
index 00000000000..c2f88a14015
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/blob_loaders/schema.py
@@ -0,0 +1,195 @@
+"""Schema for Blobs and Blob Loaders.
+
+The goal is to facilitate decoupling of content loading from content parsing code.
+
+In addition, content loading code should provide a lazy loading interface by default.
+"""
+from __future__ import annotations
+
+import contextlib
+import mimetypes
+from abc import ABC, abstractmethod
+from io import BufferedReader, BytesIO
+from pathlib import PurePath
+from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Union, cast
+
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+
+PathLike = Union[str, PurePath]
+
+
+class Blob(BaseModel):
+ """Blob represents raw data by either reference or value.
+
+ Provides an interface to materialize the blob in different representations, and
+ help to decouple the development of data loaders from the downstream parsing of
+ the raw data.
+
+ Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
+ """
+
+ data: Union[bytes, str, None]
+ """Raw data associated with the blob."""
+ mimetype: Optional[str] = None
+ """MimeType not to be confused with a file extension."""
+ encoding: str = "utf-8"
+ """Encoding to use if decoding the bytes into a string.
+
+ Use utf-8 as default encoding, if decoding to string.
+ """
+ path: Optional[PathLike] = None
+ """Location where the original content was found."""
+
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+ """Metadata about the blob (e.g., source)"""
+
+ class Config:
+ arbitrary_types_allowed = True
+ frozen = True
+
+ @property
+ def source(self) -> Optional[str]:
+ """The source location of the blob as string if known otherwise none.
+
+ If a path is associated with the blob, it will default to the path location.
+
+ Unless explicitly set via a metadata field called "source", in which
+ case that value will be used instead.
+ """
+ if self.metadata and "source" in self.metadata:
+ return cast(Optional[str], self.metadata["source"])
+ return str(self.path) if self.path else None
+
+ @root_validator(pre=True)
+ def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
+ """Verify that either data or path is provided."""
+ if "data" not in values and "path" not in values:
+ raise ValueError("Either data or path must be provided")
+ return values
+
+ def as_string(self) -> str:
+ """Read data as a string."""
+ if self.data is None and self.path:
+ with open(str(self.path), "r", encoding=self.encoding) as f:
+ return f.read()
+ elif isinstance(self.data, bytes):
+ return self.data.decode(self.encoding)
+ elif isinstance(self.data, str):
+ return self.data
+ else:
+ raise ValueError(f"Unable to get string for blob {self}")
+
+ def as_bytes(self) -> bytes:
+ """Read data as bytes."""
+ if isinstance(self.data, bytes):
+ return self.data
+ elif isinstance(self.data, str):
+ return self.data.encode(self.encoding)
+ elif self.data is None and self.path:
+ with open(str(self.path), "rb") as f:
+ return f.read()
+ else:
+ raise ValueError(f"Unable to get bytes for blob {self}")
+
+ @contextlib.contextmanager
+ def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
+ """Read data as a byte stream."""
+ if isinstance(self.data, bytes):
+ yield BytesIO(self.data)
+ elif self.data is None and self.path:
+ with open(str(self.path), "rb") as f:
+ yield f
+ else:
+ raise NotImplementedError(f"Unable to convert blob {self}")
+
+ @classmethod
+ def from_path(
+ cls,
+ path: PathLike,
+ *,
+ encoding: str = "utf-8",
+ mime_type: Optional[str] = None,
+ guess_type: bool = True,
+ metadata: Optional[dict] = None,
+ ) -> Blob:
+ """Load the blob from a path like object.
+
+ Args:
+ path: path like object to file to be read
+ encoding: Encoding to use if decoding the bytes into a string
+ mime_type: if provided, will be set as the mime-type of the data
+ guess_type: If True, the mimetype will be guessed from the file extension,
+ if a mime-type was not provided
+ metadata: Metadata to associate with the blob
+
+ Returns:
+ Blob instance
+ """
+ if mime_type is None and guess_type:
+ _mimetype = mimetypes.guess_type(path)[0] if guess_type else None
+ else:
+ _mimetype = mime_type
+ # We do not load the data immediately, instead we treat the blob as a
+ # reference to the underlying data.
+ return cls(
+ data=None,
+ mimetype=_mimetype,
+ encoding=encoding,
+ path=path,
+ metadata=metadata if metadata is not None else {},
+ )
+
+ @classmethod
+ def from_data(
+ cls,
+ data: Union[str, bytes],
+ *,
+ encoding: str = "utf-8",
+ mime_type: Optional[str] = None,
+ path: Optional[str] = None,
+ metadata: Optional[dict] = None,
+ ) -> Blob:
+ """Initialize the blob from in-memory data.
+
+ Args:
+ data: the in-memory data associated with the blob
+ encoding: Encoding to use if decoding the bytes into a string
+ mime_type: if provided, will be set as the mime-type of the data
+ path: if provided, will be set as the source from which the data came
+ metadata: Metadata to associate with the blob
+
+ Returns:
+ Blob instance
+ """
+ return cls(
+ data=data,
+ mimetype=mime_type,
+ encoding=encoding,
+ path=path,
+ metadata=metadata if metadata is not None else {},
+ )
+
+ def __repr__(self) -> str:
+ """Define the blob representation."""
+ str_repr = f"Blob {id(self)}"
+ if self.source:
+ str_repr += f" {self.source}"
+ return str_repr
+
+
+class BlobLoader(ABC):
+ """Abstract interface for blob loaders implementation.
+
+ Implementer should be able to load raw content from a storage system according
+ to some criteria and return the raw content lazily as a stream of blobs.
+ """
+
+ @abstractmethod
+ def yield_blobs(
+ self,
+ ) -> Iterable[Blob]:
+ """A lazy loader for raw data represented by LangChain's Blob object.
+
+ Returns:
+ A generator over blobs
+ """
diff --git a/libs/community/langchain_community/document_loaders/blob_loaders/youtube_audio.py b/libs/community/langchain_community/document_loaders/blob_loaders/youtube_audio.py
new file mode 100644
index 00000000000..f7313d04a66
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/blob_loaders/youtube_audio.py
@@ -0,0 +1,50 @@
+from typing import Iterable, List
+
+from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader
+from langchain_community.document_loaders.blob_loaders.schema import Blob, BlobLoader
+
+
+class YoutubeAudioLoader(BlobLoader):
+
+ """Load YouTube urls as audio file(s)."""
+
+ def __init__(self, urls: List[str], save_dir: str):
+ if not isinstance(urls, list):
+ raise TypeError("urls must be a list")
+
+ self.urls = urls
+ self.save_dir = save_dir
+
+ def yield_blobs(self) -> Iterable[Blob]:
+ """Yield audio blobs for each url."""
+
+ try:
+ import yt_dlp
+ except ImportError:
+ raise ImportError(
+ "yt_dlp package not found, please install it with "
+ "`pip install yt_dlp`"
+ )
+
+ # Use yt_dlp to download audio given a YouTube url
+ ydl_opts = {
+ "format": "m4a/bestaudio/best",
+ "noplaylist": True,
+ "outtmpl": self.save_dir + "/%(title)s.%(ext)s",
+ "postprocessors": [
+ {
+ "key": "FFmpegExtractAudio",
+ "preferredcodec": "m4a",
+ }
+ ],
+ }
+
+ for url in self.urls:
+ # Download file
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
+ ydl.download(url)
+
+ # Yield the written blobs
+ loader = FileSystemBlobLoader(self.save_dir, glob="*.m4a")
+ for blob in loader.yield_blobs():
+ yield blob
diff --git a/libs/community/langchain_community/document_loaders/blockchain.py b/libs/community/langchain_community/document_loaders/blockchain.py
new file mode 100644
index 00000000000..69a6a344fd3
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/blockchain.py
@@ -0,0 +1,168 @@
+import os
+import re
+import time
+from enum import Enum
+from typing import List, Optional
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class BlockchainType(Enum):
+ """Enumerator of the supported blockchains."""
+
+ ETH_MAINNET = "eth-mainnet"
+ ETH_GOERLI = "eth-goerli"
+ POLYGON_MAINNET = "polygon-mainnet"
+ POLYGON_MUMBAI = "polygon-mumbai"
+
+
+class BlockchainDocumentLoader(BaseLoader):
+ """Load elements from a blockchain smart contract.
+
+ The supported blockchains are: Ethereum mainnet, Ethereum Goerli testnet,
+ Polygon mainnet, and Polygon Mumbai testnet.
+
+ If no BlockchainType is specified, the default is Ethereum mainnet.
+
+ The Loader uses the Alchemy API to interact with the blockchain.
+ ALCHEMY_API_KEY environment variable must be set to use this loader.
+
+ The API returns 100 NFTs per request and can be paginated using the
+ startToken parameter.
+
+ If get_all_tokens is set to True, the loader will get all tokens
+ on the contract. Note that for contracts with a large number of tokens,
+ this may take a long time (e.g. 10k tokens is 100 requests).
+ Default value is false for this reason.
+
+ The max_execution_time (sec) can be set to limit the execution time
+ of the loader.
+
+ Future versions of this loader can:
+ - Support additional Alchemy APIs (e.g. getTransactions, etc.)
+ - Support additional blockain APIs (e.g. Infura, Opensea, etc.)
+ """
+
+ def __init__(
+ self,
+ contract_address: str,
+ blockchainType: BlockchainType = BlockchainType.ETH_MAINNET,
+ api_key: str = "docs-demo",
+ startToken: str = "",
+ get_all_tokens: bool = False,
+ max_execution_time: Optional[int] = None,
+ ):
+ """
+
+ Args:
+ contract_address: The address of the smart contract.
+ blockchainType: The blockchain type.
+ api_key: The Alchemy API key.
+ startToken: The start token for pagination.
+ get_all_tokens: Whether to get all tokens on the contract.
+ max_execution_time: The maximum execution time (sec).
+ """
+ self.contract_address = contract_address
+ self.blockchainType = blockchainType.value
+ self.api_key = os.environ.get("ALCHEMY_API_KEY") or api_key
+ self.startToken = startToken
+ self.get_all_tokens = get_all_tokens
+ self.max_execution_time = max_execution_time
+
+ if not self.api_key:
+ raise ValueError("Alchemy API key not provided.")
+
+ if not re.match(r"^0x[a-fA-F0-9]{40}$", self.contract_address):
+ raise ValueError(f"Invalid contract address {self.contract_address}")
+
+ def load(self) -> List[Document]:
+ result = []
+
+ current_start_token = self.startToken
+
+ start_time = time.time()
+
+ while True:
+ url = (
+ f"https://{self.blockchainType}.g.alchemy.com/nft/v2/"
+ f"{self.api_key}/getNFTsForCollection?withMetadata="
+ f"True&contractAddress={self.contract_address}"
+ f"&startToken={current_start_token}"
+ )
+
+ response = requests.get(url)
+
+ if response.status_code != 200:
+ raise ValueError(
+ f"Request failed with status code {response.status_code}"
+ )
+
+ items = response.json()["nfts"]
+
+ if not items:
+ break
+
+ for item in items:
+ content = str(item)
+ tokenId = item["id"]["tokenId"]
+ metadata = {
+ "source": self.contract_address,
+ "blockchain": self.blockchainType,
+ "tokenId": tokenId,
+ }
+ result.append(Document(page_content=content, metadata=metadata))
+
+ # exit after the first API call if get_all_tokens is False
+ if not self.get_all_tokens:
+ break
+
+ # get the start token for the next API call from the last item in array
+ current_start_token = self._get_next_tokenId(result[-1].metadata["tokenId"])
+
+ if (
+ self.max_execution_time is not None
+ and (time.time() - start_time) > self.max_execution_time
+ ):
+ raise RuntimeError("Execution time exceeded the allowed time limit.")
+
+ if not result:
+ raise ValueError(
+ f"No NFTs found for contract address {self.contract_address}"
+ )
+
+ return result
+
+ # add one to the tokenId, ensuring the correct tokenId format is used
+ def _get_next_tokenId(self, tokenId: str) -> str:
+ value_type = self._detect_value_type(tokenId)
+
+ if value_type == "hex_0x":
+ value_int = int(tokenId, 16)
+ elif value_type == "hex_0xbf":
+ value_int = int(tokenId[2:], 16)
+ else:
+ value_int = int(tokenId)
+
+ result = value_int + 1
+
+ if value_type == "hex_0x":
+ return "0x" + format(result, "0" + str(len(tokenId) - 2) + "x")
+ elif value_type == "hex_0xbf":
+ return "0xbf" + format(result, "0" + str(len(tokenId) - 4) + "x")
+ else:
+ return str(result)
+
+ # A smart contract can use different formats for the tokenId
+ @staticmethod
+ def _detect_value_type(tokenId: str) -> str:
+ if isinstance(tokenId, int):
+ return "int"
+ elif tokenId.startswith("0x"):
+ return "hex_0x"
+ elif tokenId.startswith("0xbf"):
+ return "hex_0xbf"
+ else:
+ return "hex_0xbf"
diff --git a/libs/community/langchain_community/document_loaders/brave_search.py b/libs/community/langchain_community/document_loaders/brave_search.py
new file mode 100644
index 00000000000..7fa5ff13624
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/brave_search.py
@@ -0,0 +1,33 @@
+from typing import Iterator, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.brave_search import BraveSearchWrapper
+
+
+class BraveSearchLoader(BaseLoader):
+ """Load with `Brave Search` engine."""
+
+ def __init__(self, query: str, api_key: str, search_kwargs: Optional[dict] = None):
+ """Initializes the BraveLoader.
+
+ Args:
+ query: The query to search for.
+ api_key: The API key to use.
+ search_kwargs: The search kwargs to use.
+ """
+ self.query = query
+ self.api_key = api_key
+ self.search_kwargs = search_kwargs or {}
+
+ def load(self) -> List[Document]:
+ brave_client = BraveSearchWrapper(
+ api_key=self.api_key,
+ search_kwargs=self.search_kwargs,
+ )
+ return brave_client.download_documents(self.query)
+
+ def lazy_load(self) -> Iterator[Document]:
+ for doc in self.load():
+ yield doc
diff --git a/libs/community/langchain_community/document_loaders/browserless.py b/libs/community/langchain_community/document_loaders/browserless.py
new file mode 100644
index 00000000000..03d8dcec9ad
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/browserless.py
@@ -0,0 +1,67 @@
+from typing import Iterator, List, Union
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class BrowserlessLoader(BaseLoader):
+ """Load webpages with `Browserless` /content endpoint."""
+
+ def __init__(
+ self, api_token: str, urls: Union[str, List[str]], text_content: bool = True
+ ):
+ """Initialize with API token and the URLs to scrape"""
+ self.api_token = api_token
+ """Browserless API token."""
+ self.urls = urls
+ """List of URLs to scrape."""
+ self.text_content = text_content
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load Documents from URLs."""
+
+ for url in self.urls:
+ if self.text_content:
+ response = requests.post(
+ "https://chrome.browserless.io/scrape",
+ params={
+ "token": self.api_token,
+ },
+ json={
+ "url": url,
+ "elements": [
+ {
+ "selector": "body",
+ }
+ ],
+ },
+ )
+ yield Document(
+ page_content=response.json()["data"][0]["results"][0]["text"],
+ metadata={
+ "source": url,
+ },
+ )
+ else:
+ response = requests.post(
+ "https://chrome.browserless.io/content",
+ params={
+ "token": self.api_token,
+ },
+ json={
+ "url": url,
+ },
+ )
+
+ yield Document(
+ page_content=response.text,
+ metadata={
+ "source": url,
+ },
+ )
+
+ def load(self) -> List[Document]:
+ """Load Documents from URLs."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/chatgpt.py b/libs/community/langchain_community/document_loaders/chatgpt.py
new file mode 100644
index 00000000000..b6001be4191
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/chatgpt.py
@@ -0,0 +1,65 @@
+import datetime
+import json
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+def concatenate_rows(message: dict, title: str) -> str:
+ """
+ Combine message information in a readable format ready to be used.
+ Args:
+ message: Message to be concatenated
+ title: Title of the conversation
+
+ Returns:
+ Concatenated message
+ """
+ if not message:
+ return ""
+
+ sender = message["author"]["role"] if message["author"] else "unknown"
+ text = message["content"]["parts"][0]
+ date = datetime.datetime.fromtimestamp(message["create_time"]).strftime(
+ "%Y-%m-%d %H:%M:%S"
+ )
+ return f"{title} - {sender} on {date}: {text}\n\n"
+
+
+class ChatGPTLoader(BaseLoader):
+ """Load conversations from exported `ChatGPT` data."""
+
+ def __init__(self, log_file: str, num_logs: int = -1):
+ """Initialize a class object.
+
+ Args:
+ log_file: Path to the log file
+ num_logs: Number of logs to load. If 0, load all logs.
+ """
+ self.log_file = log_file
+ self.num_logs = num_logs
+
+ def load(self) -> List[Document]:
+ with open(self.log_file, encoding="utf8") as f:
+ data = json.load(f)[: self.num_logs] if self.num_logs else json.load(f)
+
+ documents = []
+ for d in data:
+ title = d["title"]
+ messages = d["mapping"]
+ text = "".join(
+ [
+ concatenate_rows(messages[key]["message"], title)
+ for idx, key in enumerate(messages)
+ if not (
+ idx == 0
+ and messages[key]["message"]["author"]["role"] == "system"
+ )
+ ]
+ )
+ metadata = {"source": str(self.log_file)}
+ documents.append(Document(page_content=text, metadata=metadata))
+
+ return documents
diff --git a/libs/community/langchain_community/document_loaders/chromium.py b/libs/community/langchain_community/document_loaders/chromium.py
new file mode 100644
index 00000000000..7af99aed26a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/chromium.py
@@ -0,0 +1,91 @@
+import asyncio
+import logging
+from typing import Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncChromiumLoader(BaseLoader):
+ """Scrape HTML pages from URLs using a
+ headless instance of the Chromium."""
+
+ def __init__(
+ self,
+ urls: List[str],
+ ):
+ """
+ Initialize the loader with a list of URL paths.
+
+ Args:
+ urls (List[str]): A list of URLs to scrape content from.
+
+ Raises:
+ ImportError: If the required 'playwright' package is not installed.
+ """
+ self.urls = urls
+
+ try:
+ import playwright # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "playwright is required for AsyncChromiumLoader. "
+ "Please install it with `pip install playwright`."
+ )
+
+ async def ascrape_playwright(self, url: str) -> str:
+ """
+ Asynchronously scrape the content of a given URL using Playwright's async API.
+
+ Args:
+ url (str): The URL to scrape.
+
+ Returns:
+ str: The scraped HTML content or an error message if an exception occurs.
+
+ """
+ from playwright.async_api import async_playwright
+
+ logger.info("Starting scraping...")
+ results = ""
+ async with async_playwright() as p:
+ browser = await p.chromium.launch(headless=True)
+ try:
+ page = await browser.new_page()
+ await page.goto(url)
+ results = await page.content() # Simply get the HTML content
+ logger.info("Content scraped")
+ except Exception as e:
+ results = f"Error: {e}"
+ await browser.close()
+ return results
+
+ def lazy_load(self) -> Iterator[Document]:
+ """
+ Lazily load text content from the provided URLs.
+
+ This method yields Documents one at a time as they're scraped,
+ instead of waiting to scrape all URLs before returning.
+
+ Yields:
+ Document: The scraped content encapsulated within a Document object.
+
+ """
+ for url in self.urls:
+ html_content = asyncio.run(self.ascrape_playwright(url))
+ metadata = {"source": url}
+ yield Document(page_content=html_content, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """
+ Load and return all Documents from the provided URLs.
+
+ Returns:
+ List[Document]: A list of Document objects
+ containing the scraped content from each URL.
+
+ """
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/college_confidential.py b/libs/community/langchain_community/document_loaders/college_confidential.py
new file mode 100644
index 00000000000..5b9fb3a43c2
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/college_confidential.py
@@ -0,0 +1,16 @@
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+
+class CollegeConfidentialLoader(WebBaseLoader):
+ """Load `College Confidential` webpages."""
+
+ def load(self) -> List[Document]:
+ """Load webpages as Documents."""
+ soup = self.scrape()
+ text = soup.select_one("main[class='skin-handler']").text
+ metadata = {"source": self.web_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/concurrent.py b/libs/community/langchain_community/document_loaders/concurrent.py
new file mode 100644
index 00000000000..9a538d498c7
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/concurrent.py
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+import concurrent.futures
+from pathlib import Path
+from typing import Iterator, Literal, Optional, Sequence, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import (
+ BlobLoader,
+ FileSystemBlobLoader,
+)
+from langchain_community.document_loaders.generic import GenericLoader
+from langchain_community.document_loaders.parsers.registry import get_parser
+
+_PathLike = Union[str, Path]
+
+DEFAULT = Literal["default"]
+
+
+class ConcurrentLoader(GenericLoader):
+ """Load and pars Documents concurrently."""
+
+ def __init__(
+ self, blob_loader: BlobLoader, blob_parser: BaseBlobParser, num_workers: int = 4
+ ) -> None:
+ super().__init__(blob_loader, blob_parser)
+ self.num_workers = num_workers
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Load documents lazily with concurrent parsing."""
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=self.num_workers
+ ) as executor:
+ futures = {
+ executor.submit(self.blob_parser.lazy_parse, blob)
+ for blob in self.blob_loader.yield_blobs()
+ }
+ for future in concurrent.futures.as_completed(futures):
+ yield from future.result()
+
+ @classmethod
+ def from_filesystem(
+ cls,
+ path: _PathLike,
+ *,
+ glob: str = "**/[!.]*",
+ exclude: Sequence[str] = (),
+ suffixes: Optional[Sequence[str]] = None,
+ show_progress: bool = False,
+ parser: Union[DEFAULT, BaseBlobParser] = "default",
+ num_workers: int = 4,
+ parser_kwargs: Optional[dict] = None,
+ ) -> ConcurrentLoader:
+ """Create a concurrent generic document loader using a filesystem blob loader.
+
+ Args:
+ path: The path to the directory to load documents from.
+ glob: The glob pattern to use to find documents.
+ suffixes: The suffixes to use to filter documents. If None, all files
+ matching the glob will be loaded.
+ exclude: A list of patterns to exclude from the loader.
+ show_progress: Whether to show a progress bar or not (requires tqdm).
+ Proxies to the file system loader.
+ parser: A blob parser which knows how to parse blobs into documents
+ num_workers: Max number of concurrent workers to use.
+ parser_kwargs: Keyword arguments to pass to the parser.
+ """
+ blob_loader = FileSystemBlobLoader(
+ path,
+ glob=glob,
+ exclude=exclude,
+ suffixes=suffixes,
+ show_progress=show_progress,
+ )
+ if isinstance(parser, str):
+ if parser == "default" and cls.get_parser != GenericLoader.get_parser:
+ # There is an implementation of get_parser on the class, use it.
+ blob_parser = cls.get_parser(**(parser_kwargs or {}))
+ else:
+ blob_parser = get_parser(parser)
+ else:
+ blob_parser = parser
+ return cls(blob_loader, blob_parser, num_workers=num_workers)
diff --git a/libs/community/langchain_community/document_loaders/confluence.py b/libs/community/langchain_community/document_loaders/confluence.py
new file mode 100644
index 00000000000..6e3293e00a6
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/confluence.py
@@ -0,0 +1,743 @@
+import logging
+from enum import Enum
+from io import BytesIO
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import requests
+from langchain_core.documents import Document
+from tenacity import (
+ before_sleep_log,
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class ContentFormat(str, Enum):
+ """Enumerator of the content formats of Confluence page."""
+
+ EDITOR = "body.editor"
+ EXPORT_VIEW = "body.export_view"
+ ANONYMOUS_EXPORT_VIEW = "body.anonymous_export_view"
+ STORAGE = "body.storage"
+ VIEW = "body.view"
+
+ def get_content(self, page: dict) -> str:
+ return page["body"][self.name.lower()]["value"]
+
+
+class ConfluenceLoader(BaseLoader):
+ """Load `Confluence` pages.
+
+ Port of https://llamahub.ai/l/confluence
+ This currently supports username/api_key, Oauth2 login or personal access token
+ authentication.
+
+ Specify a list page_ids and/or space_key to load in the corresponding pages into
+ Document objects, if both are specified the union of both sets will be returned.
+
+ You can also specify a boolean `include_attachments` to include attachments, this
+ is set to False by default, if set to True all attachments will be downloaded and
+ ConfluenceReader will extract the text from the attachments and add it to the
+ Document object. Currently supported attachment types are: PDF, PNG, JPEG/JPG,
+ SVG, Word and Excel.
+
+ Confluence API supports difference format of page content. The storage format is the
+ raw XML representation for storage. The view format is the HTML representation for
+ viewing with macros are rendered as though it is viewed by users. You can pass
+ a enum `content_format` argument to `load()` to specify the content format, this is
+ set to `ContentFormat.STORAGE` by default, the supported values are:
+ `ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`,
+ `ContentFormat.ANONYMOUS_EXPORT_VIEW`, `ContentFormat.STORAGE`,
+ and `ContentFormat.VIEW`.
+
+ Hint: space_key and page_id can both be found in the URL of a page in Confluence
+ - https://yoursite.atlassian.com/wiki/spaces//pages/
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import ConfluenceLoader
+
+ loader = ConfluenceLoader(
+ url="https://yoursite.atlassian.com/wiki",
+ username="me",
+ api_key="12345"
+ )
+ documents = loader.load(space_key="SPACE",limit=50)
+
+ # Server on perm
+ loader = ConfluenceLoader(
+ url="https://confluence.yoursite.com/",
+ username="me",
+ api_key="your_password",
+ cloud=False
+ )
+ documents = loader.load(space_key="SPACE",limit=50)
+
+ :param url: _description_
+ :type url: str
+ :param api_key: _description_, defaults to None
+ :type api_key: str, optional
+ :param username: _description_, defaults to None
+ :type username: str, optional
+ :param oauth2: _description_, defaults to {}
+ :type oauth2: dict, optional
+ :param token: _description_, defaults to None
+ :type token: str, optional
+ :param cloud: _description_, defaults to True
+ :type cloud: bool, optional
+ :param number_of_retries: How many times to retry, defaults to 3
+ :type number_of_retries: Optional[int], optional
+ :param min_retry_seconds: defaults to 2
+ :type min_retry_seconds: Optional[int], optional
+ :param max_retry_seconds: defaults to 10
+ :type max_retry_seconds: Optional[int], optional
+ :param confluence_kwargs: additional kwargs to initialize confluence with
+ :type confluence_kwargs: dict, optional
+ :raises ValueError: Errors while validating input
+ :raises ImportError: Required dependencies not installed.
+ """
+
+ def __init__(
+ self,
+ url: str,
+ api_key: Optional[str] = None,
+ username: Optional[str] = None,
+ session: Optional[requests.Session] = None,
+ oauth2: Optional[dict] = None,
+ token: Optional[str] = None,
+ cloud: Optional[bool] = True,
+ number_of_retries: Optional[int] = 3,
+ min_retry_seconds: Optional[int] = 2,
+ max_retry_seconds: Optional[int] = 10,
+ confluence_kwargs: Optional[dict] = None,
+ ):
+ confluence_kwargs = confluence_kwargs or {}
+ errors = ConfluenceLoader.validate_init_args(
+ url=url,
+ api_key=api_key,
+ username=username,
+ session=session,
+ oauth2=oauth2,
+ token=token,
+ )
+ if errors:
+ raise ValueError(f"Error(s) while validating input: {errors}")
+ try:
+ from atlassian import Confluence # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "`atlassian` package not found, please run "
+ "`pip install atlassian-python-api`"
+ )
+
+ self.base_url = url
+ self.number_of_retries = number_of_retries
+ self.min_retry_seconds = min_retry_seconds
+ self.max_retry_seconds = max_retry_seconds
+
+ if session:
+ self.confluence = Confluence(url=url, session=session, **confluence_kwargs)
+ elif oauth2:
+ self.confluence = Confluence(
+ url=url, oauth2=oauth2, cloud=cloud, **confluence_kwargs
+ )
+ elif token:
+ self.confluence = Confluence(
+ url=url, token=token, cloud=cloud, **confluence_kwargs
+ )
+ else:
+ self.confluence = Confluence(
+ url=url,
+ username=username,
+ password=api_key,
+ cloud=cloud,
+ **confluence_kwargs,
+ )
+
+ @staticmethod
+ def validate_init_args(
+ url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ username: Optional[str] = None,
+ session: Optional[requests.Session] = None,
+ oauth2: Optional[dict] = None,
+ token: Optional[str] = None,
+ ) -> Union[List, None]:
+ """Validates proper combinations of init arguments"""
+
+ errors = []
+ if url is None:
+ errors.append("Must provide `base_url`")
+
+ if (api_key and not username) or (username and not api_key):
+ errors.append(
+ "If one of `api_key` or `username` is provided, "
+ "the other must be as well."
+ )
+
+ non_null_creds = list(
+ x is not None for x in ((api_key or username), session, oauth2, token)
+ )
+ if sum(non_null_creds) > 1:
+ all_names = ("(api_key, username)", "session", "oath2", "token")
+ provided = tuple(n for x, n in zip(non_null_creds, all_names) if x)
+ errors.append(
+ f"Cannot provide a value for more than one of: {all_names}. Received "
+ f"values for: {provided}"
+ )
+ if oauth2 and set(oauth2.keys()) != {
+ "access_token",
+ "access_token_secret",
+ "consumer_key",
+ "key_cert",
+ }:
+ errors.append(
+ "You have either omitted require keys or added extra "
+ "keys to the oauth2 dictionary. key values should be "
+ "`['access_token', 'access_token_secret', 'consumer_key', 'key_cert']`"
+ )
+ return errors or None
+
+ def load(
+ self,
+ space_key: Optional[str] = None,
+ page_ids: Optional[List[str]] = None,
+ label: Optional[str] = None,
+ cql: Optional[str] = None,
+ include_restricted_content: bool = False,
+ include_archived_content: bool = False,
+ include_attachments: bool = False,
+ include_comments: bool = False,
+ content_format: ContentFormat = ContentFormat.STORAGE,
+ limit: Optional[int] = 50,
+ max_pages: Optional[int] = 1000,
+ ocr_languages: Optional[str] = None,
+ keep_markdown_format: bool = False,
+ keep_newlines: bool = False,
+ ) -> List[Document]:
+ """
+ :param space_key: Space key retrieved from a confluence URL, defaults to None
+ :type space_key: Optional[str], optional
+ :param page_ids: List of specific page IDs to load, defaults to None
+ :type page_ids: Optional[List[str]], optional
+ :param label: Get all pages with this label, defaults to None
+ :type label: Optional[str], optional
+ :param cql: CQL Expression, defaults to None
+ :type cql: Optional[str], optional
+ :param include_restricted_content: defaults to False
+ :type include_restricted_content: bool, optional
+ :param include_archived_content: Whether to include archived content,
+ defaults to False
+ :type include_archived_content: bool, optional
+ :param include_attachments: defaults to False
+ :type include_attachments: bool, optional
+ :param include_comments: defaults to False
+ :type include_comments: bool, optional
+ :param content_format: Specify content format, defaults to
+ ContentFormat.STORAGE, the supported values are:
+ `ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`,
+ `ContentFormat.ANONYMOUS_EXPORT_VIEW`,
+ `ContentFormat.STORAGE`, and `ContentFormat.VIEW`.
+ :type content_format: ContentFormat
+ :param limit: Maximum number of pages to retrieve per request, defaults to 50
+ :type limit: int, optional
+ :param max_pages: Maximum number of pages to retrieve in total, defaults 1000
+ :type max_pages: int, optional
+ :param ocr_languages: The languages to use for the Tesseract agent. To use a
+ language, you'll first need to install the appropriate
+ Tesseract language pack.
+ :type ocr_languages: str, optional
+ :param keep_markdown_format: Whether to keep the markdown format, defaults to
+ False
+ :type keep_markdown_format: bool
+ :param keep_newlines: Whether to keep the newlines format, defaults to
+ False
+ :type keep_newlines: bool
+ :raises ValueError: _description_
+ :raises ImportError: _description_
+ :return: _description_
+ :rtype: List[Document]
+ """
+ if not space_key and not page_ids and not label and not cql:
+ raise ValueError(
+ "Must specify at least one among `space_key`, `page_ids`, "
+ "`label`, `cql` parameters."
+ )
+
+ docs = []
+
+ if space_key:
+ pages = self.paginate_request(
+ self.confluence.get_all_pages_from_space,
+ space=space_key,
+ limit=limit,
+ max_pages=max_pages,
+ status="any" if include_archived_content else "current",
+ expand=content_format.value,
+ )
+ docs += self.process_pages(
+ pages,
+ include_restricted_content,
+ include_attachments,
+ include_comments,
+ content_format,
+ ocr_languages=ocr_languages,
+ keep_markdown_format=keep_markdown_format,
+ keep_newlines=keep_newlines,
+ )
+
+ if label:
+ pages = self.paginate_request(
+ self.confluence.get_all_pages_by_label,
+ label=label,
+ limit=limit,
+ max_pages=max_pages,
+ )
+ ids_by_label = [page["id"] for page in pages]
+ if page_ids:
+ page_ids = list(set(page_ids + ids_by_label))
+ else:
+ page_ids = list(set(ids_by_label))
+
+ if cql:
+ pages = self.paginate_request(
+ self._search_content_by_cql,
+ cql=cql,
+ limit=limit,
+ max_pages=max_pages,
+ include_archived_spaces=include_archived_content,
+ expand=content_format.value,
+ )
+ docs += self.process_pages(
+ pages,
+ include_restricted_content,
+ include_attachments,
+ include_comments,
+ content_format,
+ ocr_languages,
+ keep_markdown_format,
+ )
+
+ if page_ids:
+ for page_id in page_ids:
+ get_page = retry(
+ reraise=True,
+ stop=stop_after_attempt(
+ self.number_of_retries # type: ignore[arg-type]
+ ),
+ wait=wait_exponential(
+ multiplier=1, # type: ignore[arg-type]
+ min=self.min_retry_seconds, # type: ignore[arg-type]
+ max=self.max_retry_seconds, # type: ignore[arg-type]
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )(self.confluence.get_page_by_id)
+ page = get_page(
+ page_id=page_id, expand=f"{content_format.value},version"
+ )
+ if not include_restricted_content and not self.is_public_page(page):
+ continue
+ doc = self.process_page(
+ page,
+ include_attachments,
+ include_comments,
+ content_format,
+ ocr_languages,
+ keep_markdown_format,
+ )
+ docs.append(doc)
+
+ return docs
+
+ def _search_content_by_cql(
+ self, cql: str, include_archived_spaces: Optional[bool] = None, **kwargs: Any
+ ) -> List[dict]:
+ url = "rest/api/content/search"
+
+ params: Dict[str, Any] = {"cql": cql}
+ params.update(kwargs)
+ if include_archived_spaces is not None:
+ params["includeArchivedSpaces"] = include_archived_spaces
+
+ response = self.confluence.get(url, params=params)
+ return response.get("results", [])
+
+ def paginate_request(self, retrieval_method: Callable, **kwargs: Any) -> List:
+ """Paginate the various methods to retrieve groups of pages.
+
+ Unfortunately, due to page size, sometimes the Confluence API
+ doesn't match the limit value. If `limit` is >100 confluence
+ seems to cap the response to 100. Also, due to the Atlassian Python
+ package, we don't get the "next" values from the "_links" key because
+ they only return the value from the result key. So here, the pagination
+ starts from 0 and goes until the max_pages, getting the `limit` number
+ of pages with each request. We have to manually check if there
+ are more docs based on the length of the returned list of pages, rather than
+ just checking for the presence of a `next` key in the response like this page
+ would have you do:
+ https://developer.atlassian.com/server/confluence/pagination-in-the-rest-api/
+
+ :param retrieval_method: Function used to retrieve docs
+ :type retrieval_method: callable
+ :return: List of documents
+ :rtype: List
+ """
+
+ max_pages = kwargs.pop("max_pages")
+ docs: List[dict] = []
+ while len(docs) < max_pages:
+ get_pages = retry(
+ reraise=True,
+ stop=stop_after_attempt(
+ self.number_of_retries # type: ignore[arg-type]
+ ),
+ wait=wait_exponential(
+ multiplier=1,
+ min=self.min_retry_seconds, # type: ignore[arg-type]
+ max=self.max_retry_seconds, # type: ignore[arg-type]
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )(retrieval_method)
+ batch = get_pages(**kwargs, start=len(docs))
+ if not batch:
+ break
+ docs.extend(batch)
+ return docs[:max_pages]
+
+ def is_public_page(self, page: dict) -> bool:
+ """Check if a page is publicly accessible."""
+ restrictions = self.confluence.get_all_restrictions_for_content(page["id"])
+
+ return (
+ page["status"] == "current"
+ and not restrictions["read"]["restrictions"]["user"]["results"]
+ and not restrictions["read"]["restrictions"]["group"]["results"]
+ )
+
+ def process_pages(
+ self,
+ pages: List[dict],
+ include_restricted_content: bool,
+ include_attachments: bool,
+ include_comments: bool,
+ content_format: ContentFormat,
+ ocr_languages: Optional[str] = None,
+ keep_markdown_format: Optional[bool] = False,
+ keep_newlines: bool = False,
+ ) -> List[Document]:
+ """Process a list of pages into a list of documents."""
+ docs = []
+ for page in pages:
+ if not include_restricted_content and not self.is_public_page(page):
+ continue
+ doc = self.process_page(
+ page,
+ include_attachments,
+ include_comments,
+ content_format,
+ ocr_languages=ocr_languages,
+ keep_markdown_format=keep_markdown_format,
+ keep_newlines=keep_newlines,
+ )
+ docs.append(doc)
+
+ return docs
+
+ def process_page(
+ self,
+ page: dict,
+ include_attachments: bool,
+ include_comments: bool,
+ content_format: ContentFormat,
+ ocr_languages: Optional[str] = None,
+ keep_markdown_format: Optional[bool] = False,
+ keep_newlines: bool = False,
+ ) -> Document:
+ if keep_markdown_format:
+ try:
+ from markdownify import markdownify
+ except ImportError:
+ raise ImportError(
+ "`markdownify` package not found, please run "
+ "`pip install markdownify`"
+ )
+ if include_comments or not keep_markdown_format:
+ try:
+ from bs4 import BeautifulSoup # type: ignore
+ except ImportError:
+ raise ImportError(
+ "`beautifulsoup4` package not found, please run "
+ "`pip install beautifulsoup4`"
+ )
+ if include_attachments:
+ attachment_texts = self.process_attachment(page["id"], ocr_languages)
+ else:
+ attachment_texts = []
+
+ content = content_format.get_content(page)
+ if keep_markdown_format:
+ # Use markdownify to keep the page Markdown style
+ text = markdownify(content, heading_style="ATX") + "".join(attachment_texts)
+
+ else:
+ if keep_newlines:
+ text = BeautifulSoup(
+ content.replace("
", "\n").replace(" ", "\n"), "lxml"
+ ).get_text(" ") + "".join(attachment_texts)
+ else:
+ text = BeautifulSoup(content, "lxml").get_text(
+ " ", strip=True
+ ) + "".join(attachment_texts)
+
+ if include_comments:
+ comments = self.confluence.get_page_comments(
+ page["id"], expand="body.view.value", depth="all"
+ )["results"]
+ comment_texts = [
+ BeautifulSoup(comment["body"]["view"]["value"], "lxml").get_text(
+ " ", strip=True
+ )
+ for comment in comments
+ ]
+ text = text + "".join(comment_texts)
+
+ metadata = {
+ "title": page["title"],
+ "id": page["id"],
+ "source": self.base_url.strip("/") + page["_links"]["webui"],
+ }
+
+ if "version" in page and "when" in page["version"]:
+ metadata["when"] = page["version"]["when"]
+
+ return Document(
+ page_content=text,
+ metadata=metadata,
+ )
+
+ def process_attachment(
+ self,
+ page_id: str,
+ ocr_languages: Optional[str] = None,
+ ) -> List[str]:
+ try:
+ from PIL import Image # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "`Pillow` package not found, " "please run `pip install Pillow`"
+ )
+
+ # depending on setup you may also need to set the correct path for
+ # poppler and tesseract
+ attachments = self.confluence.get_attachments_from_content(page_id)["results"]
+ texts = []
+ for attachment in attachments:
+ media_type = attachment["metadata"]["mediaType"]
+ absolute_url = self.base_url + attachment["_links"]["download"]
+ title = attachment["title"]
+ try:
+ if media_type == "application/pdf":
+ text = title + self.process_pdf(absolute_url, ocr_languages)
+ elif (
+ media_type == "image/png"
+ or media_type == "image/jpg"
+ or media_type == "image/jpeg"
+ ):
+ text = title + self.process_image(absolute_url, ocr_languages)
+ elif (
+ media_type == "application/vnd.openxmlformats-officedocument"
+ ".wordprocessingml.document"
+ ):
+ text = title + self.process_doc(absolute_url)
+ elif media_type == "application/vnd.ms-excel":
+ text = title + self.process_xls(absolute_url)
+ elif media_type == "image/svg+xml":
+ text = title + self.process_svg(absolute_url, ocr_languages)
+ else:
+ continue
+ texts.append(text)
+ except requests.HTTPError as e:
+ if e.response.status_code == 404:
+ print(f"Attachment not found at {absolute_url}")
+ continue
+ else:
+ raise
+
+ return texts
+
+ def process_pdf(
+ self,
+ link: str,
+ ocr_languages: Optional[str] = None,
+ ) -> str:
+ try:
+ import pytesseract # noqa: F401
+ from pdf2image import convert_from_bytes # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "`pytesseract` or `pdf2image` package not found, "
+ "please run `pip install pytesseract pdf2image`"
+ )
+
+ response = self.confluence.request(path=link, absolute=True)
+ text = ""
+
+ if (
+ response.status_code != 200
+ or response.content == b""
+ or response.content is None
+ ):
+ return text
+ try:
+ images = convert_from_bytes(response.content)
+ except ValueError:
+ return text
+
+ for i, image in enumerate(images):
+ image_text = pytesseract.image_to_string(image, lang=ocr_languages)
+ text += f"Page {i + 1}:\n{image_text}\n\n"
+
+ return text
+
+ def process_image(
+ self,
+ link: str,
+ ocr_languages: Optional[str] = None,
+ ) -> str:
+ try:
+ import pytesseract # noqa: F401
+ from PIL import Image # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "`pytesseract` or `Pillow` package not found, "
+ "please run `pip install pytesseract Pillow`"
+ )
+
+ response = self.confluence.request(path=link, absolute=True)
+ text = ""
+
+ if (
+ response.status_code != 200
+ or response.content == b""
+ or response.content is None
+ ):
+ return text
+ try:
+ image = Image.open(BytesIO(response.content))
+ except OSError:
+ return text
+
+ return pytesseract.image_to_string(image, lang=ocr_languages)
+
+ def process_doc(self, link: str) -> str:
+ try:
+ import docx2txt # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "`docx2txt` package not found, please run `pip install docx2txt`"
+ )
+
+ response = self.confluence.request(path=link, absolute=True)
+ text = ""
+
+ if (
+ response.status_code != 200
+ or response.content == b""
+ or response.content is None
+ ):
+ return text
+ file_data = BytesIO(response.content)
+
+ return docx2txt.process(file_data)
+
+ def process_xls(self, link: str) -> str:
+ import io
+ import os
+
+ try:
+ import xlrd # noqa: F401
+
+ except ImportError:
+ raise ImportError("`xlrd` package not found, please run `pip install xlrd`")
+
+ try:
+ import pandas as pd
+
+ except ImportError:
+ raise ImportError(
+ "`pandas` package not found, please run `pip install pandas`"
+ )
+
+ response = self.confluence.request(path=link, absolute=True)
+ text = ""
+
+ if (
+ response.status_code != 200
+ or response.content == b""
+ or response.content is None
+ ):
+ return text
+
+ filename = os.path.basename(link)
+ # Getting the whole content of the url after filename,
+ # Example: ".csv?version=2&modificationDate=1631800010678&cacheVersion=1&api=v2"
+ file_extension = os.path.splitext(filename)[1]
+
+ if file_extension.startswith(
+ ".csv"
+ ): # if the extension found in the url is ".csv"
+ content_string = response.content.decode("utf-8")
+ df = pd.read_csv(io.StringIO(content_string))
+ text += df.to_string(index=False, header=False) + "\n\n"
+ else:
+ workbook = xlrd.open_workbook(file_contents=response.content)
+ for sheet in workbook.sheets():
+ text += f"{sheet.name}:\n"
+ for row in range(sheet.nrows):
+ for col in range(sheet.ncols):
+ text += f"{sheet.cell_value(row, col)}\t"
+ text += "\n"
+ text += "\n"
+
+ return text
+
+ def process_svg(
+ self,
+ link: str,
+ ocr_languages: Optional[str] = None,
+ ) -> str:
+ try:
+ import pytesseract # noqa: F401
+ from PIL import Image # noqa: F401
+ from reportlab.graphics import renderPM # noqa: F401
+ from svglib.svglib import svg2rlg # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "`pytesseract`, `Pillow`, `reportlab` or `svglib` package not found, "
+ "please run `pip install pytesseract Pillow reportlab svglib`"
+ )
+
+ response = self.confluence.request(path=link, absolute=True)
+ text = ""
+
+ if (
+ response.status_code != 200
+ or response.content == b""
+ or response.content is None
+ ):
+ return text
+
+ drawing = svg2rlg(BytesIO(response.content))
+
+ img_data = BytesIO()
+ renderPM.drawToFile(drawing, img_data, fmt="PNG")
+ img_data.seek(0)
+ image = Image.open(img_data)
+
+ return pytesseract.image_to_string(image, lang=ocr_languages)
diff --git a/libs/community/langchain_community/document_loaders/conllu.py b/libs/community/langchain_community/document_loaders/conllu.py
new file mode 100644
index 00000000000..989eec61a5e
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/conllu.py
@@ -0,0 +1,33 @@
+import csv
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class CoNLLULoader(BaseLoader):
+ """Load `CoNLL-U` files."""
+
+ def __init__(self, file_path: str):
+ """Initialize with a file path."""
+ self.file_path = file_path
+
+ def load(self) -> List[Document]:
+ """Load from a file path."""
+ with open(self.file_path, encoding="utf8") as f:
+ tsv = list(csv.reader(f, delimiter="\t"))
+
+ # If len(line) > 1, the line is not a comment
+ lines = [line for line in tsv if len(line) > 1]
+
+ text = ""
+ for i, line in enumerate(lines):
+ # Do not add a space after a punctuation mark or at the end of the sentence
+ if line[9] == "SpaceAfter=No" or i == len(lines) - 1:
+ text += line[1]
+ else:
+ text += line[1] + " "
+
+ metadata = {"source": self.file_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/couchbase.py b/libs/community/langchain_community/document_loaders/couchbase.py
new file mode 100644
index 00000000000..fabc0a73987
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/couchbase.py
@@ -0,0 +1,100 @@
+import logging
+from typing import Iterator, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class CouchbaseLoader(BaseLoader):
+ """Load documents from `Couchbase`.
+
+ Each document represents one row of the result. The `page_content_fields` are
+ written into the `page_content`of the document. The `metadata_fields` are written
+ into the `metadata` of the document. By default, all columns are written into
+ the `page_content` and none into the `metadata`.
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ db_username: str,
+ db_password: str,
+ query: str,
+ *,
+ page_content_fields: Optional[List[str]] = None,
+ metadata_fields: Optional[List[str]] = None,
+ ) -> None:
+ """Initialize Couchbase document loader.
+
+ Args:
+ connection_string (str): The connection string to the Couchbase cluster.
+ db_username (str): The username to connect to the Couchbase cluster.
+ db_password (str): The password to connect to the Couchbase cluster.
+ query (str): The SQL++ query to execute.
+ page_content_fields (Optional[List[str]]): The columns to write into the
+ `page_content` field of the document. By default, all columns are
+ written.
+ metadata_fields (Optional[List[str]]): The columns to write into the
+ `metadata` field of the document. By default, no columns are written.
+ """
+ try:
+ from couchbase.auth import PasswordAuthenticator
+ from couchbase.cluster import Cluster
+ from couchbase.options import ClusterOptions
+ except ImportError as e:
+ raise ImportError(
+ "Could not import couchbase package."
+ "Please install couchbase SDK with `pip install couchbase`."
+ ) from e
+ if not connection_string:
+ raise ValueError("connection_string must be provided.")
+
+ if not db_username:
+ raise ValueError("db_username must be provided.")
+
+ if not db_password:
+ raise ValueError("db_password must be provided.")
+
+ auth = PasswordAuthenticator(
+ db_username,
+ db_password,
+ )
+
+ self.cluster: Cluster = Cluster(connection_string, ClusterOptions(auth))
+ self.query = query
+ self.page_content_fields = page_content_fields
+ self.metadata_fields = metadata_fields
+
+ def load(self) -> List[Document]:
+ """Load Couchbase data into Document objects."""
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Load Couchbase data into Document objects lazily."""
+ from datetime import timedelta
+
+ # Ensure connection to Couchbase cluster
+ self.cluster.wait_until_ready(timedelta(seconds=5))
+
+ # Run SQL++ Query
+ result = self.cluster.query(self.query)
+ for row in result:
+ metadata_fields = self.metadata_fields
+ page_content_fields = self.page_content_fields
+
+ if not page_content_fields:
+ page_content_fields = list(row.keys())
+
+ if not metadata_fields:
+ metadata_fields = []
+
+ metadata = {field: row[field] for field in metadata_fields}
+
+ document = "\n".join(
+ f"{k}: {v}" for k, v in row.items() if k in page_content_fields
+ )
+
+ yield (Document(page_content=document, metadata=metadata))
diff --git a/libs/community/langchain_community/document_loaders/csv_loader.py b/libs/community/langchain_community/document_loaders/csv_loader.py
new file mode 100644
index 00000000000..92198ac5bac
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/csv_loader.py
@@ -0,0 +1,158 @@
+import csv
+from io import TextIOWrapper
+from typing import Any, Dict, List, Optional, Sequence
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.helpers import detect_file_encodings
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ validate_unstructured_version,
+)
+
+
+class CSVLoader(BaseLoader):
+ """Load a `CSV` file into a list of Documents.
+
+ Each document represents one row of the CSV file. Every row is converted into a
+ key/value pair and outputted to a new line in the document's page_content.
+
+ The source for each document loaded from csv is set to the value of the
+ `file_path` argument for all documents by default.
+ You can override this by setting the `source_column` argument to the
+ name of a column in the CSV file.
+ The source of each document will then be set to the value of the column
+ with the name specified in `source_column`.
+
+ Output Example:
+ .. code-block:: txt
+
+ column1: value1
+ column2: value2
+ column3: value3
+ """
+
+ def __init__(
+ self,
+ file_path: str,
+ source_column: Optional[str] = None,
+ metadata_columns: Sequence[str] = (),
+ csv_args: Optional[Dict] = None,
+ encoding: Optional[str] = None,
+ autodetect_encoding: bool = False,
+ ):
+ """
+
+ Args:
+ file_path: The path to the CSV file.
+ source_column: The name of the column in the CSV file to use as the source.
+ Optional. Defaults to None.
+ metadata_columns: A sequence of column names to use as metadata. Optional.
+ csv_args: A dictionary of arguments to pass to the csv.DictReader.
+ Optional. Defaults to None.
+ encoding: The encoding of the CSV file. Optional. Defaults to None.
+ autodetect_encoding: Whether to try to autodetect the file encoding.
+ """
+ self.file_path = file_path
+ self.source_column = source_column
+ self.metadata_columns = metadata_columns
+ self.encoding = encoding
+ self.csv_args = csv_args or {}
+ self.autodetect_encoding = autodetect_encoding
+
+ def load(self) -> List[Document]:
+ """Load data into document objects."""
+
+ docs = []
+ try:
+ with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
+ docs = self.__read_file(csvfile)
+ except UnicodeDecodeError as e:
+ if self.autodetect_encoding:
+ detected_encodings = detect_file_encodings(self.file_path)
+ for encoding in detected_encodings:
+ try:
+ with open(
+ self.file_path, newline="", encoding=encoding.encoding
+ ) as csvfile:
+ docs = self.__read_file(csvfile)
+ break
+ except UnicodeDecodeError:
+ continue
+ else:
+ raise RuntimeError(f"Error loading {self.file_path}") from e
+ except Exception as e:
+ raise RuntimeError(f"Error loading {self.file_path}") from e
+
+ return docs
+
+ def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
+ docs = []
+
+ csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
+ for i, row in enumerate(csv_reader):
+ try:
+ source = (
+ row[self.source_column]
+ if self.source_column is not None
+ else self.file_path
+ )
+ except KeyError:
+ raise ValueError(
+ f"Source column '{self.source_column}' not found in CSV file."
+ )
+ content = "\n".join(
+ f"{k.strip()}: {v.strip() if v is not None else v}"
+ for k, v in row.items()
+ if k not in self.metadata_columns
+ )
+ metadata = {"source": source, "row": i}
+ for col in self.metadata_columns:
+ try:
+ metadata[col] = row[col]
+ except KeyError:
+ raise ValueError(f"Metadata column '{col}' not found in CSV file.")
+ doc = Document(page_content=content, metadata=metadata)
+ docs.append(doc)
+
+ return docs
+
+
+class UnstructuredCSVLoader(UnstructuredFileLoader):
+ """Load `CSV` files using `Unstructured`.
+
+ Like other
+ Unstructured loaders, UnstructuredCSVLoader can be used in both
+ "single" and "elements" mode. If you use the loader in "elements"
+ mode, the CSV file will be a single Unstructured Table element.
+ If you use the loader in "elements" mode, an HTML representation
+ of the table will be available in the "text_as_html" key in the
+ document metadata.
+
+ Examples
+ --------
+ from langchain_community.document_loaders.csv_loader import UnstructuredCSVLoader
+
+ loader = UnstructuredCSVLoader("stanley-cups.csv", mode="elements")
+ docs = loader.load()
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ """
+
+ Args:
+ file_path: The path to the CSV file.
+ mode: The mode to use when loading the CSV file.
+ Optional. Defaults to "single".
+ **unstructured_kwargs: Keyword arguments to pass to unstructured.
+ """
+ validate_unstructured_version(min_unstructured_version="0.6.8")
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.csv import partition_csv
+
+ return partition_csv(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/cube_semantic.py b/libs/community/langchain_community/document_loaders/cube_semantic.py
new file mode 100644
index 00000000000..26ebc5e2637
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/cube_semantic.py
@@ -0,0 +1,178 @@
+import json
+import logging
+import time
+from typing import List
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class CubeSemanticLoader(BaseLoader):
+ """Load `Cube semantic layer` metadata.
+
+ Args:
+ cube_api_url: REST API endpoint.
+ Use the REST API of your Cube's deployment.
+ Please find out more information here:
+ https://cube.dev/docs/http-api/rest#configuration-base-path
+ cube_api_token: Cube API token.
+ Authentication tokens are generated based on your Cube's API secret.
+ Please find out more information here:
+ https://cube.dev/docs/security#generating-json-web-tokens-jwt
+ load_dimension_values: Whether to load dimension values for every string
+ dimension or not.
+ dimension_values_limit: Maximum number of dimension values to load.
+ dimension_values_max_retries: Maximum number of retries to load dimension
+ values.
+ dimension_values_retry_delay: Delay between retries to load dimension values.
+ """
+
+ def __init__(
+ self,
+ cube_api_url: str,
+ cube_api_token: str,
+ load_dimension_values: bool = True,
+ dimension_values_limit: int = 10_000,
+ dimension_values_max_retries: int = 10,
+ dimension_values_retry_delay: int = 3,
+ ):
+ self.cube_api_url = cube_api_url
+ self.cube_api_token = cube_api_token
+ self.load_dimension_values = load_dimension_values
+ self.dimension_values_limit = dimension_values_limit
+ self.dimension_values_max_retries = dimension_values_max_retries
+ self.dimension_values_retry_delay = dimension_values_retry_delay
+
+ def _get_dimension_values(self, dimension_name: str) -> List[str]:
+ """Makes a call to Cube's REST API load endpoint to retrieve
+ values for dimensions.
+
+ These values can be used to achieve a more accurate filtering.
+ """
+ logger.info("Loading dimension values for: {dimension_name}...")
+
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": self.cube_api_token,
+ }
+
+ query = {
+ "query": {
+ "dimensions": [dimension_name],
+ "limit": self.dimension_values_limit,
+ }
+ }
+
+ retries = 0
+ while retries < self.dimension_values_max_retries:
+ response = requests.request(
+ "POST",
+ f"{self.cube_api_url}/load",
+ headers=headers,
+ data=json.dumps(query),
+ )
+
+ if response.status_code == 200:
+ response_data = response.json()
+ if (
+ "error" in response_data
+ and response_data["error"] == "Continue wait"
+ ):
+ logger.info("Retrying...")
+ retries += 1
+ time.sleep(self.dimension_values_retry_delay)
+ continue
+ else:
+ dimension_values = [
+ item[dimension_name] for item in response_data["data"]
+ ]
+ return dimension_values
+ else:
+ logger.error("Request failed with status code:", response.status_code)
+ break
+
+ if retries == self.dimension_values_max_retries:
+ logger.info("Maximum retries reached.")
+ return []
+
+ def load(self) -> List[Document]:
+ """Makes a call to Cube's REST API metadata endpoint.
+
+ Returns:
+ A list of documents with attributes:
+ - page_content=column_title + column_description
+ - metadata
+ - table_name
+ - column_name
+ - column_data_type
+ - column_member_type
+ - column_title
+ - column_description
+ - column_values
+ - cube_data_obj_type
+ """
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": self.cube_api_token,
+ }
+
+ logger.info(f"Loading metadata from {self.cube_api_url}...")
+ response = requests.get(f"{self.cube_api_url}/meta", headers=headers)
+ response.raise_for_status()
+ raw_meta_json = response.json()
+ cube_data_objects = raw_meta_json.get("cubes", [])
+
+ logger.info(f"Found {len(cube_data_objects)} cube data objects in metadata.")
+
+ if not cube_data_objects:
+ raise ValueError("No cubes found in metadata.")
+
+ docs = []
+
+ for cube_data_obj in cube_data_objects:
+ cube_data_obj_name = cube_data_obj.get("name")
+ cube_data_obj_type = cube_data_obj.get("type")
+ cube_data_obj_is_public = cube_data_obj.get("public")
+ measures = cube_data_obj.get("measures", [])
+ dimensions = cube_data_obj.get("dimensions", [])
+
+ logger.info(f"Processing {cube_data_obj_name}...")
+
+ if not cube_data_obj_is_public:
+ logger.info(f"Skipping {cube_data_obj_name} because it is not public.")
+ continue
+
+ for item in measures + dimensions:
+ column_member_type = "measure" if item in measures else "dimension"
+ dimension_values = []
+ item_name = str(item.get("name"))
+ item_type = str(item.get("type"))
+
+ if (
+ self.load_dimension_values
+ and column_member_type == "dimension"
+ and item_type == "string"
+ ):
+ dimension_values = self._get_dimension_values(item_name)
+
+ metadata = dict(
+ table_name=str(cube_data_obj_name),
+ column_name=item_name,
+ column_data_type=item_type,
+ column_title=str(item.get("title")),
+ column_description=str(item.get("description")),
+ column_member_type=column_member_type,
+ column_values=dimension_values,
+ cube_data_obj_type=cube_data_obj_type,
+ )
+
+ page_content = f"{str(item.get('title'))}, "
+ page_content += f"{str(item.get('description'))}"
+
+ docs.append(Document(page_content=page_content, metadata=metadata))
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/datadog_logs.py b/libs/community/langchain_community/document_loaders/datadog_logs.py
new file mode 100644
index 00000000000..38e774c6c3d
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/datadog_logs.py
@@ -0,0 +1,137 @@
+from datetime import datetime, timedelta
+from typing import List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class DatadogLogsLoader(BaseLoader):
+ """Load `Datadog` logs.
+
+ Logs are written into the `page_content` and into the `metadata`.
+ """
+
+ def __init__(
+ self,
+ query: str,
+ api_key: str,
+ app_key: str,
+ from_time: Optional[int] = None,
+ to_time: Optional[int] = None,
+ limit: int = 100,
+ ) -> None:
+ """Initialize Datadog document loader.
+
+ Requirements:
+ - Must have datadog_api_client installed. Install with `pip install datadog_api_client`.
+
+ Args:
+ query: The query to run in Datadog.
+ api_key: The Datadog API key.
+ app_key: The Datadog APP key.
+ from_time: Optional. The start of the time range to query.
+ Supports date math and regular timestamps (milliseconds) like '1688732708951'
+ Defaults to 20 minutes ago.
+ to_time: Optional. The end of the time range to query.
+ Supports date math and regular timestamps (milliseconds) like '1688732708951'
+ Defaults to now.
+ limit: The maximum number of logs to return.
+ Defaults to 100.
+ """ # noqa: E501
+ try:
+ from datadog_api_client import Configuration
+ except ImportError as ex:
+ raise ImportError(
+ "Could not import datadog_api_client python package. "
+ "Please install it with `pip install datadog_api_client`."
+ ) from ex
+
+ self.query = query
+ configuration = Configuration()
+ configuration.api_key["apiKeyAuth"] = api_key
+ configuration.api_key["appKeyAuth"] = app_key
+ self.configuration = configuration
+ self.from_time = from_time
+ self.to_time = to_time
+ self.limit = limit
+
+ def parse_log(self, log: dict) -> Document:
+ """
+ Create Document objects from Datadog log items.
+ """
+ attributes = log.get("attributes", {})
+ metadata = {
+ "id": log.get("id", ""),
+ "status": attributes.get("status"),
+ "service": attributes.get("service", ""),
+ "tags": attributes.get("tags", []),
+ "timestamp": attributes.get("timestamp", ""),
+ }
+
+ message = attributes.get("message", "")
+ inside_attributes = attributes.get("attributes", {})
+ content_dict = {**inside_attributes, "message": message}
+ content = ", ".join(f"{k}: {v}" for k, v in content_dict.items())
+ return Document(page_content=content, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """
+ Get logs from Datadog.
+
+ Returns:
+ A list of Document objects.
+ - page_content
+ - metadata
+ - id
+ - service
+ - status
+ - tags
+ - timestamp
+ """
+ try:
+ from datadog_api_client import ApiClient
+ from datadog_api_client.v2.api.logs_api import LogsApi
+ from datadog_api_client.v2.model.logs_list_request import LogsListRequest
+ from datadog_api_client.v2.model.logs_list_request_page import (
+ LogsListRequestPage,
+ )
+ from datadog_api_client.v2.model.logs_query_filter import LogsQueryFilter
+ from datadog_api_client.v2.model.logs_sort import LogsSort
+ except ImportError as ex:
+ raise ImportError(
+ "Could not import datadog_api_client python package. "
+ "Please install it with `pip install datadog_api_client`."
+ ) from ex
+
+ now = datetime.now()
+ twenty_minutes_before = now - timedelta(minutes=20)
+ now_timestamp = int(now.timestamp() * 1000)
+ twenty_minutes_before_timestamp = int(twenty_minutes_before.timestamp() * 1000)
+ _from = (
+ self.from_time
+ if self.from_time is not None
+ else twenty_minutes_before_timestamp
+ )
+
+ body = LogsListRequest(
+ filter=LogsQueryFilter(
+ query=self.query,
+ _from=_from,
+ to=f"{self.to_time if self.to_time is not None else now_timestamp}",
+ ),
+ sort=LogsSort.TIMESTAMP_ASCENDING,
+ page=LogsListRequestPage(
+ limit=self.limit,
+ ),
+ )
+
+ with ApiClient(configuration=self.configuration) as api_client:
+ api_instance = LogsApi(api_client)
+ response = api_instance.list_logs(body=body).to_dict()
+
+ docs: List[Document] = []
+ for row in response["data"]:
+ docs.append(self.parse_log(row))
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/dataframe.py b/libs/community/langchain_community/document_loaders/dataframe.py
new file mode 100644
index 00000000000..848f4d9d075
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/dataframe.py
@@ -0,0 +1,56 @@
+from typing import Any, Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class BaseDataFrameLoader(BaseLoader):
+ def __init__(self, data_frame: Any, *, page_content_column: str = "text"):
+ """Initialize with dataframe object.
+
+ Args:
+ data_frame: DataFrame object.
+ page_content_column: Name of the column containing the page content.
+ Defaults to "text".
+ """
+ self.data_frame = data_frame
+ self.page_content_column = page_content_column
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load records from dataframe."""
+
+ for _, row in self.data_frame.iterrows():
+ text = row[self.page_content_column]
+ metadata = row.to_dict()
+ metadata.pop(self.page_content_column)
+ yield Document(page_content=text, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """Load full dataframe."""
+ return list(self.lazy_load())
+
+
+class DataFrameLoader(BaseDataFrameLoader):
+ """Load `Pandas` DataFrame."""
+
+ def __init__(self, data_frame: Any, page_content_column: str = "text"):
+ """Initialize with dataframe object.
+
+ Args:
+ data_frame: Pandas DataFrame object.
+ page_content_column: Name of the column containing the page content.
+ Defaults to "text".
+ """
+ try:
+ import pandas as pd
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import pandas, please install with `pip install pandas`."
+ ) from e
+
+ if not isinstance(data_frame, pd.DataFrame):
+ raise ValueError(
+ f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}"
+ )
+ super().__init__(data_frame, page_content_column=page_content_column)
diff --git a/libs/community/langchain_community/document_loaders/diffbot.py b/libs/community/langchain_community/document_loaders/diffbot.py
new file mode 100644
index 00000000000..5014ecf780c
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/diffbot.py
@@ -0,0 +1,61 @@
+import logging
+from typing import Any, List
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class DiffbotLoader(BaseLoader):
+ """Load `Diffbot` json file."""
+
+ def __init__(
+ self, api_token: str, urls: List[str], continue_on_failure: bool = True
+ ):
+ """Initialize with API token, ids, and key.
+
+ Args:
+ api_token: Diffbot API token.
+ urls: List of URLs to load.
+ continue_on_failure: Whether to continue loading other URLs if one fails.
+ Defaults to True.
+ """
+ self.api_token = api_token
+ self.urls = urls
+ self.continue_on_failure = continue_on_failure
+
+ def _diffbot_api_url(self, diffbot_api: str) -> str:
+ return f"https://api.diffbot.com/v3/{diffbot_api}"
+
+ def _get_diffbot_data(self, url: str) -> Any:
+ """Get Diffbot file from Diffbot REST API."""
+ # TODO: Add support for other Diffbot APIs
+ diffbot_url = self._diffbot_api_url("article")
+ params = {
+ "token": self.api_token,
+ "url": url,
+ }
+ response = requests.get(diffbot_url, params=params, timeout=10)
+
+ # TODO: handle non-ok errors
+ return response.json() if response.ok else {}
+
+ def load(self) -> List[Document]:
+ """Extract text from Diffbot on all the URLs and return Documents"""
+ docs: List[Document] = list()
+
+ for url in self.urls:
+ try:
+ data = self._get_diffbot_data(url)
+ text = data["objects"][0]["text"] if "objects" in data else ""
+ metadata = {"source": url}
+ docs.append(Document(page_content=text, metadata=metadata))
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.error(f"Error fetching or processing {url}, exception: {e}")
+ else:
+ raise e
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/directory.py b/libs/community/langchain_community/document_loaders/directory.py
new file mode 100644
index 00000000000..da51770d097
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/directory.py
@@ -0,0 +1,162 @@
+import concurrent
+import logging
+import random
+from pathlib import Path
+from typing import Any, List, Optional, Type, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.html_bs import BSHTMLLoader
+from langchain_community.document_loaders.text import TextLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+FILE_LOADER_TYPE = Union[
+ Type[UnstructuredFileLoader], Type[TextLoader], Type[BSHTMLLoader]
+]
+logger = logging.getLogger(__name__)
+
+
+def _is_visible(p: Path) -> bool:
+ parts = p.parts
+ for _p in parts:
+ if _p.startswith("."):
+ return False
+ return True
+
+
+class DirectoryLoader(BaseLoader):
+ """Load from a directory."""
+
+ def __init__(
+ self,
+ path: str,
+ glob: str = "**/[!.]*",
+ silent_errors: bool = False,
+ load_hidden: bool = False,
+ loader_cls: FILE_LOADER_TYPE = UnstructuredFileLoader,
+ loader_kwargs: Union[dict, None] = None,
+ recursive: bool = False,
+ show_progress: bool = False,
+ use_multithreading: bool = False,
+ max_concurrency: int = 4,
+ *,
+ sample_size: int = 0,
+ randomize_sample: bool = False,
+ sample_seed: Union[int, None] = None,
+ ):
+ """Initialize with a path to directory and how to glob over it.
+
+ Args:
+ path: Path to directory.
+ glob: Glob pattern to use to find files. Defaults to "**/[!.]*"
+ (all files except hidden).
+ silent_errors: Whether to silently ignore errors. Defaults to False.
+ load_hidden: Whether to load hidden files. Defaults to False.
+ loader_cls: Loader class to use for loading files.
+ Defaults to UnstructuredFileLoader.
+ loader_kwargs: Keyword arguments to pass to loader_cls. Defaults to None.
+ recursive: Whether to recursively search for files. Defaults to False.
+ show_progress: Whether to show a progress bar. Defaults to False.
+ use_multithreading: Whether to use multithreading. Defaults to False.
+ max_concurrency: The maximum number of threads to use. Defaults to 4.
+ sample_size: The maximum number of files you would like to load from the
+ directory.
+ randomize_sample: Suffle the files to get a random sample.
+ sample_seed: set the seed of the random shuffle for reporoducibility.
+ """
+ if loader_kwargs is None:
+ loader_kwargs = {}
+ self.path = path
+ self.glob = glob
+ self.load_hidden = load_hidden
+ self.loader_cls = loader_cls
+ self.loader_kwargs = loader_kwargs
+ self.silent_errors = silent_errors
+ self.recursive = recursive
+ self.show_progress = show_progress
+ self.use_multithreading = use_multithreading
+ self.max_concurrency = max_concurrency
+ self.sample_size = sample_size
+ self.randomize_sample = randomize_sample
+ self.sample_seed = sample_seed
+
+ def load_file(
+ self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any]
+ ) -> None:
+ """Load a file.
+
+ Args:
+ item: File path.
+ path: Directory path.
+ docs: List of documents to append to.
+ pbar: Progress bar. Defaults to None.
+
+ """
+ if item.is_file():
+ if _is_visible(item.relative_to(path)) or self.load_hidden:
+ try:
+ logger.debug(f"Processing file: {str(item)}")
+ sub_docs = self.loader_cls(str(item), **self.loader_kwargs).load()
+ docs.extend(sub_docs)
+ except Exception as e:
+ if self.silent_errors:
+ logger.warning(f"Error loading file {str(item)}: {e}")
+ else:
+ raise e
+ finally:
+ if pbar:
+ pbar.update(1)
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ p = Path(self.path)
+ if not p.exists():
+ raise FileNotFoundError(f"Directory not found: '{self.path}'")
+ if not p.is_dir():
+ raise ValueError(f"Expected directory, got file: '{self.path}'")
+
+ docs: List[Document] = []
+ items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob))
+
+ if self.sample_size > 0:
+ if self.randomize_sample:
+ randomizer = (
+ random.Random(self.sample_seed) if self.sample_seed else random
+ )
+ randomizer.shuffle(items) # type: ignore
+ items = items[: min(len(items), self.sample_size)]
+
+ pbar = None
+ if self.show_progress:
+ try:
+ from tqdm import tqdm
+
+ pbar = tqdm(total=len(items))
+ except ImportError as e:
+ logger.warning(
+ "To log the progress of DirectoryLoader you need to install tqdm, "
+ "`pip install tqdm`"
+ )
+ if self.silent_errors:
+ logger.warning(e)
+ else:
+ raise ImportError(
+ "To log the progress of DirectoryLoader "
+ "you need to install tqdm, "
+ "`pip install tqdm`"
+ )
+
+ if self.use_multithreading:
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=self.max_concurrency
+ ) as executor:
+ executor.map(lambda i: self.load_file(i, p, docs, pbar), items)
+ else:
+ for i in items:
+ self.load_file(i, p, docs, pbar)
+
+ if pbar:
+ pbar.close()
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/discord.py b/libs/community/langchain_community/document_loaders/discord.py
new file mode 100644
index 00000000000..0c5308e6ae0
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/discord.py
@@ -0,0 +1,38 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ import pandas as pd
+
+
+class DiscordChatLoader(BaseLoader):
+ """Load `Discord` chat logs."""
+
+ def __init__(self, chat_log: pd.DataFrame, user_id_col: str = "ID"):
+ """Initialize with a Pandas DataFrame containing chat logs.
+
+ Args:
+ chat_log: Pandas DataFrame containing chat logs.
+ user_id_col: Name of the column containing the user ID. Defaults to "ID".
+ """
+ if not isinstance(chat_log, pd.DataFrame):
+ raise ValueError(
+ f"Expected chat_log to be a pd.DataFrame, got {type(chat_log)}"
+ )
+ self.chat_log = chat_log
+ self.user_id_col = user_id_col
+
+ def load(self) -> List[Document]:
+ """Load all chat messages."""
+ result = []
+ for _, row in self.chat_log.iterrows():
+ user_id = row[self.user_id_col]
+ metadata = row.to_dict()
+ metadata.pop(self.user_id_col)
+ result.append(Document(page_content=user_id, metadata=metadata))
+ return result
diff --git a/libs/community/langchain_community/document_loaders/docugami.py b/libs/community/langchain_community/document_loaders/docugami.py
new file mode 100644
index 00000000000..9e3acd73249
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/docugami.py
@@ -0,0 +1,362 @@
+import hashlib
+import io
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+from langchain_community.document_loaders.base import BaseLoader
+
+TABLE_NAME = "{http://www.w3.org/1999/xhtml}table"
+
+XPATH_KEY = "xpath"
+ID_KEY = "id"
+DOCUMENT_SOURCE_KEY = "source"
+DOCUMENT_NAME_KEY = "name"
+STRUCTURE_KEY = "structure"
+TAG_KEY = "tag"
+PROJECTS_KEY = "projects"
+
+DEFAULT_API_ENDPOINT = "https://api.docugami.com/v1preview1"
+
+logger = logging.getLogger(__name__)
+
+
+class DocugamiLoader(BaseLoader, BaseModel):
+ """Load from `Docugami`.
+
+ To use, you should have the ``dgml-utils`` python package installed.
+ """
+
+ api: str = DEFAULT_API_ENDPOINT
+ """The Docugami API endpoint to use."""
+
+ access_token: Optional[str] = os.environ.get("DOCUGAMI_API_KEY")
+ """The Docugami API access token to use."""
+
+ max_text_length = 4096
+ """Max length of chunk text returned."""
+
+ min_text_length: int = 32
+ """Threshold under which chunks are appended to next to avoid over-chunking."""
+
+ max_metadata_length = 512
+ """Max length of metadata text returned."""
+
+ include_xml_tags: bool = False
+ """Set to true for XML tags in chunk output text."""
+
+ parent_hierarchy_levels: int = 0
+ """Set appropriately to get parent chunks using the chunk hierarchy."""
+
+ parent_id_key: str = "doc_id"
+ """Metadata key for parent doc ID."""
+
+ sub_chunk_tables: bool = False
+ """Set to True to return sub-chunks within tables."""
+
+ whitespace_normalize_text: bool = True
+ """Set to False if you want to full whitespace formatting in the original
+ XML doc, including indentation."""
+
+ docset_id: Optional[str]
+ """The Docugami API docset ID to use."""
+
+ document_ids: Optional[Sequence[str]]
+ """The Docugami API document IDs to use."""
+
+ file_paths: Optional[Sequence[Union[Path, str]]]
+ """The local file paths to use."""
+
+ include_project_metadata_in_doc_metadata: bool = True
+ """Set to True if you want to include the project metadata in the doc metadata."""
+
+ @root_validator
+ def validate_local_or_remote(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Validate that either local file paths are given, or remote API docset ID.
+
+ Args:
+ values: The values to validate.
+
+ Returns:
+ The validated values.
+ """
+ if values.get("file_paths") and values.get("docset_id"):
+ raise ValueError("Cannot specify both file_paths and remote API docset_id")
+
+ if not values.get("file_paths") and not values.get("docset_id"):
+ raise ValueError("Must specify either file_paths or remote API docset_id")
+
+ if values.get("docset_id") and not values.get("access_token"):
+ raise ValueError("Must specify access token if using remote API docset_id")
+
+ return values
+
+ def _parse_dgml(
+ self,
+ content: bytes,
+ document_name: Optional[str] = None,
+ additional_doc_metadata: Optional[Mapping] = None,
+ ) -> List[Document]:
+ """Parse a single DGML document into a list of Documents."""
+ try:
+ from lxml import etree
+ except ImportError:
+ raise ImportError(
+ "Could not import lxml python package. "
+ "Please install it with `pip install lxml`."
+ )
+
+ try:
+ from dgml_utils.models import Chunk
+ from dgml_utils.segmentation import get_chunks
+ except ImportError:
+ raise ImportError(
+ "Could not import from dgml-utils python package. "
+ "Please install it with `pip install dgml-utils`."
+ )
+
+ def _build_framework_chunk(dg_chunk: Chunk) -> Document:
+ # Stable IDs for chunks with the same text.
+ _hashed_id = hashlib.md5(dg_chunk.text.encode()).hexdigest()
+ metadata = {
+ XPATH_KEY: dg_chunk.xpath,
+ ID_KEY: _hashed_id,
+ DOCUMENT_NAME_KEY: document_name,
+ DOCUMENT_SOURCE_KEY: document_name,
+ STRUCTURE_KEY: dg_chunk.structure,
+ TAG_KEY: dg_chunk.tag,
+ }
+
+ text = dg_chunk.text
+ if additional_doc_metadata:
+ if self.include_project_metadata_in_doc_metadata:
+ metadata.update(additional_doc_metadata)
+
+ return Document(
+ page_content=text[: self.max_text_length],
+ metadata=metadata,
+ )
+
+ # Parse the tree and return chunks
+ tree = etree.parse(io.BytesIO(content))
+ root = tree.getroot()
+
+ dg_chunks = get_chunks(
+ root,
+ min_text_length=self.min_text_length,
+ max_text_length=self.max_text_length,
+ whitespace_normalize_text=self.whitespace_normalize_text,
+ sub_chunk_tables=self.sub_chunk_tables,
+ include_xml_tags=self.include_xml_tags,
+ parent_hierarchy_levels=self.parent_hierarchy_levels,
+ )
+
+ framework_chunks: Dict[str, Document] = {}
+ for dg_chunk in dg_chunks:
+ framework_chunk = _build_framework_chunk(dg_chunk)
+ chunk_id = framework_chunk.metadata.get(ID_KEY)
+ if chunk_id:
+ framework_chunks[chunk_id] = framework_chunk
+ if dg_chunk.parent:
+ framework_parent_chunk = _build_framework_chunk(dg_chunk.parent)
+ parent_id = framework_parent_chunk.metadata.get(ID_KEY)
+ if parent_id and framework_parent_chunk.page_content:
+ framework_chunk.metadata[self.parent_id_key] = parent_id
+ framework_chunks[parent_id] = framework_parent_chunk
+
+ return list(framework_chunks.values())
+
+ def _document_details_for_docset_id(self, docset_id: str) -> List[Dict]:
+ """Gets all document details for the given docset ID"""
+ url = f"{self.api}/docsets/{docset_id}/documents"
+ all_documents = []
+
+ while url:
+ response = requests.get(
+ url,
+ headers={"Authorization": f"Bearer {self.access_token}"},
+ )
+ if response.ok:
+ data = response.json()
+ all_documents.extend(data["documents"])
+ url = data.get("next", None)
+ else:
+ raise Exception(
+ f"Failed to download {url} (status: {response.status_code})"
+ )
+
+ return all_documents
+
+ def _project_details_for_docset_id(self, docset_id: str) -> List[Dict]:
+ """Gets all project details for the given docset ID"""
+ url = f"{self.api}/projects?docset.id={docset_id}"
+ all_projects = []
+
+ while url:
+ response = requests.request(
+ "GET",
+ url,
+ headers={"Authorization": f"Bearer {self.access_token}"},
+ data={},
+ )
+ if response.ok:
+ data = response.json()
+ all_projects.extend(data["projects"])
+ url = data.get("next", None)
+ else:
+ raise Exception(
+ f"Failed to download {url} (status: {response.status_code})"
+ )
+
+ return all_projects
+
+ def _metadata_for_project(self, project: Dict) -> Dict:
+ """Gets project metadata for all files"""
+ project_id = project.get(ID_KEY)
+
+ url = f"{self.api}/projects/{project_id}/artifacts/latest"
+ all_artifacts = []
+
+ per_file_metadata: Dict = {}
+ while url:
+ response = requests.request(
+ "GET",
+ url,
+ headers={"Authorization": f"Bearer {self.access_token}"},
+ data={},
+ )
+ if response.ok:
+ data = response.json()
+ all_artifacts.extend(data["artifacts"])
+ url = data.get("next", None)
+ elif response.status_code == 404:
+ # Not found is ok, just means no published projects
+ return per_file_metadata
+ else:
+ raise Exception(
+ f"Failed to download {url} (status: {response.status_code})"
+ )
+
+ for artifact in all_artifacts:
+ artifact_name = artifact.get("name")
+ artifact_url = artifact.get("url")
+ artifact_doc = artifact.get("document")
+
+ if artifact_name == "report-values.xml" and artifact_url and artifact_doc:
+ doc_id = artifact_doc[ID_KEY]
+ metadata: Dict = {}
+
+ # The evaluated XML for each document is named after the project
+ response = requests.request(
+ "GET",
+ f"{artifact_url}/content",
+ headers={"Authorization": f"Bearer {self.access_token}"},
+ data={},
+ )
+
+ if response.ok:
+ try:
+ from lxml import etree
+ except ImportError:
+ raise ImportError(
+ "Could not import lxml python package. "
+ "Please install it with `pip install lxml`."
+ )
+ artifact_tree = etree.parse(io.BytesIO(response.content))
+ artifact_root = artifact_tree.getroot()
+ ns = artifact_root.nsmap
+ entries = artifact_root.xpath("//pr:Entry", namespaces=ns)
+ for entry in entries:
+ heading = entry.xpath("./pr:Heading", namespaces=ns)[0].text
+ value = " ".join(
+ entry.xpath("./pr:Value", namespaces=ns)[0].itertext()
+ ).strip()
+ metadata[heading] = value[: self.max_metadata_length]
+ per_file_metadata[doc_id] = metadata
+ else:
+ raise Exception(
+ f"Failed to download {artifact_url}/content "
+ + "(status: {response.status_code})"
+ )
+
+ return per_file_metadata
+
+ def _load_chunks_for_document(
+ self,
+ document_id: str,
+ docset_id: str,
+ document_name: Optional[str] = None,
+ additional_metadata: Optional[Mapping] = None,
+ ) -> List[Document]:
+ """Load chunks for a document."""
+ url = f"{self.api}/docsets/{docset_id}/documents/{document_id}/dgml"
+
+ response = requests.request(
+ "GET",
+ url,
+ headers={"Authorization": f"Bearer {self.access_token}"},
+ data={},
+ )
+
+ if response.ok:
+ return self._parse_dgml(
+ content=response.content,
+ document_name=document_name,
+ additional_doc_metadata=additional_metadata,
+ )
+ else:
+ raise Exception(
+ f"Failed to download {url} (status: {response.status_code})"
+ )
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ chunks: List[Document] = []
+
+ if self.access_token and self.docset_id:
+ # Remote mode
+ _document_details = self._document_details_for_docset_id(self.docset_id)
+ if self.document_ids:
+ _document_details = [
+ d for d in _document_details if d[ID_KEY] in self.document_ids
+ ]
+
+ _project_details = self._project_details_for_docset_id(self.docset_id)
+ combined_project_metadata: Dict[str, Dict] = {}
+ if _project_details and self.include_project_metadata_in_doc_metadata:
+ # If there are any projects for this docset and the caller requested
+ # project metadata, load it.
+ for project in _project_details:
+ metadata = self._metadata_for_project(project)
+ for file_id in metadata:
+ if file_id not in combined_project_metadata:
+ combined_project_metadata[file_id] = metadata[file_id]
+ else:
+ combined_project_metadata[file_id].update(metadata[file_id])
+
+ for doc in _document_details:
+ doc_id = doc[ID_KEY]
+ doc_name = doc.get(DOCUMENT_NAME_KEY)
+ doc_metadata = combined_project_metadata.get(doc_id)
+ chunks += self._load_chunks_for_document(
+ document_id=doc_id,
+ docset_id=self.docset_id,
+ document_name=doc_name,
+ additional_metadata=doc_metadata,
+ )
+ elif self.file_paths:
+ # Local mode (for integration testing, or pre-downloaded XML)
+ for path in self.file_paths:
+ path = Path(path)
+ with open(path, "rb") as file:
+ chunks += self._parse_dgml(
+ content=file.read(),
+ document_name=path.name,
+ )
+
+ return chunks
diff --git a/libs/community/langchain_community/document_loaders/docusaurus.py b/libs/community/langchain_community/document_loaders/docusaurus.py
new file mode 100644
index 00000000000..efdb499795a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/docusaurus.py
@@ -0,0 +1,49 @@
+"""Load Documents from Docusarus Documentation"""
+from typing import Any, List, Optional
+
+from langchain_community.document_loaders.sitemap import SitemapLoader
+
+
+class DocusaurusLoader(SitemapLoader):
+ """
+ Loader that leverages the SitemapLoader to loop through the generated pages of a
+ Docusaurus Documentation website and extracts the content by looking for specific
+ HTML tags. By default, the parser searches for the main content of the Docusaurus
+ page, which is normally the . You also have the option to define your own
+ custom HTML tags by providing them as a list, for example: ["div", ".main", "a"].
+ """
+
+ def __init__(
+ self,
+ url: str,
+ custom_html_tags: Optional[List[str]] = None,
+ **kwargs: Any,
+ ):
+ """
+ Initialize DocusaurusLoader
+ Args:
+ url: The base URL of the Docusaurus website.
+ custom_html_tags: Optional custom html tags to extract content from pages.
+ kwargs: Additional args to extend the underlying SitemapLoader, for example:
+ filter_urls, blocksize, meta_function, is_local, continue_on_failure
+ """
+ if not kwargs.get("is_local"):
+ url = f"{url}/sitemap.xml"
+
+ self.custom_html_tags = custom_html_tags or ["main article"]
+
+ super().__init__(
+ url,
+ parsing_function=kwargs.get("parsing_function") or self._parsing_function,
+ **kwargs,
+ )
+
+ def _parsing_function(self, content: Any) -> str:
+ """Parses specific elements from a Docusarus page."""
+ relevant_elements = content.select(",".join(self.custom_html_tags))
+
+ for element in relevant_elements:
+ if element not in relevant_elements:
+ element.decompose()
+
+ return str(content.get_text())
diff --git a/libs/community/langchain_community/document_loaders/dropbox.py b/libs/community/langchain_community/document_loaders/dropbox.py
new file mode 100644
index 00000000000..fcf9d6503fd
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/dropbox.py
@@ -0,0 +1,173 @@
+# Prerequisites:
+# 1. Create a Dropbox app.
+# 2. Give the app these scope permissions: `files.metadata.read`
+# and `files.content.read`.
+# 3. Generate access token: https://www.dropbox.com/developers/apps/create.
+# 4. `pip install dropbox` (requires `pip install unstructured[pdf]` for PDF filetype).
+
+
+import os
+import tempfile
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class DropboxLoader(BaseLoader, BaseModel):
+ """Load files from `Dropbox`.
+
+ In addition to common files such as text and PDF files, it also supports
+ *Dropbox Paper* files.
+ """
+
+ dropbox_access_token: str
+ """Dropbox access token."""
+ dropbox_folder_path: Optional[str] = None
+ """The folder path to load from."""
+ dropbox_file_paths: Optional[List[str]] = None
+ """The file paths to load from."""
+ recursive: bool = False
+ """Flag to indicate whether to load files recursively from subfolders."""
+
+ @root_validator
+ def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Validate that either folder_path or file_paths is set, but not both."""
+ if (
+ values.get("dropbox_folder_path") is not None
+ and values.get("dropbox_file_paths") is not None
+ ):
+ raise ValueError("Cannot specify both folder_path and file_paths")
+ if values.get("dropbox_folder_path") is None and not values.get(
+ "dropbox_file_paths"
+ ):
+ raise ValueError("Must specify either folder_path or file_paths")
+
+ return values
+
+ def _create_dropbox_client(self) -> Any:
+ """Create a Dropbox client."""
+ try:
+ from dropbox import Dropbox, exceptions
+ except ImportError:
+ raise ImportError("You must run " "`pip install dropbox")
+
+ try:
+ dbx = Dropbox(self.dropbox_access_token)
+ dbx.users_get_current_account()
+ except exceptions.AuthError as ex:
+ raise ValueError(
+ "Invalid Dropbox access token. Please verify your token and try again."
+ ) from ex
+ return dbx
+
+ def _load_documents_from_folder(self, folder_path: str) -> List[Document]:
+ """Load documents from a Dropbox folder."""
+ dbx = self._create_dropbox_client()
+
+ try:
+ from dropbox import exceptions
+ from dropbox.files import FileMetadata
+ except ImportError:
+ raise ImportError("You must run " "`pip install dropbox")
+
+ try:
+ results = dbx.files_list_folder(folder_path, recursive=self.recursive)
+ except exceptions.ApiError as ex:
+ raise ValueError(
+ f"Could not list files in the folder: {folder_path}. "
+ "Please verify the folder path and try again."
+ ) from ex
+
+ files = [entry for entry in results.entries if isinstance(entry, FileMetadata)]
+ documents = [
+ doc
+ for doc in (self._load_file_from_path(file.path_display) for file in files)
+ if doc is not None
+ ]
+ return documents
+
+ def _load_file_from_path(self, file_path: str) -> Optional[Document]:
+ """Load a file from a Dropbox path."""
+ dbx = self._create_dropbox_client()
+
+ try:
+ from dropbox import exceptions
+ except ImportError:
+ raise ImportError("You must run " "`pip install dropbox")
+
+ try:
+ file_metadata = dbx.files_get_metadata(file_path)
+
+ if file_metadata.is_downloadable:
+ _, response = dbx.files_download(file_path)
+
+ # Some types such as Paper, need to be exported.
+ elif file_metadata.export_info:
+ _, response = dbx.files_export(file_path, "markdown")
+
+ except exceptions.ApiError as ex:
+ raise ValueError(
+ f"Could not load file: {file_path}. Please verify the file path"
+ "and try again."
+ ) from ex
+
+ try:
+ text = response.content.decode("utf-8")
+ except UnicodeDecodeError:
+ file_extension = os.path.splitext(file_path)[1].lower()
+
+ if file_extension == ".pdf":
+ print(f"File {file_path} type detected as .pdf")
+ from langchain_community.document_loaders import UnstructuredPDFLoader
+
+ # Download it to a temporary file.
+ temp_dir = tempfile.TemporaryDirectory()
+ temp_pdf = Path(temp_dir.name) / "tmp.pdf"
+ with open(temp_pdf, mode="wb") as f:
+ f.write(response.content)
+
+ try:
+ loader = UnstructuredPDFLoader(str(temp_pdf))
+ docs = loader.load()
+ if docs:
+ return docs[0]
+ except Exception as pdf_ex:
+ print(f"Error while trying to parse PDF {file_path}: {pdf_ex}")
+ return None
+ else:
+ print(
+ f"File {file_path} could not be decoded as pdf or text. Skipping."
+ )
+
+ return None
+
+ metadata = {
+ "source": f"dropbox://{file_path}",
+ "title": os.path.basename(file_path),
+ }
+ return Document(page_content=text, metadata=metadata)
+
+ def _load_documents_from_paths(self) -> List[Document]:
+ """Load documents from a list of Dropbox file paths."""
+ if not self.dropbox_file_paths:
+ raise ValueError("file_paths must be set")
+
+ return [
+ doc
+ for doc in (
+ self._load_file_from_path(file_path)
+ for file_path in self.dropbox_file_paths
+ )
+ if doc is not None
+ ]
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ if self.dropbox_folder_path is not None:
+ return self._load_documents_from_folder(self.dropbox_folder_path)
+ else:
+ return self._load_documents_from_paths()
diff --git a/libs/community/langchain_community/document_loaders/duckdb_loader.py b/libs/community/langchain_community/document_loaders/duckdb_loader.py
new file mode 100644
index 00000000000..1e2c3022d54
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/duckdb_loader.py
@@ -0,0 +1,89 @@
+from typing import Dict, List, Optional, cast
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class DuckDBLoader(BaseLoader):
+ """Load from `DuckDB`.
+
+ Each document represents one row of the result. The `page_content_columns`
+ are written into the `page_content` of the document. The `metadata_columns`
+ are written into the `metadata` of the document. By default, all columns
+ are written into the `page_content` and none into the `metadata`.
+ """
+
+ def __init__(
+ self,
+ query: str,
+ database: str = ":memory:",
+ read_only: bool = False,
+ config: Optional[Dict[str, str]] = None,
+ page_content_columns: Optional[List[str]] = None,
+ metadata_columns: Optional[List[str]] = None,
+ ):
+ """
+
+ Args:
+ query: The query to execute.
+ database: The database to connect to. Defaults to ":memory:".
+ read_only: Whether to open the database in read-only mode.
+ Defaults to False.
+ config: A dictionary of configuration options to pass to the database.
+ Optional.
+ page_content_columns: The columns to write into the `page_content`
+ of the document. Optional.
+ metadata_columns: The columns to write into the `metadata` of the document.
+ Optional.
+ """
+ self.query = query
+ self.database = database
+ self.read_only = read_only
+ self.config = config or {}
+ self.page_content_columns = page_content_columns
+ self.metadata_columns = metadata_columns
+
+ def load(self) -> List[Document]:
+ try:
+ import duckdb
+ except ImportError:
+ raise ImportError(
+ "Could not import duckdb python package. "
+ "Please install it with `pip install duckdb`."
+ )
+
+ docs = []
+ with duckdb.connect(
+ database=self.database, read_only=self.read_only, config=self.config
+ ) as con:
+ query_result = con.execute(self.query)
+ results = query_result.fetchall()
+ description = cast(list, query_result.description)
+ field_names = [c[0] for c in description]
+
+ if self.page_content_columns is None:
+ page_content_columns = field_names
+ else:
+ page_content_columns = self.page_content_columns
+
+ if self.metadata_columns is None:
+ metadata_columns = []
+ else:
+ metadata_columns = self.metadata_columns
+
+ for result in results:
+ page_content = "\n".join(
+ f"{column}: {result[field_names.index(column)]}"
+ for column in page_content_columns
+ )
+
+ metadata = {
+ column: result[field_names.index(column)]
+ for column in metadata_columns
+ }
+
+ doc = Document(page_content=page_content, metadata=metadata)
+ docs.append(doc)
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/email.py b/libs/community/langchain_community/document_loaders/email.py
new file mode 100644
index 00000000000..277fdcc640c
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/email.py
@@ -0,0 +1,117 @@
+import os
+from typing import Any, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ satisfies_min_unstructured_version,
+)
+
+
+class UnstructuredEmailLoader(UnstructuredFileLoader):
+ """Load email files using `Unstructured`.
+
+ Works with both
+ .eml and .msg files. You can process attachments in addition to the
+ e-mail message itself by passing process_attachments=True into the
+ constructor for the loader. By default, attachments will be processed
+ with the unstructured partition function. If you already know the document
+ types of the attachments, you can specify another partitioning function
+ with the attachment partitioner kwarg.
+
+ Example
+ -------
+ from langchain_community.document_loaders import UnstructuredEmailLoader
+
+ loader = UnstructuredEmailLoader("example_data/fake-email.eml", mode="elements")
+ loader.load()
+
+ Example
+ -------
+ from langchain_community.document_loaders import UnstructuredEmailLoader
+
+ loader = UnstructuredEmailLoader(
+ "example_data/fake-email-attachment.eml",
+ mode="elements",
+ process_attachments=True,
+ )
+ loader.load()
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ process_attachments = unstructured_kwargs.get("process_attachments")
+ attachment_partitioner = unstructured_kwargs.get("attachment_partitioner")
+
+ if process_attachments and attachment_partitioner is None:
+ from unstructured.partition.auto import partition
+
+ unstructured_kwargs["attachment_partitioner"] = partition
+
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.file_utils.filetype import FileType, detect_filetype
+
+ filetype = detect_filetype(self.file_path)
+
+ if filetype == FileType.EML:
+ from unstructured.partition.email import partition_email
+
+ return partition_email(filename=self.file_path, **self.unstructured_kwargs)
+ elif satisfies_min_unstructured_version("0.5.8") and filetype == FileType.MSG:
+ from unstructured.partition.msg import partition_msg
+
+ return partition_msg(filename=self.file_path, **self.unstructured_kwargs)
+ else:
+ raise ValueError(
+ f"Filetype {filetype} is not supported in UnstructuredEmailLoader."
+ )
+
+
+class OutlookMessageLoader(BaseLoader):
+ """
+ Loads Outlook Message files using extract_msg.
+
+ https://github.com/TeamMsgExtractor/msg-extractor
+ """
+
+ def __init__(self, file_path: str):
+ """Initialize with a file path.
+
+ Args:
+ file_path: The path to the Outlook Message file.
+ """
+
+ self.file_path = file_path
+
+ if not os.path.isfile(self.file_path):
+ raise ValueError("File path %s is not a valid file" % self.file_path)
+
+ try:
+ import extract_msg # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "extract_msg is not installed. Please install it with "
+ "`pip install extract_msg`"
+ )
+
+ def load(self) -> List[Document]:
+ """Load data into document objects."""
+ import extract_msg
+
+ msg = extract_msg.Message(self.file_path)
+ return [
+ Document(
+ page_content=msg.body,
+ metadata={
+ "source": self.file_path,
+ "subject": msg.subject,
+ "sender": msg.sender,
+ "date": msg.date,
+ },
+ )
+ ]
diff --git a/libs/community/langchain_community/document_loaders/epub.py b/libs/community/langchain_community/document_loaders/epub.py
new file mode 100644
index 00000000000..fbef21b9b5a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/epub.py
@@ -0,0 +1,42 @@
+from typing import List
+
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ satisfies_min_unstructured_version,
+)
+
+
+class UnstructuredEPubLoader(UnstructuredFileLoader):
+ """Load `EPub` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredEPubLoader
+
+ loader = UnstructuredEPubLoader(
+ "example.epub", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-epub
+ """
+
+ def _get_elements(self) -> List:
+ min_unstructured_version = "0.5.4"
+ if not satisfies_min_unstructured_version(min_unstructured_version):
+ raise ValueError(
+ "Partitioning epub files is only supported in "
+ f"unstructured>={min_unstructured_version}."
+ )
+ from unstructured.partition.epub import partition_epub
+
+ return partition_epub(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/etherscan.py b/libs/community/langchain_community/document_loaders/etherscan.py
new file mode 100644
index 00000000000..862b63843a1
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/etherscan.py
@@ -0,0 +1,203 @@
+import os
+import re
+from typing import Iterator, List
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class EtherscanLoader(BaseLoader):
+ """Load transactions from `Ethereum` mainnet.
+
+ The Loader use Etherscan API to interact with Ethereum mainnet.
+
+ ETHERSCAN_API_KEY environment variable must be set use this loader.
+ """
+
+ def __init__(
+ self,
+ account_address: str,
+ api_key: str = "docs-demo",
+ filter: str = "normal_transaction",
+ page: int = 1,
+ offset: int = 10,
+ start_block: int = 0,
+ end_block: int = 99999999,
+ sort: str = "desc",
+ ):
+ self.account_address = account_address
+ self.api_key = os.environ.get("ETHERSCAN_API_KEY") or api_key
+ self.filter = filter
+ self.page = page
+ self.offset = offset
+ self.start_block = start_block
+ self.end_block = end_block
+ self.sort = sort
+
+ if not self.api_key:
+ raise ValueError("Etherscan API key not provided")
+
+ if not re.match(r"^0x[a-fA-F0-9]{40}$", self.account_address):
+ raise ValueError(f"Invalid contract address {self.account_address}")
+ if filter not in [
+ "normal_transaction",
+ "internal_transaction",
+ "erc20_transaction",
+ "eth_balance",
+ "erc721_transaction",
+ "erc1155_transaction",
+ ]:
+ raise ValueError(f"Invalid filter {filter}")
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load Documents from table."""
+ result = []
+ if self.filter == "normal_transaction":
+ result = self.getNormTx()
+ elif self.filter == "internal_transaction":
+ result = self.getInternalTx()
+ elif self.filter == "erc20_transaction":
+ result = self.getERC20Tx()
+ elif self.filter == "eth_balance":
+ result = self.getEthBalance()
+ elif self.filter == "erc721_transaction":
+ result = self.getERC721Tx()
+ elif self.filter == "erc1155_transaction":
+ result = self.getERC1155Tx()
+ else:
+ raise ValueError(f"Invalid filter {filter}")
+ for doc in result:
+ yield doc
+
+ def load(self) -> List[Document]:
+ """Load transactions from spcifc account by Etherscan."""
+ return list(self.lazy_load())
+
+ def getNormTx(self) -> List[Document]:
+ url = (
+ f"https://api.etherscan.io/api?module=account&action=txlist&address={self.account_address}"
+ f"&startblock={self.start_block}&endblock={self.end_block}&page={self.page}"
+ f"&offset={self.offset}&sort={self.sort}&apikey={self.api_key}"
+ )
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ print("Error occurred while making the request:", e)
+ items = response.json()["result"]
+ result = []
+ if len(items) == 0:
+ return [Document(page_content="")]
+ for item in items:
+ content = str(item)
+ metadata = {"from": item["from"], "tx_hash": item["hash"], "to": item["to"]}
+ result.append(Document(page_content=content, metadata=metadata))
+ print(len(result))
+ return result
+
+ def getEthBalance(self) -> List[Document]:
+ url = (
+ f"https://api.etherscan.io/api?module=account&action=balance"
+ f"&address={self.account_address}&tag=latest&apikey={self.api_key}"
+ )
+
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ print("Error occurred while making the request:", e)
+ return [Document(page_content=response.json()["result"])]
+
+ def getInternalTx(self) -> List[Document]:
+ url = (
+ f"https://api.etherscan.io/api?module=account&action=txlistinternal"
+ f"&address={self.account_address}&startblock={self.start_block}"
+ f"&endblock={self.end_block}&page={self.page}&offset={self.offset}"
+ f"&sort={self.sort}&apikey={self.api_key}"
+ )
+
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ print("Error occurred while making the request:", e)
+ items = response.json()["result"]
+ result = []
+ if len(items) == 0:
+ return [Document(page_content="")]
+ for item in items:
+ content = str(item)
+ metadata = {"from": item["from"], "tx_hash": item["hash"], "to": item["to"]}
+ result.append(Document(page_content=content, metadata=metadata))
+ return result
+
+ def getERC20Tx(self) -> List[Document]:
+ url = (
+ f"https://api.etherscan.io/api?module=account&action=tokentx"
+ f"&address={self.account_address}&startblock={self.start_block}"
+ f"&endblock={self.end_block}&page={self.page}&offset={self.offset}"
+ f"&sort={self.sort}&apikey={self.api_key}"
+ )
+
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ print("Error occurred while making the request:", e)
+ items = response.json()["result"]
+ result = []
+ if len(items) == 0:
+ return [Document(page_content="")]
+ for item in items:
+ content = str(item)
+ metadata = {"from": item["from"], "tx_hash": item["hash"], "to": item["to"]}
+ result.append(Document(page_content=content, metadata=metadata))
+ return result
+
+ def getERC721Tx(self) -> List[Document]:
+ url = (
+ f"https://api.etherscan.io/api?module=account&action=tokennfttx"
+ f"&address={self.account_address}&startblock={self.start_block}"
+ f"&endblock={self.end_block}&page={self.page}&offset={self.offset}"
+ f"&sort={self.sort}&apikey={self.api_key}"
+ )
+
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ print("Error occurred while making the request:", e)
+ items = response.json()["result"]
+ result = []
+ if len(items) == 0:
+ return [Document(page_content="")]
+ for item in items:
+ content = str(item)
+ metadata = {"from": item["from"], "tx_hash": item["hash"], "to": item["to"]}
+ result.append(Document(page_content=content, metadata=metadata))
+ return result
+
+ def getERC1155Tx(self) -> List[Document]:
+ url = (
+ f"https://api.etherscan.io/api?module=account&action=token1155tx"
+ f"&address={self.account_address}&startblock={self.start_block}"
+ f"&endblock={self.end_block}&page={self.page}&offset={self.offset}"
+ f"&sort={self.sort}&apikey={self.api_key}"
+ )
+
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ print("Error occurred while making the request:", e)
+ items = response.json()["result"]
+ result = []
+ if len(items) == 0:
+ return [Document(page_content="")]
+ for item in items:
+ content = str(item)
+ metadata = {"from": item["from"], "tx_hash": item["hash"], "to": item["to"]}
+ result.append(Document(page_content=content, metadata=metadata))
+ return result
diff --git a/libs/community/langchain_community/document_loaders/evernote.py b/libs/community/langchain_community/document_loaders/evernote.py
new file mode 100644
index 00000000000..836f393d784
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/evernote.py
@@ -0,0 +1,151 @@
+"""Load documents from Evernote.
+
+https://gist.github.com/foxmask/7b29c43a161e001ff04afdb2f181e31c
+"""
+import hashlib
+import logging
+from base64 import b64decode
+from time import strptime
+from typing import Any, Dict, Iterator, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class EverNoteLoader(BaseLoader):
+ """Load from `EverNote`.
+
+ Loads an EverNote notebook export file e.g. my_notebook.enex into Documents.
+ Instructions on producing this file can be found at
+ https://help.evernote.com/hc/en-us/articles/209005557-Export-notes-and-notebooks-as-ENEX-or-HTML
+
+ Currently only the plain text in the note is extracted and stored as the contents
+ of the Document, any non content metadata (e.g. 'author', 'created', 'updated' etc.
+ but not 'content-raw' or 'resource') tags on the note will be extracted and stored
+ as metadata on the Document.
+
+ Args:
+ file_path (str): The path to the notebook export with a .enex extension
+ load_single_document (bool): Whether or not to concatenate the content of all
+ notes into a single long Document.
+ If this is set to True (default) then the only metadata on the document will be
+ the 'source' which contains the file name of the export.
+ """ # noqa: E501
+
+ def __init__(self, file_path: str, load_single_document: bool = True):
+ """Initialize with file path."""
+ self.file_path = file_path
+ self.load_single_document = load_single_document
+
+ def load(self) -> List[Document]:
+ """Load documents from EverNote export file."""
+ documents = [
+ Document(
+ page_content=note["content"],
+ metadata={
+ **{
+ key: value
+ for key, value in note.items()
+ if key not in ["content", "content-raw", "resource"]
+ },
+ **{"source": self.file_path},
+ },
+ )
+ for note in self._parse_note_xml(self.file_path)
+ if note.get("content") is not None
+ ]
+
+ if not self.load_single_document:
+ return documents
+
+ return [
+ Document(
+ page_content="".join([document.page_content for document in documents]),
+ metadata={"source": self.file_path},
+ )
+ ]
+
+ @staticmethod
+ def _parse_content(content: str) -> str:
+ try:
+ import html2text
+
+ return html2text.html2text(content).strip()
+ except ImportError as e:
+ raise ImportError(
+ "Could not import `html2text`. Although it is not a required package "
+ "to use Langchain, using the EverNote loader requires `html2text`. "
+ "Please install `html2text` via `pip install html2text` and try again."
+ ) from e
+
+ @staticmethod
+ def _parse_resource(resource: list) -> dict:
+ rsc_dict: Dict[str, Any] = {}
+ for elem in resource:
+ if elem.tag == "data":
+ # Sometimes elem.text is None
+ rsc_dict[elem.tag] = b64decode(elem.text) if elem.text else b""
+ rsc_dict["hash"] = hashlib.md5(rsc_dict[elem.tag]).hexdigest()
+ else:
+ rsc_dict[elem.tag] = elem.text
+
+ return rsc_dict
+
+ @staticmethod
+ def _parse_note(note: List, prefix: Optional[str] = None) -> dict:
+ note_dict: Dict[str, Any] = {}
+ resources = []
+
+ def add_prefix(element_tag: str) -> str:
+ if prefix is None:
+ return element_tag
+ return f"{prefix}.{element_tag}"
+
+ for elem in note:
+ if elem.tag == "content":
+ note_dict[elem.tag] = EverNoteLoader._parse_content(elem.text)
+ # A copy of original content
+ note_dict["content-raw"] = elem.text
+ elif elem.tag == "resource":
+ resources.append(EverNoteLoader._parse_resource(elem))
+ elif elem.tag == "created" or elem.tag == "updated":
+ note_dict[elem.tag] = strptime(elem.text, "%Y%m%dT%H%M%SZ")
+ elif elem.tag == "note-attributes":
+ additional_attributes = EverNoteLoader._parse_note(
+ elem, elem.tag
+ ) # Recursively enter the note-attributes tag
+ note_dict.update(additional_attributes)
+ else:
+ note_dict[elem.tag] = elem.text
+
+ if len(resources) > 0:
+ note_dict["resource"] = resources
+
+ return {add_prefix(key): value for key, value in note_dict.items()}
+
+ @staticmethod
+ def _parse_note_xml(xml_file: str) -> Iterator[Dict[str, Any]]:
+ """Parse Evernote xml."""
+ # Without huge_tree set to True, parser may complain about huge text node
+ # Try to recover, because there may be " ", which will cause
+ # "XMLSyntaxError: Entity 'nbsp' not defined"
+ try:
+ from lxml import etree
+ except ImportError as e:
+ logger.error(
+ "Could not import `lxml`. Although it is not a required package to use "
+ "Langchain, using the EverNote loader requires `lxml`. Please install "
+ "`lxml` via `pip install lxml` and try again."
+ )
+ raise e
+
+ context = etree.iterparse(
+ xml_file, encoding="utf-8", strip_cdata=False, huge_tree=True, recover=True
+ )
+
+ for action, elem in context:
+ if elem.tag == "note":
+ yield EverNoteLoader._parse_note(elem)
diff --git a/libs/community/langchain_community/document_loaders/excel.py b/libs/community/langchain_community/document_loaders/excel.py
new file mode 100644
index 00000000000..6660cfd71de
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/excel.py
@@ -0,0 +1,46 @@
+"""Loads Microsoft Excel files."""
+from typing import Any, List
+
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ validate_unstructured_version,
+)
+
+
+class UnstructuredExcelLoader(UnstructuredFileLoader):
+ """Load Microsoft Excel files using `Unstructured`.
+
+ Like other
+ Unstructured loaders, UnstructuredExcelLoader can be used in both
+ "single" and "elements" mode. If you use the loader in "elements"
+ mode, each sheet in the Excel file will be a an Unstructured Table
+ element. If you use the loader in "elements" mode, an
+ HTML representation of the table will be available in the
+ "text_as_html" key in the document metadata.
+
+ Examples
+ --------
+ from langchain_community.document_loaders.excel import UnstructuredExcelLoader
+
+ loader = UnstructuredExcelLoader("stanley-cups.xlsd", mode="elements")
+ docs = loader.load()
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ """
+
+ Args:
+ file_path: The path to the Microsoft Excel file.
+ mode: The mode to use when partitioning the file. See unstructured docs
+ for more info. Optional. Defaults to "single".
+ **unstructured_kwargs: Keyword arguments to pass to unstructured.
+ """
+ validate_unstructured_version(min_unstructured_version="0.6.7")
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.xlsx import partition_xlsx
+
+ return partition_xlsx(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/facebook_chat.py b/libs/community/langchain_community/document_loaders/facebook_chat.py
new file mode 100644
index 00000000000..759499b25e8
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/facebook_chat.py
@@ -0,0 +1,46 @@
+import datetime
+import json
+from pathlib import Path
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+def concatenate_rows(row: dict) -> str:
+ """Combine message information in a readable format ready to be used.
+
+ Args:
+ row: dictionary containing message information.
+ """
+ sender = row["sender_name"]
+ text = row["content"]
+ date = datetime.datetime.fromtimestamp(row["timestamp_ms"] / 1000).strftime(
+ "%Y-%m-%d %H:%M:%S"
+ )
+ return f"{sender} on {date}: {text}\n\n"
+
+
+class FacebookChatLoader(BaseLoader):
+ """Load `Facebook Chat` messages directory dump."""
+
+ def __init__(self, path: str):
+ """Initialize with a path."""
+ self.file_path = path
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ p = Path(self.file_path)
+
+ with open(p, encoding="utf8") as f:
+ d = json.load(f)
+
+ text = "".join(
+ concatenate_rows(message)
+ for message in d["messages"]
+ if message.get("content") and isinstance(message["content"], str)
+ )
+ metadata = {"source": str(p)}
+
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/fauna.py b/libs/community/langchain_community/document_loaders/fauna.py
new file mode 100644
index 00000000000..5b17216dc00
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/fauna.py
@@ -0,0 +1,65 @@
+from typing import Iterator, List, Optional, Sequence
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class FaunaLoader(BaseLoader):
+ """Load from `FaunaDB`.
+
+ Attributes:
+ query (str): The FQL query string to execute.
+ page_content_field (str): The field that contains the content of each page.
+ secret (str): The secret key for authenticating to FaunaDB.
+ metadata_fields (Optional[Sequence[str]]):
+ Optional list of field names to include in metadata.
+ """
+
+ def __init__(
+ self,
+ query: str,
+ page_content_field: str,
+ secret: str,
+ metadata_fields: Optional[Sequence[str]] = None,
+ ):
+ self.query = query
+ self.page_content_field = page_content_field
+ self.secret = secret
+ self.metadata_fields = metadata_fields
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ try:
+ from fauna import Page, fql
+ from fauna.client import Client
+ from fauna.encoding import QuerySuccess
+ except ImportError:
+ raise ImportError(
+ "Could not import fauna python package. "
+ "Please install it with `pip install fauna`."
+ )
+ # Create Fauna Client
+ client = Client(secret=self.secret)
+ # Run FQL Query
+ response: QuerySuccess = client.query(fql(self.query))
+ page: Page = response.data
+ for result in page:
+ if result is not None:
+ document_dict = dict(result.items())
+ page_content = ""
+ for key, value in document_dict.items():
+ if key == self.page_content_field:
+ page_content = value
+ document: Document = Document(
+ page_content=page_content,
+ metadata={"id": result.id, "ts": result.ts},
+ )
+ yield document
+ if page.after is not None:
+ yield Document(
+ page_content="Next Page Exists",
+ metadata={"after": page.after},
+ )
diff --git a/libs/community/langchain_community/document_loaders/figma.py b/libs/community/langchain_community/document_loaders/figma.py
new file mode 100644
index 00000000000..d147aaf9f20
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/figma.py
@@ -0,0 +1,48 @@
+import json
+import urllib.request
+from typing import Any, List
+
+from langchain_core.documents import Document
+from langchain_core.utils import stringify_dict
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class FigmaFileLoader(BaseLoader):
+ """Load `Figma` file."""
+
+ def __init__(self, access_token: str, ids: str, key: str):
+ """Initialize with access token, ids, and key.
+
+ Args:
+ access_token: The access token for the Figma REST API.
+ ids: The ids of the Figma file.
+ key: The key for the Figma file
+ """
+ self.access_token = access_token
+ self.ids = ids
+ self.key = key
+
+ def _construct_figma_api_url(self) -> str:
+ api_url = "https://api.figma.com/v1/files/%s/nodes?ids=%s" % (
+ self.key,
+ self.ids,
+ )
+ return api_url
+
+ def _get_figma_file(self) -> Any:
+ """Get Figma file from Figma REST API."""
+ headers = {"X-Figma-Token": self.access_token}
+ request = urllib.request.Request(
+ self._construct_figma_api_url(), headers=headers
+ )
+ with urllib.request.urlopen(request) as response:
+ json_data = json.loads(response.read().decode())
+ return json_data
+
+ def load(self) -> List[Document]:
+ """Load file"""
+ data = self._get_figma_file()
+ text = stringify_dict(data)
+ metadata = {"source": self._construct_figma_api_url()}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/gcs_directory.py b/libs/community/langchain_community/document_loaders/gcs_directory.py
new file mode 100644
index 00000000000..d6f5b61c0db
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/gcs_directory.py
@@ -0,0 +1,58 @@
+from typing import Callable, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.gcs_file import GCSFileLoader
+from langchain_community.utilities.vertexai import get_client_info
+
+
+class GCSDirectoryLoader(BaseLoader):
+ """Load from GCS directory."""
+
+ def __init__(
+ self,
+ project_name: str,
+ bucket: str,
+ prefix: str = "",
+ loader_func: Optional[Callable[[str], BaseLoader]] = None,
+ ):
+ """Initialize with bucket and key name.
+
+ Args:
+ project_name: The ID of the project for the GCS bucket.
+ bucket: The name of the GCS bucket.
+ prefix: The prefix of the GCS bucket.
+ loader_func: A loader function that instantiates a loader based on a
+ file_path argument. If nothing is provided, the GCSFileLoader
+ would use its default loader.
+ """
+ self.project_name = project_name
+ self.bucket = bucket
+ self.prefix = prefix
+ self._loader_func = loader_func
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ try:
+ from google.cloud import storage
+ except ImportError:
+ raise ImportError(
+ "Could not import google-cloud-storage python package. "
+ "Please install it with `pip install google-cloud-storage`."
+ )
+ client = storage.Client(
+ project=self.project_name,
+ client_info=get_client_info(module="google-cloud-storage"),
+ )
+ docs = []
+ for blob in client.list_blobs(self.bucket, prefix=self.prefix):
+ # we shall just skip directories since GCSFileLoader creates
+ # intermediate directories on the fly
+ if blob.name.endswith("/"):
+ continue
+ loader = GCSFileLoader(
+ self.project_name, self.bucket, blob.name, loader_func=self._loader_func
+ )
+ docs.extend(loader.load())
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/gcs_file.py b/libs/community/langchain_community/document_loaders/gcs_file.py
new file mode 100644
index 00000000000..d43e81ceba8
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/gcs_file.py
@@ -0,0 +1,83 @@
+import os
+import tempfile
+from typing import Callable, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+from langchain_community.utilities.vertexai import get_client_info
+
+
+class GCSFileLoader(BaseLoader):
+ """Load from GCS file."""
+
+ def __init__(
+ self,
+ project_name: str,
+ bucket: str,
+ blob: str,
+ loader_func: Optional[Callable[[str], BaseLoader]] = None,
+ ):
+ """Initialize with bucket and key name.
+
+ Args:
+ project_name: The name of the project to load
+ bucket: The name of the GCS bucket.
+ blob: The name of the GCS blob to load.
+ loader_func: A loader function that instantiates a loader based on a
+ file_path argument. If nothing is provided, the
+ UnstructuredFileLoader is used.
+
+ Examples:
+ To use an alternative PDF loader:
+ >> from from langchain_community.document_loaders import PyPDFLoader
+ >> loader = GCSFileLoader(..., loader_func=PyPDFLoader)
+
+ To use UnstructuredFileLoader with additional arguments:
+ >> loader = GCSFileLoader(...,
+ >> loader_func=lambda x: UnstructuredFileLoader(x, mode="elements"))
+
+ """
+ self.bucket = bucket
+ self.blob = blob
+ self.project_name = project_name
+
+ def default_loader_func(file_path: str) -> BaseLoader:
+ return UnstructuredFileLoader(file_path)
+
+ self._loader_func = loader_func if loader_func else default_loader_func
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ try:
+ from google.cloud import storage
+ except ImportError:
+ raise ImportError(
+ "Could not import google-cloud-storage python package. "
+ "Please install it with `pip install google-cloud-storage`."
+ )
+
+ # Initialise a client
+ storage_client = storage.Client(
+ self.project_name, client_info=get_client_info("google-cloud-storage")
+ )
+ # Create a bucket object for our bucket
+ bucket = storage_client.get_bucket(self.bucket)
+ # Create a blob object from the filepath
+ blob = bucket.blob(self.blob)
+ # retrieve custom metadata associated with the blob
+ metadata = bucket.get_blob(self.blob).metadata
+ with tempfile.TemporaryDirectory() as temp_dir:
+ file_path = f"{temp_dir}/{self.blob}"
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ # Download the file to a destination
+ blob.download_to_filename(file_path)
+ loader = self._loader_func(file_path)
+ docs = loader.load()
+ for doc in docs:
+ if "source" in doc.metadata:
+ doc.metadata["source"] = f"gs://{self.bucket}/{self.blob}"
+ if metadata:
+ doc.metadata.update(metadata)
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/generic.py b/libs/community/langchain_community/document_loaders/generic.py
new file mode 100644
index 00000000000..0ec6ca60bdf
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/generic.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Iterator,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Union,
+)
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader
+from langchain_community.document_loaders.blob_loaders import (
+ BlobLoader,
+ FileSystemBlobLoader,
+)
+from langchain_community.document_loaders.parsers.registry import get_parser
+
+if TYPE_CHECKING:
+ from langchain.text_splitter import TextSplitter
+
+_PathLike = Union[str, Path]
+
+DEFAULT = Literal["default"]
+
+
+class GenericLoader(BaseLoader):
+ """Generic Document Loader.
+
+ A generic document loader that allows combining an arbitrary blob loader with
+ a blob parser.
+
+ Examples:
+
+ Parse a specific PDF file:
+
+ .. code-block:: python
+
+ from langchain_community.document_loaders import GenericLoader
+ from langchain_community.document_loaders.parsers.pdf import PyPDFParser
+
+ # Recursively load all text files in a directory.
+ loader = GenericLoader.from_filesystem(
+ "my_lovely_pdf.pdf",
+ parser=PyPDFParser()
+ )
+
+ .. code-block:: python
+
+ from langchain_community.document_loaders import GenericLoader
+ from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader
+
+
+ loader = GenericLoader.from_filesystem(
+ path="path/to/directory",
+ glob="**/[!.]*",
+ suffixes=[".pdf"],
+ show_progress=True,
+ )
+
+ docs = loader.lazy_load()
+ next(docs)
+
+ Example instantiations to change which files are loaded:
+
+ .. code-block:: python
+
+ # Recursively load all text files in a directory.
+ loader = GenericLoader.from_filesystem("/path/to/dir", glob="**/*.txt")
+
+ # Recursively load all non-hidden files in a directory.
+ loader = GenericLoader.from_filesystem("/path/to/dir", glob="**/[!.]*")
+
+ # Load all files in a directory without recursion.
+ loader = GenericLoader.from_filesystem("/path/to/dir", glob="*")
+
+ Example instantiations to change which parser is used:
+
+ .. code-block:: python
+
+ from langchain_community.document_loaders.parsers.pdf import PyPDFParser
+
+ # Recursively load all text files in a directory.
+ loader = GenericLoader.from_filesystem(
+ "/path/to/dir",
+ glob="**/*.pdf",
+ parser=PyPDFParser()
+ )
+
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ blob_loader: BlobLoader,
+ blob_parser: BaseBlobParser,
+ ) -> None:
+ """A generic document loader.
+
+ Args:
+ blob_loader: A blob loader which knows how to yield blobs
+ blob_parser: A blob parser which knows how to parse blobs into documents
+ """
+ self.blob_loader = blob_loader
+ self.blob_parser = blob_parser
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Load documents lazily. Use this when working at a large scale."""
+ for blob in self.blob_loader.yield_blobs():
+ yield from self.blob_parser.lazy_parse(blob)
+
+ def load(self) -> List[Document]:
+ """Load all documents."""
+ return list(self.lazy_load())
+
+ def load_and_split(
+ self, text_splitter: Optional[TextSplitter] = None
+ ) -> List[Document]:
+ """Load all documents and split them into sentences."""
+ raise NotImplementedError(
+ "Loading and splitting is not yet implemented for generic loaders. "
+ "When they will be implemented they will be added via the initializer. "
+ "This method should not be used going forward."
+ )
+
+ @classmethod
+ def from_filesystem(
+ cls,
+ path: _PathLike,
+ *,
+ glob: str = "**/[!.]*",
+ exclude: Sequence[str] = (),
+ suffixes: Optional[Sequence[str]] = None,
+ show_progress: bool = False,
+ parser: Union[DEFAULT, BaseBlobParser] = "default",
+ parser_kwargs: Optional[dict] = None,
+ ) -> GenericLoader:
+ """Create a generic document loader using a filesystem blob loader.
+
+ Args:
+ path: The path to the directory to load documents from OR the path to a
+ single file to load. If this is a file, glob, exclude, suffixes
+ will be ignored.
+ glob: The glob pattern to use to find documents.
+ suffixes: The suffixes to use to filter documents. If None, all files
+ matching the glob will be loaded.
+ exclude: A list of patterns to exclude from the loader.
+ show_progress: Whether to show a progress bar or not (requires tqdm).
+ Proxies to the file system loader.
+ parser: A blob parser which knows how to parse blobs into documents,
+ will instantiate a default parser if not provided.
+ The default can be overridden by either passing a parser or
+ setting the class attribute `blob_parser` (the latter
+ should be used with inheritance).
+ parser_kwargs: Keyword arguments to pass to the parser.
+
+ Returns:
+ A generic document loader.
+ """
+ blob_loader = FileSystemBlobLoader(
+ path,
+ glob=glob,
+ exclude=exclude,
+ suffixes=suffixes,
+ show_progress=show_progress,
+ )
+ if isinstance(parser, str):
+ if parser == "default":
+ try:
+ # If there is an implementation of get_parser on the class, use it.
+ blob_parser = cls.get_parser(**(parser_kwargs or {}))
+ except NotImplementedError:
+ # if not then use the global registry.
+ blob_parser = get_parser(parser)
+ else:
+ blob_parser = get_parser(parser)
+ else:
+ blob_parser = parser
+ return cls(blob_loader, blob_parser)
+
+ @staticmethod
+ def get_parser(**kwargs: Any) -> BaseBlobParser:
+ """Override this method to associate a default parser with the class."""
+ raise NotImplementedError()
diff --git a/libs/community/langchain_community/document_loaders/geodataframe.py b/libs/community/langchain_community/document_loaders/geodataframe.py
new file mode 100644
index 00000000000..09a4c5ae9f9
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/geodataframe.py
@@ -0,0 +1,73 @@
+from typing import Any, Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class GeoDataFrameLoader(BaseLoader):
+ """Load `geopandas` Dataframe."""
+
+ def __init__(self, data_frame: Any, page_content_column: str = "geometry"):
+ """Initialize with geopandas Dataframe.
+
+ Args:
+ data_frame: geopandas DataFrame object.
+ page_content_column: Name of the column containing the page content.
+ Defaults to "geometry".
+ """
+
+ try:
+ import geopandas as gpd
+ except ImportError:
+ raise ImportError(
+ "geopandas package not found, please install it with "
+ "`pip install geopandas`"
+ )
+
+ if not isinstance(data_frame, gpd.GeoDataFrame):
+ raise ValueError(
+ f"Expected data_frame to be a gpd.GeoDataFrame, got {type(data_frame)}"
+ )
+
+ if page_content_column not in data_frame.columns:
+ raise ValueError(
+ f"Expected data_frame to have a column named {page_content_column}"
+ )
+
+ if not isinstance(data_frame[page_content_column], gpd.GeoSeries):
+ raise ValueError(
+ f"Expected data_frame[{page_content_column}] to be a GeoSeries"
+ )
+
+ self.data_frame = data_frame
+ self.page_content_column = page_content_column
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load records from dataframe."""
+
+ # assumes all geometries in GeoSeries are same CRS and Geom Type
+ crs_str = self.data_frame.crs.to_string() if self.data_frame.crs else None
+ geometry_type = self.data_frame.geometry.geom_type.iloc[0]
+
+ for _, row in self.data_frame.iterrows():
+ geom = row[self.page_content_column]
+
+ xmin, ymin, xmax, ymax = geom.bounds
+
+ metadata = row.to_dict()
+ metadata["crs"] = crs_str
+ metadata["geometry_type"] = geometry_type
+ metadata["xmin"] = xmin
+ metadata["ymin"] = ymin
+ metadata["xmax"] = xmax
+ metadata["ymax"] = ymax
+
+ metadata.pop(self.page_content_column)
+
+ # using WKT instead of str() to help GIS system interoperability
+ yield Document(page_content=geom.wkt, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """Load full dataframe."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/git.py b/libs/community/langchain_community/document_loaders/git.py
new file mode 100644
index 00000000000..5bd5341a3d8
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/git.py
@@ -0,0 +1,110 @@
+import os
+from typing import Callable, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class GitLoader(BaseLoader):
+ """Load `Git` repository files.
+
+ The Repository can be local on disk available at `repo_path`,
+ or remote at `clone_url` that will be cloned to `repo_path`.
+ Currently, supports only text files.
+
+ Each document represents one file in the repository. The `path` points to
+ the local Git repository, and the `branch` specifies the branch to load
+ files from. By default, it loads from the `main` branch.
+ """
+
+ def __init__(
+ self,
+ repo_path: str,
+ clone_url: Optional[str] = None,
+ branch: Optional[str] = "main",
+ file_filter: Optional[Callable[[str], bool]] = None,
+ ):
+ """
+
+ Args:
+ repo_path: The path to the Git repository.
+ clone_url: Optional. The URL to clone the repository from.
+ branch: Optional. The branch to load files from. Defaults to `main`.
+ file_filter: Optional. A function that takes a file path and returns
+ a boolean indicating whether to load the file. Defaults to None.
+ """
+ self.repo_path = repo_path
+ self.clone_url = clone_url
+ self.branch = branch
+ self.file_filter = file_filter
+
+ def load(self) -> List[Document]:
+ try:
+ from git import Blob, Repo # type: ignore
+ except ImportError as ex:
+ raise ImportError(
+ "Could not import git python package. "
+ "Please install it with `pip install GitPython`."
+ ) from ex
+
+ if not os.path.exists(self.repo_path) and self.clone_url is None:
+ raise ValueError(f"Path {self.repo_path} does not exist")
+ elif self.clone_url:
+ # If the repo_path already contains a git repository, verify that it's the
+ # same repository as the one we're trying to clone.
+ if os.path.isdir(os.path.join(self.repo_path, ".git")):
+ repo = Repo(self.repo_path)
+ # If the existing repository is not the same as the one we're trying to
+ # clone, raise an error.
+ if repo.remotes.origin.url != self.clone_url:
+ raise ValueError(
+ "A different repository is already cloned at this path."
+ )
+ else:
+ repo = Repo.clone_from(self.clone_url, self.repo_path)
+ repo.git.checkout(self.branch)
+ else:
+ repo = Repo(self.repo_path)
+ repo.git.checkout(self.branch)
+
+ docs: List[Document] = []
+
+ for item in repo.tree().traverse():
+ if not isinstance(item, Blob):
+ continue
+
+ file_path = os.path.join(self.repo_path, item.path)
+
+ ignored_files = repo.ignored([file_path]) # type: ignore
+ if len(ignored_files):
+ continue
+
+ # uses filter to skip files
+ if self.file_filter and not self.file_filter(file_path):
+ continue
+
+ rel_file_path = os.path.relpath(file_path, self.repo_path)
+ try:
+ with open(file_path, "rb") as f:
+ content = f.read()
+ file_type = os.path.splitext(item.name)[1]
+
+ # loads only text files
+ try:
+ text_content = content.decode("utf-8")
+ except UnicodeDecodeError:
+ continue
+
+ metadata = {
+ "source": rel_file_path,
+ "file_path": rel_file_path,
+ "file_name": item.name,
+ "file_type": file_type,
+ }
+ doc = Document(page_content=text_content, metadata=metadata)
+ docs.append(doc)
+ except Exception as e:
+ print(f"Error reading file {file_path}: {e}")
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/gitbook.py b/libs/community/langchain_community/document_loaders/gitbook.py
new file mode 100644
index 00000000000..d2a32b367a0
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/gitbook.py
@@ -0,0 +1,83 @@
+from typing import Any, List, Optional
+from urllib.parse import urljoin, urlparse
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+
+class GitbookLoader(WebBaseLoader):
+ """Load `GitBook` data.
+
+ 1. load from either a single page, or
+ 2. load all (relative) paths in the navbar.
+ """
+
+ def __init__(
+ self,
+ web_page: str,
+ load_all_paths: bool = False,
+ base_url: Optional[str] = None,
+ content_selector: str = "main",
+ continue_on_failure: bool = False,
+ ):
+ """Initialize with web page and whether to load all paths.
+
+ Args:
+ web_page: The web page to load or the starting point from where
+ relative paths are discovered.
+ load_all_paths: If set to True, all relative paths in the navbar
+ are loaded instead of only `web_page`.
+ base_url: If `load_all_paths` is True, the relative paths are
+ appended to this base url. Defaults to `web_page`.
+ content_selector: The CSS selector for the content to load.
+ Defaults to "main".
+ continue_on_failure: whether to continue loading the sitemap if an error
+ occurs loading a url, emitting a warning instead of raising an
+ exception. Setting this to True makes the loader more robust, but also
+ may result in missing data. Default: False
+ """
+ self.base_url = base_url or web_page
+ if self.base_url.endswith("/"):
+ self.base_url = self.base_url[:-1]
+ if load_all_paths:
+ # set web_path to the sitemap if we want to crawl all paths
+ web_page = f"{self.base_url}/sitemap.xml"
+ super().__init__(web_paths=(web_page,), continue_on_failure=continue_on_failure)
+ self.load_all_paths = load_all_paths
+ self.content_selector = content_selector
+
+ def load(self) -> List[Document]:
+ """Fetch text from one single GitBook page."""
+ if self.load_all_paths:
+ soup_info = self.scrape()
+ relative_paths = self._get_paths(soup_info)
+ urls = [urljoin(self.base_url, path) for path in relative_paths]
+ soup_infos = self.scrape_all(urls)
+ _documents = [
+ self._get_document(soup_info, url)
+ for soup_info, url in zip(soup_infos, urls)
+ ]
+ else:
+ soup_info = self.scrape()
+ _documents = [self._get_document(soup_info, self.web_path)]
+ documents = [d for d in _documents if d]
+
+ return documents
+
+ def _get_document(
+ self, soup: Any, custom_url: Optional[str] = None
+ ) -> Optional[Document]:
+ """Fetch content from page and return Document."""
+ page_content_raw = soup.find(self.content_selector)
+ if not page_content_raw:
+ return None
+ content = page_content_raw.get_text(separator="\n").strip()
+ title_if_exists = page_content_raw.find("h1")
+ title = title_if_exists.text if title_if_exists else ""
+ metadata = {"source": custom_url or self.web_path, "title": title}
+ return Document(page_content=content, metadata=metadata)
+
+ def _get_paths(self, soup: Any) -> List[str]:
+ """Fetch all relative paths in the navbar."""
+ return [urlparse(loc.text).path for loc in soup.find_all("loc")]
diff --git a/libs/community/langchain_community/document_loaders/github.py b/libs/community/langchain_community/document_loaders/github.py
new file mode 100644
index 00000000000..77bdf7da6c7
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/github.py
@@ -0,0 +1,188 @@
+from abc import ABC
+from datetime import datetime
+from typing import Dict, Iterator, List, Literal, Optional, Union
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator, validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class BaseGitHubLoader(BaseLoader, BaseModel, ABC):
+ """Load `GitHub` repository Issues."""
+
+ repo: str
+ """Name of repository"""
+ access_token: str
+ """Personal access token - see https://github.com/settings/tokens?type=beta"""
+ github_api_url: str = "https://api.github.com"
+ """URL of GitHub API"""
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that access token exists in environment."""
+ values["access_token"] = get_from_dict_or_env(
+ values, "access_token", "GITHUB_PERSONAL_ACCESS_TOKEN"
+ )
+ return values
+
+ @property
+ def headers(self) -> Dict[str, str]:
+ return {
+ "Accept": "application/vnd.github+json",
+ "Authorization": f"Bearer {self.access_token}",
+ }
+
+
+class GitHubIssuesLoader(BaseGitHubLoader):
+ """Load issues of a GitHub repository."""
+
+ include_prs: bool = True
+ """If True include Pull Requests in results, otherwise ignore them."""
+ milestone: Union[int, Literal["*", "none"], None] = None
+ """If integer is passed, it should be a milestone's number field.
+ If the string '*' is passed, issues with any milestone are accepted.
+ If the string 'none' is passed, issues without milestones are returned.
+ """
+ state: Optional[Literal["open", "closed", "all"]] = None
+ """Filter on issue state. Can be one of: 'open', 'closed', 'all'."""
+ assignee: Optional[str] = None
+ """Filter on assigned user. Pass 'none' for no user and '*' for any user."""
+ creator: Optional[str] = None
+ """Filter on the user that created the issue."""
+ mentioned: Optional[str] = None
+ """Filter on a user that's mentioned in the issue."""
+ labels: Optional[List[str]] = None
+ """Label names to filter one. Example: bug,ui,@high."""
+ sort: Optional[Literal["created", "updated", "comments"]] = None
+ """What to sort results by. Can be one of: 'created', 'updated', 'comments'.
+ Default is 'created'."""
+ direction: Optional[Literal["asc", "desc"]] = None
+ """The direction to sort the results by. Can be one of: 'asc', 'desc'."""
+ since: Optional[str] = None
+ """Only show notifications updated after the given time.
+ This is a timestamp in ISO 8601 format: YYYY-MM-DDTHH:MM:SSZ."""
+
+ @validator("since")
+ def validate_since(cls, v: Optional[str]) -> Optional[str]:
+ if v:
+ try:
+ datetime.strptime(v, "%Y-%m-%dT%H:%M:%SZ")
+ except ValueError:
+ raise ValueError(
+ "Invalid value for 'since'. Expected a date string in "
+ f"YYYY-MM-DDTHH:MM:SSZ format. Received: {v}"
+ )
+ return v
+
+ def lazy_load(self) -> Iterator[Document]:
+ """
+ Get issues of a GitHub repository.
+
+ Returns:
+ A list of Documents with attributes:
+ - page_content
+ - metadata
+ - url
+ - title
+ - creator
+ - created_at
+ - last_update_time
+ - closed_time
+ - number of comments
+ - state
+ - labels
+ - assignee
+ - assignees
+ - milestone
+ - locked
+ - number
+ - is_pull_request
+ """
+ url: Optional[str] = self.url
+ while url:
+ response = requests.get(url, headers=self.headers)
+ response.raise_for_status()
+ issues = response.json()
+ for issue in issues:
+ doc = self.parse_issue(issue)
+ if not self.include_prs and doc.metadata["is_pull_request"]:
+ continue
+ yield doc
+ if response.links and response.links.get("next"):
+ url = response.links["next"]["url"]
+ else:
+ url = None
+
+ def load(self) -> List[Document]:
+ """
+ Get issues of a GitHub repository.
+
+ Returns:
+ A list of Documents with attributes:
+ - page_content
+ - metadata
+ - url
+ - title
+ - creator
+ - created_at
+ - last_update_time
+ - closed_time
+ - number of comments
+ - state
+ - labels
+ - assignee
+ - assignees
+ - milestone
+ - locked
+ - number
+ - is_pull_request
+ """
+ return list(self.lazy_load())
+
+ def parse_issue(self, issue: dict) -> Document:
+ """Create Document objects from a list of GitHub issues."""
+ metadata = {
+ "url": issue["html_url"],
+ "title": issue["title"],
+ "creator": issue["user"]["login"],
+ "created_at": issue["created_at"],
+ "comments": issue["comments"],
+ "state": issue["state"],
+ "labels": [label["name"] for label in issue["labels"]],
+ "assignee": issue["assignee"]["login"] if issue["assignee"] else None,
+ "milestone": issue["milestone"]["title"] if issue["milestone"] else None,
+ "locked": issue["locked"],
+ "number": issue["number"],
+ "is_pull_request": "pull_request" in issue,
+ }
+ content = issue["body"] if issue["body"] is not None else ""
+ return Document(page_content=content, metadata=metadata)
+
+ @property
+ def query_params(self) -> str:
+ """Create query parameters for GitHub API."""
+ labels = ",".join(self.labels) if self.labels else self.labels
+ query_params_dict = {
+ "milestone": self.milestone,
+ "state": self.state,
+ "assignee": self.assignee,
+ "creator": self.creator,
+ "mentioned": self.mentioned,
+ "labels": labels,
+ "sort": self.sort,
+ "direction": self.direction,
+ "since": self.since,
+ }
+ query_params_list = [
+ f"{k}={v}" for k, v in query_params_dict.items() if v is not None
+ ]
+ query_params = "&".join(query_params_list)
+ return query_params
+
+ @property
+ def url(self) -> str:
+ """Create URL for GitHub API."""
+ return f"{self.github_api_url}/repos/{self.repo}/issues?{self.query_params}"
diff --git a/libs/community/langchain_community/document_loaders/google_speech_to_text.py b/libs/community/langchain_community/document_loaders/google_speech_to_text.py
new file mode 100644
index 00000000000..da3c743dea3
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/google_speech_to_text.py
@@ -0,0 +1,137 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.vertexai import get_client_info
+
+if TYPE_CHECKING:
+ from google.cloud.speech_v2 import RecognitionConfig
+ from google.protobuf.field_mask_pb2 import FieldMask
+
+
+class GoogleSpeechToTextLoader(BaseLoader):
+ """
+ Loader for Google Cloud Speech-to-Text audio transcripts.
+
+ It uses the Google Cloud Speech-to-Text API to transcribe audio files
+ and loads the transcribed text into one or more Documents,
+ depending on the specified format.
+
+ To use, you should have the ``google-cloud-speech`` python package installed.
+
+ Audio files can be specified via a Google Cloud Storage uri or a local file path.
+
+ For a detailed explanation of Google Cloud Speech-to-Text, refer to the product
+ documentation.
+ https://cloud.google.com/speech-to-text
+ """
+
+ def __init__(
+ self,
+ project_id: str,
+ file_path: str,
+ location: str = "us-central1",
+ recognizer_id: str = "_",
+ config: Optional[RecognitionConfig] = None,
+ config_mask: Optional[FieldMask] = None,
+ ):
+ """
+ Initializes the GoogleSpeechToTextLoader.
+
+ Args:
+ project_id: Google Cloud Project ID.
+ file_path: A Google Cloud Storage URI or a local file path.
+ location: Speech-to-Text recognizer location.
+ recognizer_id: Speech-to-Text recognizer id.
+ config: Recognition options and features.
+ For more information:
+ https://cloud.google.com/python/docs/reference/speech/latest/google.cloud.speech_v2.types.RecognitionConfig
+ config_mask: The list of fields in config that override the values in the
+ ``default_recognition_config`` of the recognizer during this
+ recognition request.
+ For more information:
+ https://cloud.google.com/python/docs/reference/speech/latest/google.cloud.speech_v2.types.RecognizeRequest
+ """
+ try:
+ from google.api_core.client_options import ClientOptions
+ from google.cloud.speech_v2 import (
+ AutoDetectDecodingConfig,
+ RecognitionConfig,
+ RecognitionFeatures,
+ SpeechClient,
+ )
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import google-cloud-speech python package. "
+ "Please install it with `pip install google-cloud-speech`."
+ ) from exc
+
+ self.project_id = project_id
+ self.file_path = file_path
+ self.location = location
+ self.recognizer_id = recognizer_id
+ # Config must be set in speech recognition request.
+ self.config = config or RecognitionConfig(
+ auto_decoding_config=AutoDetectDecodingConfig(),
+ language_codes=["en-US"],
+ model="chirp",
+ features=RecognitionFeatures(
+ # Automatic punctuation could be useful for language applications
+ enable_automatic_punctuation=True,
+ ),
+ )
+ self.config_mask = config_mask
+
+ self._client = SpeechClient(
+ client_info=get_client_info(module="speech-to-text"),
+ client_options=(
+ ClientOptions(api_endpoint=f"{location}-speech.googleapis.com")
+ if location != "global"
+ else None
+ ),
+ )
+ self._recognizer_path = self._client.recognizer_path(
+ project_id, location, recognizer_id
+ )
+
+ def load(self) -> List[Document]:
+ """Transcribes the audio file and loads the transcript into documents.
+
+ It uses the Google Cloud Speech-to-Text API to transcribe the audio file
+ and blocks until the transcription is finished.
+ """
+ try:
+ from google.cloud.speech_v2 import RecognizeRequest
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import google-cloud-speech python package. "
+ "Please install it with `pip install google-cloud-speech`."
+ ) from exc
+
+ request = RecognizeRequest(
+ recognizer=self._recognizer_path,
+ config=self.config,
+ config_mask=self.config_mask,
+ )
+
+ if "gs://" in self.file_path:
+ request.uri = self.file_path
+ else:
+ with open(self.file_path, "rb") as f:
+ request.content = f.read()
+
+ response = self._client.recognize(request=request)
+
+ return [
+ Document(
+ page_content=result.alternatives[0].transcript,
+ metadata={
+ "language_code": result.language_code,
+ "result_end_offset": result.result_end_offset,
+ },
+ )
+ for result in response.results
+ ]
diff --git a/libs/community/langchain_community/document_loaders/googledrive.py b/libs/community/langchain_community/document_loaders/googledrive.py
new file mode 100644
index 00000000000..7f5124a4d9b
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/googledrive.py
@@ -0,0 +1,360 @@
+# Prerequisites:
+# 1. Create a Google Cloud project
+# 2. Enable the Google Drive API:
+# https://console.cloud.google.com/flows/enableapi?apiid=drive.googleapis.com
+# 3. Authorize credentials for desktop app:
+# https://developers.google.com/drive/api/quickstart/python#authorize_credentials_for_a_desktop_application # noqa: E501
+# 4. For service accounts visit
+# https://cloud.google.com/iam/docs/service-accounts-create
+
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator, validator
+
+from langchain_community.document_loaders.base import BaseLoader
+
+SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
+
+
+class GoogleDriveLoader(BaseLoader, BaseModel):
+ """Load Google Docs from `Google Drive`."""
+
+ service_account_key: Path = Path.home() / ".credentials" / "keys.json"
+ """Path to the service account key file."""
+ credentials_path: Path = Path.home() / ".credentials" / "credentials.json"
+ """Path to the credentials file."""
+ token_path: Path = Path.home() / ".credentials" / "token.json"
+ """Path to the token file."""
+ folder_id: Optional[str] = None
+ """The folder id to load from."""
+ document_ids: Optional[List[str]] = None
+ """The document ids to load from."""
+ file_ids: Optional[List[str]] = None
+ """The file ids to load from."""
+ recursive: bool = False
+ """Whether to load recursively. Only applies when folder_id is given."""
+ file_types: Optional[Sequence[str]] = None
+ """The file types to load. Only applies when folder_id is given."""
+ load_trashed_files: bool = False
+ """Whether to load trashed files. Only applies when folder_id is given."""
+ # NOTE(MthwRobinson) - changing the file_loader_cls to type here currently
+ # results in pydantic validation errors
+ file_loader_cls: Any = None
+ """The file loader class to use."""
+ file_loader_kwargs: Dict["str", Any] = {}
+ """The file loader kwargs to use."""
+
+ @root_validator
+ def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Validate that either folder_id or document_ids is set, but not both."""
+ if values.get("folder_id") and (
+ values.get("document_ids") or values.get("file_ids")
+ ):
+ raise ValueError(
+ "Cannot specify both folder_id and document_ids nor "
+ "folder_id and file_ids"
+ )
+ if (
+ not values.get("folder_id")
+ and not values.get("document_ids")
+ and not values.get("file_ids")
+ ):
+ raise ValueError("Must specify either folder_id, document_ids, or file_ids")
+
+ file_types = values.get("file_types")
+ if file_types:
+ if values.get("document_ids") or values.get("file_ids"):
+ raise ValueError(
+ "file_types can only be given when folder_id is given,"
+ " (not when document_ids or file_ids are given)."
+ )
+ type_mapping = {
+ "document": "application/vnd.google-apps.document",
+ "sheet": "application/vnd.google-apps.spreadsheet",
+ "pdf": "application/pdf",
+ }
+ allowed_types = list(type_mapping.keys()) + list(type_mapping.values())
+ short_names = ", ".join([f"'{x}'" for x in type_mapping.keys()])
+ full_names = ", ".join([f"'{x}'" for x in type_mapping.values()])
+ for file_type in file_types:
+ if file_type not in allowed_types:
+ raise ValueError(
+ f"Given file type {file_type} is not supported. "
+ f"Supported values are: {short_names}; and "
+ f"their full-form names: {full_names}"
+ )
+
+ # replace short-form file types by full-form file types
+ def full_form(x: str) -> str:
+ return type_mapping[x] if x in type_mapping else x
+
+ values["file_types"] = [full_form(file_type) for file_type in file_types]
+ return values
+
+ @validator("credentials_path")
+ def validate_credentials_path(cls, v: Any, **kwargs: Any) -> Any:
+ """Validate that credentials_path exists."""
+ if not v.exists():
+ raise ValueError(f"credentials_path {v} does not exist")
+ return v
+
+ def _load_credentials(self) -> Any:
+ """Load credentials."""
+ # Adapted from https://developers.google.com/drive/api/v3/quickstart/python
+ try:
+ from google.auth import default
+ from google.auth.transport.requests import Request
+ from google.oauth2 import service_account
+ from google.oauth2.credentials import Credentials
+ from google_auth_oauthlib.flow import InstalledAppFlow
+ except ImportError:
+ raise ImportError(
+ "You must run "
+ "`pip install --upgrade "
+ "google-api-python-client google-auth-httplib2 "
+ "google-auth-oauthlib` "
+ "to use the Google Drive loader."
+ )
+
+ creds = None
+ if self.service_account_key.exists():
+ return service_account.Credentials.from_service_account_file(
+ str(self.service_account_key), scopes=SCOPES
+ )
+
+ if self.token_path.exists():
+ creds = Credentials.from_authorized_user_file(str(self.token_path), SCOPES)
+
+ if not creds or not creds.valid:
+ if creds and creds.expired and creds.refresh_token:
+ creds.refresh(Request())
+ elif "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ:
+ creds, project = default()
+ creds = creds.with_scopes(SCOPES)
+ # no need to write to file
+ if creds:
+ return creds
+ else:
+ flow = InstalledAppFlow.from_client_secrets_file(
+ str(self.credentials_path), SCOPES
+ )
+ creds = flow.run_local_server(port=0)
+ with open(self.token_path, "w") as token:
+ token.write(creds.to_json())
+
+ return creds
+
+ def _load_sheet_from_id(self, id: str) -> List[Document]:
+ """Load a sheet and all tabs from an ID."""
+
+ from googleapiclient.discovery import build
+
+ creds = self._load_credentials()
+ sheets_service = build("sheets", "v4", credentials=creds)
+ spreadsheet = sheets_service.spreadsheets().get(spreadsheetId=id).execute()
+ sheets = spreadsheet.get("sheets", [])
+
+ documents = []
+ for sheet in sheets:
+ sheet_name = sheet["properties"]["title"]
+ result = (
+ sheets_service.spreadsheets()
+ .values()
+ .get(spreadsheetId=id, range=sheet_name)
+ .execute()
+ )
+ values = result.get("values", [])
+ if not values:
+ continue # empty sheet
+
+ header = values[0]
+ for i, row in enumerate(values[1:], start=1):
+ metadata = {
+ "source": (
+ f"https://docs.google.com/spreadsheets/d/{id}/"
+ f"edit?gid={sheet['properties']['sheetId']}"
+ ),
+ "title": f"{spreadsheet['properties']['title']} - {sheet_name}",
+ "row": i,
+ }
+ content = []
+ for j, v in enumerate(row):
+ title = header[j].strip() if len(header) > j else ""
+ content.append(f"{title}: {v.strip()}")
+
+ page_content = "\n".join(content)
+ documents.append(Document(page_content=page_content, metadata=metadata))
+
+ return documents
+
+ def _load_document_from_id(self, id: str) -> Document:
+ """Load a document from an ID."""
+ from io import BytesIO
+
+ from googleapiclient.discovery import build
+ from googleapiclient.errors import HttpError
+ from googleapiclient.http import MediaIoBaseDownload
+
+ creds = self._load_credentials()
+ service = build("drive", "v3", credentials=creds)
+
+ file = (
+ service.files()
+ .get(fileId=id, supportsAllDrives=True, fields="modifiedTime,name")
+ .execute()
+ )
+ request = service.files().export_media(fileId=id, mimeType="text/plain")
+ fh = BytesIO()
+ downloader = MediaIoBaseDownload(fh, request)
+ done = False
+ try:
+ while done is False:
+ status, done = downloader.next_chunk()
+
+ except HttpError as e:
+ if e.resp.status == 404:
+ print("File not found: {}".format(id))
+ else:
+ print("An error occurred: {}".format(e))
+
+ text = fh.getvalue().decode("utf-8")
+ metadata = {
+ "source": f"https://docs.google.com/document/d/{id}/edit",
+ "title": f"{file.get('name')}",
+ "when": f"{file.get('modifiedTime')}",
+ }
+ return Document(page_content=text, metadata=metadata)
+
+ def _load_documents_from_folder(
+ self, folder_id: str, *, file_types: Optional[Sequence[str]] = None
+ ) -> List[Document]:
+ """Load documents from a folder."""
+ from googleapiclient.discovery import build
+
+ creds = self._load_credentials()
+ service = build("drive", "v3", credentials=creds)
+ files = self._fetch_files_recursive(service, folder_id)
+ # If file types filter is provided, we'll filter by the file type.
+ if file_types:
+ _files = [f for f in files if f["mimeType"] in file_types] # type: ignore
+ else:
+ _files = files
+
+ returns = []
+ for file in _files:
+ if file["trashed"] and not self.load_trashed_files:
+ continue
+ elif file["mimeType"] == "application/vnd.google-apps.document":
+ returns.append(self._load_document_from_id(file["id"])) # type: ignore
+ elif file["mimeType"] == "application/vnd.google-apps.spreadsheet":
+ returns.extend(self._load_sheet_from_id(file["id"])) # type: ignore
+ elif (
+ file["mimeType"] == "application/pdf"
+ or self.file_loader_cls is not None
+ ):
+ returns.extend(self._load_file_from_id(file["id"])) # type: ignore
+ else:
+ pass
+ return returns
+
+ def _fetch_files_recursive(
+ self, service: Any, folder_id: str
+ ) -> List[Dict[str, Union[str, List[str]]]]:
+ """Fetch all files and subfolders recursively."""
+ results = (
+ service.files()
+ .list(
+ q=f"'{folder_id}' in parents",
+ pageSize=1000,
+ includeItemsFromAllDrives=True,
+ supportsAllDrives=True,
+ fields="nextPageToken, files(id, name, mimeType, parents, trashed)",
+ )
+ .execute()
+ )
+ files = results.get("files", [])
+ returns = []
+ for file in files:
+ if file["mimeType"] == "application/vnd.google-apps.folder":
+ if self.recursive:
+ returns.extend(self._fetch_files_recursive(service, file["id"]))
+ else:
+ returns.append(file)
+
+ return returns
+
+ def _load_documents_from_ids(self) -> List[Document]:
+ """Load documents from a list of IDs."""
+ if not self.document_ids:
+ raise ValueError("document_ids must be set")
+
+ return [self._load_document_from_id(doc_id) for doc_id in self.document_ids]
+
+ def _load_file_from_id(self, id: str) -> List[Document]:
+ """Load a file from an ID."""
+ from io import BytesIO
+
+ from googleapiclient.discovery import build
+ from googleapiclient.http import MediaIoBaseDownload
+
+ creds = self._load_credentials()
+ service = build("drive", "v3", credentials=creds)
+
+ file = service.files().get(fileId=id, supportsAllDrives=True).execute()
+ request = service.files().get_media(fileId=id)
+ fh = BytesIO()
+ downloader = MediaIoBaseDownload(fh, request)
+ done = False
+ while done is False:
+ status, done = downloader.next_chunk()
+
+ if self.file_loader_cls is not None:
+ fh.seek(0)
+ loader = self.file_loader_cls(file=fh, **self.file_loader_kwargs)
+ docs = loader.load()
+ for doc in docs:
+ doc.metadata["source"] = f"https://drive.google.com/file/d/{id}/view"
+ if "title" not in doc.metadata:
+ doc.metadata["title"] = f"{file.get('name')}"
+ return docs
+
+ else:
+ from PyPDF2 import PdfReader
+
+ content = fh.getvalue()
+ pdf_reader = PdfReader(BytesIO(content))
+
+ return [
+ Document(
+ page_content=page.extract_text(),
+ metadata={
+ "source": f"https://drive.google.com/file/d/{id}/view",
+ "title": f"{file.get('name')}",
+ "page": i,
+ },
+ )
+ for i, page in enumerate(pdf_reader.pages)
+ ]
+
+ def _load_file_from_ids(self) -> List[Document]:
+ """Load files from a list of IDs."""
+ if not self.file_ids:
+ raise ValueError("file_ids must be set")
+ docs = []
+ for file_id in self.file_ids:
+ docs.extend(self._load_file_from_id(file_id))
+ return docs
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ if self.folder_id:
+ return self._load_documents_from_folder(
+ self.folder_id, file_types=self.file_types
+ )
+ elif self.document_ids:
+ return self._load_documents_from_ids()
+ else:
+ return self._load_file_from_ids()
diff --git a/libs/community/langchain_community/document_loaders/gutenberg.py b/libs/community/langchain_community/document_loaders/gutenberg.py
new file mode 100644
index 00000000000..fe253ae8887
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/gutenberg.py
@@ -0,0 +1,28 @@
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class GutenbergLoader(BaseLoader):
+ """Load from `Gutenberg.org`."""
+
+ def __init__(self, file_path: str):
+ """Initialize with a file path."""
+ if not file_path.startswith("https://www.gutenberg.org"):
+ raise ValueError("file path must start with 'https://www.gutenberg.org'")
+
+ if not file_path.endswith(".txt"):
+ raise ValueError("file path must end with '.txt'")
+
+ self.file_path = file_path
+
+ def load(self) -> List[Document]:
+ """Load file."""
+ from urllib.request import urlopen
+
+ elements = urlopen(self.file_path)
+ text = "\n\n".join([str(el.decode("utf-8-sig")) for el in elements])
+ metadata = {"source": self.file_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/helpers.py b/libs/community/langchain_community/document_loaders/helpers.py
new file mode 100644
index 00000000000..6e0f8b9bfb9
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/helpers.py
@@ -0,0 +1,46 @@
+"""Document loader helpers."""
+
+import concurrent.futures
+from typing import List, NamedTuple, Optional, cast
+
+
+class FileEncoding(NamedTuple):
+ """File encoding as the NamedTuple."""
+
+ encoding: Optional[str]
+ """The encoding of the file."""
+ confidence: float
+ """The confidence of the encoding."""
+ language: Optional[str]
+ """The language of the file."""
+
+
+def detect_file_encodings(file_path: str, timeout: int = 5) -> List[FileEncoding]:
+ """Try to detect the file encoding.
+
+ Returns a list of `FileEncoding` tuples with the detected encodings ordered
+ by confidence.
+
+ Args:
+ file_path: The path to the file to detect the encoding for.
+ timeout: The timeout in seconds for the encoding detection.
+ """
+ import chardet
+
+ def read_and_detect(file_path: str) -> List[dict]:
+ with open(file_path, "rb") as f:
+ rawdata = f.read()
+ return cast(List[dict], chardet.detect_all(rawdata))
+
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(read_and_detect, file_path)
+ try:
+ encodings = future.result(timeout=timeout)
+ except concurrent.futures.TimeoutError:
+ raise TimeoutError(
+ f"Timeout reached while detecting encoding for {file_path}"
+ )
+
+ if all(encoding["encoding"] is None for encoding in encodings):
+ raise RuntimeError(f"Could not detect encoding for {file_path}")
+ return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
diff --git a/libs/community/langchain_community/document_loaders/hn.py b/libs/community/langchain_community/document_loaders/hn.py
new file mode 100644
index 00000000000..ca36ca5f2b9
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/hn.py
@@ -0,0 +1,62 @@
+from typing import Any, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+
+class HNLoader(WebBaseLoader):
+ """Load `Hacker News` data.
+
+ It loads data from either main page results or the comments page."""
+
+ def load(self) -> List[Document]:
+ """Get important HN webpage information.
+
+ HN webpage components are:
+ - title
+ - content
+ - source url,
+ - time of post
+ - author of the post
+ - number of comments
+ - rank of the post
+ """
+ soup_info = self.scrape()
+ if "item" in self.web_path:
+ return self.load_comments(soup_info)
+ else:
+ return self.load_results(soup_info)
+
+ def load_comments(self, soup_info: Any) -> List[Document]:
+ """Load comments from a HN post."""
+ comments = soup_info.select("tr[class='athing comtr']")
+ title = soup_info.select_one("tr[id='pagespace']").get("title")
+ return [
+ Document(
+ page_content=comment.text.strip(),
+ metadata={"source": self.web_path, "title": title},
+ )
+ for comment in comments
+ ]
+
+ def load_results(self, soup: Any) -> List[Document]:
+ """Load items from an HN page."""
+ items = soup.select("tr[class='athing']")
+ documents = []
+ for lineItem in items:
+ ranking = lineItem.select_one("span[class='rank']").text
+ link = lineItem.find("span", {"class": "titleline"}).find("a").get("href")
+ title = lineItem.find("span", {"class": "titleline"}).text.strip()
+ metadata = {
+ "source": self.web_path,
+ "title": title,
+ "link": link,
+ "ranking": ranking,
+ }
+ documents.append(
+ Document(
+ page_content=title, link=link, ranking=ranking, metadata=metadata
+ )
+ )
+ return documents
diff --git a/libs/community/langchain_community/document_loaders/html.py b/libs/community/langchain_community/document_loaders/html.py
new file mode 100644
index 00000000000..77a2ef3d24d
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/html.py
@@ -0,0 +1,33 @@
+from typing import List
+
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class UnstructuredHTMLLoader(UnstructuredFileLoader):
+ """Load `HTML` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredHTMLLoader
+
+ loader = UnstructuredHTMLLoader(
+ "example.html", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-html
+ """
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.html import partition_html
+
+ return partition_html(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/html_bs.py b/libs/community/langchain_community/document_loaders/html_bs.py
new file mode 100644
index 00000000000..75823f2e96e
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/html_bs.py
@@ -0,0 +1,63 @@
+import logging
+from typing import Dict, List, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class BSHTMLLoader(BaseLoader):
+ """Load `HTML` files and parse them with `beautiful soup`."""
+
+ def __init__(
+ self,
+ file_path: str,
+ open_encoding: Union[str, None] = None,
+ bs_kwargs: Union[dict, None] = None,
+ get_text_separator: str = "",
+ ) -> None:
+ """Initialise with path, and optionally, file encoding to use, and any kwargs
+ to pass to the BeautifulSoup object.
+
+ Args:
+ file_path: The path to the file to load.
+ open_encoding: The encoding to use when opening the file.
+ bs_kwargs: Any kwargs to pass to the BeautifulSoup object.
+ get_text_separator: The separator to use when calling get_text on the soup.
+ """
+ try:
+ import bs4 # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "beautifulsoup4 package not found, please install it with "
+ "`pip install beautifulsoup4`"
+ )
+
+ self.file_path = file_path
+ self.open_encoding = open_encoding
+ if bs_kwargs is None:
+ bs_kwargs = {"features": "lxml"}
+ self.bs_kwargs = bs_kwargs
+ self.get_text_separator = get_text_separator
+
+ def load(self) -> List[Document]:
+ """Load HTML document into document objects."""
+ from bs4 import BeautifulSoup
+
+ with open(self.file_path, "r", encoding=self.open_encoding) as f:
+ soup = BeautifulSoup(f, **self.bs_kwargs)
+
+ text = soup.get_text(self.get_text_separator)
+
+ if soup.title:
+ title = str(soup.title.string)
+ else:
+ title = ""
+
+ metadata: Dict[str, Union[str, None]] = {
+ "source": self.file_path,
+ "title": title,
+ }
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/hugging_face_dataset.py b/libs/community/langchain_community/document_loaders/hugging_face_dataset.py
new file mode 100644
index 00000000000..77841e56a79
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/hugging_face_dataset.py
@@ -0,0 +1,94 @@
+import json
+from typing import Iterator, List, Mapping, Optional, Sequence, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class HuggingFaceDatasetLoader(BaseLoader):
+ """Load from `Hugging Face Hub` datasets."""
+
+ def __init__(
+ self,
+ path: str,
+ page_content_column: str = "text",
+ name: Optional[str] = None,
+ data_dir: Optional[str] = None,
+ data_files: Optional[
+ Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
+ ] = None,
+ cache_dir: Optional[str] = None,
+ keep_in_memory: Optional[bool] = None,
+ save_infos: bool = False,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ num_proc: Optional[int] = None,
+ ):
+ """Initialize the HuggingFaceDatasetLoader.
+
+ Args:
+ path: Path or name of the dataset.
+ page_content_column: Page content column name. Default is "text".
+ name: Name of the dataset configuration.
+ data_dir: Data directory of the dataset configuration.
+ data_files: Path(s) to source data file(s).
+ cache_dir: Directory to read/write data.
+ keep_in_memory: Whether to copy the dataset in-memory.
+ save_infos: Save the dataset information (checksums/size/splits/...).
+ Default is False.
+ use_auth_token: Bearer token for remote files on the Dataset Hub.
+ num_proc: Number of processes.
+ """
+
+ self.path = path
+ self.page_content_column = page_content_column
+ self.name = name
+ self.data_dir = data_dir
+ self.data_files = data_files
+ self.cache_dir = cache_dir
+ self.keep_in_memory = keep_in_memory
+ self.save_infos = save_infos
+ self.use_auth_token = use_auth_token
+ self.num_proc = num_proc
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Load documents lazily."""
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "Could not import datasets python package. "
+ "Please install it with `pip install datasets`."
+ )
+
+ dataset = load_dataset(
+ path=self.path,
+ name=self.name,
+ data_dir=self.data_dir,
+ data_files=self.data_files,
+ cache_dir=self.cache_dir,
+ keep_in_memory=self.keep_in_memory,
+ save_infos=self.save_infos,
+ use_auth_token=self.use_auth_token,
+ num_proc=self.num_proc,
+ )
+
+ yield from (
+ Document(
+ page_content=self.parse_obj(row.pop(self.page_content_column)),
+ metadata=row,
+ )
+ for key in dataset.keys()
+ for row in dataset[key]
+ )
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ return list(self.lazy_load())
+
+ def parse_obj(self, page_content: Union[str, object]) -> str:
+ if isinstance(page_content, object):
+ return json.dumps(page_content)
+ return page_content
diff --git a/libs/community/langchain_community/document_loaders/ifixit.py b/libs/community/langchain_community/document_loaders/ifixit.py
new file mode 100644
index 00000000000..335a4a1816e
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/ifixit.py
@@ -0,0 +1,240 @@
+from typing import List, Optional
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+IFIXIT_BASE_URL = "https://www.ifixit.com/api/2.0"
+
+
+class IFixitLoader(BaseLoader):
+ """Load `iFixit` repair guides, device wikis and answers.
+
+ iFixit is the largest, open repair community on the web. The site contains nearly
+ 100k repair manuals, 200k Questions & Answers on 42k devices, and all the data is
+ licensed under CC-BY.
+
+ This loader will allow you to download the text of a repair guide, text of Q&A's
+ and wikis from devices on iFixit using their open APIs and web scraping.
+ """
+
+ def __init__(self, web_path: str):
+ """Initialize with a web path."""
+ if not web_path.startswith("https://www.ifixit.com"):
+ raise ValueError("web path must start with 'https://www.ifixit.com'")
+
+ path = web_path.replace("https://www.ifixit.com", "")
+
+ allowed_paths = ["/Device", "/Guide", "/Answers", "/Teardown"]
+
+ """ TODO: Add /Wiki """
+ if not any(path.startswith(allowed_path) for allowed_path in allowed_paths):
+ raise ValueError(
+ "web path must start with /Device, /Guide, /Teardown or /Answers"
+ )
+
+ pieces = [x for x in path.split("/") if x]
+
+ """Teardowns are just guides by a different name"""
+ self.page_type = pieces[0] if pieces[0] != "Teardown" else "Guide"
+
+ if self.page_type == "Guide" or self.page_type == "Answers":
+ self.id = pieces[2]
+ else:
+ self.id = pieces[1]
+
+ self.web_path = web_path
+
+ def load(self) -> List[Document]:
+ if self.page_type == "Device":
+ return self.load_device()
+ elif self.page_type == "Guide" or self.page_type == "Teardown":
+ return self.load_guide()
+ elif self.page_type == "Answers":
+ return self.load_questions_and_answers()
+ else:
+ raise ValueError("Unknown page type: " + self.page_type)
+
+ @staticmethod
+ def load_suggestions(query: str = "", doc_type: str = "all") -> List[Document]:
+ """Load suggestions.
+
+ Args:
+ query: A query string
+ doc_type: The type of document to search for. Can be one of "all",
+ "device", "guide", "teardown", "answer", "wiki".
+
+ Returns:
+
+ """
+ res = requests.get(
+ IFIXIT_BASE_URL + "/suggest/" + query + "?doctypes=" + doc_type
+ )
+
+ if res.status_code != 200:
+ raise ValueError(
+ 'Could not load suggestions for "' + query + '"\n' + res.json()
+ )
+
+ data = res.json()
+
+ results = data["results"]
+ output = []
+
+ for result in results:
+ try:
+ loader = IFixitLoader(result["url"])
+ if loader.page_type == "Device":
+ output += loader.load_device(include_guides=False)
+ else:
+ output += loader.load()
+ except ValueError:
+ continue
+
+ return output
+
+ def load_questions_and_answers(
+ self, url_override: Optional[str] = None
+ ) -> List[Document]:
+ """Load a list of questions and answers.
+
+ Args:
+ url_override: A URL to override the default URL.
+
+ Returns: List[Document]
+
+ """
+ loader = WebBaseLoader(self.web_path if url_override is None else url_override)
+ soup = loader.scrape()
+
+ output = []
+
+ title = soup.find("h1", "post-title").text
+
+ output.append("# " + title)
+ output.append(soup.select_one(".post-content .post-text").text.strip())
+
+ answersHeader = soup.find("div", "post-answers-header")
+ if answersHeader:
+ output.append("\n## " + answersHeader.text.strip())
+
+ for answer in soup.select(".js-answers-list .post.post-answer"):
+ if answer.has_attr("itemprop") and "acceptedAnswer" in answer["itemprop"]:
+ output.append("\n### Accepted Answer")
+ elif "post-helpful" in answer["class"]:
+ output.append("\n### Most Helpful Answer")
+ else:
+ output.append("\n### Other Answer")
+
+ output += [
+ a.text.strip() for a in answer.select(".post-content .post-text")
+ ]
+ output.append("\n")
+
+ text = "\n".join(output).strip()
+
+ metadata = {"source": self.web_path, "title": title}
+
+ return [Document(page_content=text, metadata=metadata)]
+
+ def load_device(
+ self, url_override: Optional[str] = None, include_guides: bool = True
+ ) -> List[Document]:
+ """Loads a device
+
+ Args:
+ url_override: A URL to override the default URL.
+ include_guides: Whether to include guides linked to from the device.
+ Defaults to True.
+
+ Returns:
+
+ """
+ documents = []
+ if url_override is None:
+ url = IFIXIT_BASE_URL + "/wikis/CATEGORY/" + self.id
+ else:
+ url = url_override
+
+ res = requests.get(url)
+ data = res.json()
+ text = "\n".join(
+ [
+ data[key]
+ for key in ["title", "description", "contents_raw"]
+ if key in data
+ ]
+ ).strip()
+
+ metadata = {"source": self.web_path, "title": data["title"]}
+ documents.append(Document(page_content=text, metadata=metadata))
+
+ if include_guides:
+ """Load and return documents for each guide linked to from the device"""
+ guide_urls = [guide["url"] for guide in data["guides"]]
+ for guide_url in guide_urls:
+ documents.append(IFixitLoader(guide_url).load()[0])
+
+ return documents
+
+ def load_guide(self, url_override: Optional[str] = None) -> List[Document]:
+ """Load a guide
+
+ Args:
+ url_override: A URL to override the default URL.
+
+ Returns: List[Document]
+
+ """
+ if url_override is None:
+ url = IFIXIT_BASE_URL + "/guides/" + self.id
+ else:
+ url = url_override
+
+ res = requests.get(url)
+
+ if res.status_code != 200:
+ raise ValueError(
+ "Could not load guide: " + self.web_path + "\n" + res.json()
+ )
+
+ data = res.json()
+
+ doc_parts = ["# " + data["title"], data["introduction_raw"]]
+
+ doc_parts.append("\n\n###Tools Required:")
+ if len(data["tools"]) == 0:
+ doc_parts.append("\n - None")
+ else:
+ for tool in data["tools"]:
+ doc_parts.append("\n - " + tool["text"])
+
+ doc_parts.append("\n\n###Parts Required:")
+ if len(data["parts"]) == 0:
+ doc_parts.append("\n - None")
+ else:
+ for part in data["parts"]:
+ doc_parts.append("\n - " + part["text"])
+
+ for row in data["steps"]:
+ doc_parts.append(
+ "\n\n## "
+ + (
+ row["title"]
+ if row["title"] != ""
+ else "Step {}".format(row["orderby"])
+ )
+ )
+
+ for line in row["lines"]:
+ doc_parts.append(line["text_raw"])
+
+ doc_parts.append(data["conclusion_raw"])
+
+ text = "\n".join(doc_parts)
+
+ metadata = {"source": self.web_path, "title": data["title"]}
+
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/image.py b/libs/community/langchain_community/document_loaders/image.py
new file mode 100644
index 00000000000..c7b2e29b4c6
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/image.py
@@ -0,0 +1,33 @@
+from typing import List
+
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class UnstructuredImageLoader(UnstructuredFileLoader):
+ """Load `PNG` and `JPG` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredImageLoader
+
+ loader = UnstructuredImageLoader(
+ "example.png", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-image
+ """
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.image import partition_image
+
+ return partition_image(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/image_captions.py b/libs/community/langchain_community/document_loaders/image_captions.py
new file mode 100644
index 00000000000..93dce16432e
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/image_captions.py
@@ -0,0 +1,99 @@
+from io import BytesIO
+from typing import Any, List, Tuple, Union
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class ImageCaptionLoader(BaseLoader):
+ """Load image captions.
+
+ By default, the loader utilizes the pre-trained
+ Salesforce BLIP image captioning model.
+ https://huggingface.co/Salesforce/blip-image-captioning-base
+ """
+
+ def __init__(
+ self,
+ images: Union[str, bytes, List[Union[str, bytes]]],
+ blip_processor: str = "Salesforce/blip-image-captioning-base",
+ blip_model: str = "Salesforce/blip-image-captioning-base",
+ ):
+ """Initialize with a list of image data (bytes) or file paths
+
+ Args:
+ images: Either a single image or a list of images. Accepts
+ image data (bytes) or file paths to images.
+ blip_processor: The name of the pre-trained BLIP processor.
+ blip_model: The name of the pre-trained BLIP model.
+ """
+ if isinstance(images, (str, bytes)):
+ self.images = [images]
+ else:
+ self.images = images
+
+ self.blip_processor = blip_processor
+ self.blip_model = blip_model
+
+ def load(self) -> List[Document]:
+ """Load from a list of image data or file paths"""
+ try:
+ from transformers import BlipForConditionalGeneration, BlipProcessor
+ except ImportError:
+ raise ImportError(
+ "`transformers` package not found, please install with "
+ "`pip install transformers`."
+ )
+
+ processor = BlipProcessor.from_pretrained(self.blip_processor)
+ model = BlipForConditionalGeneration.from_pretrained(self.blip_model)
+
+ results = []
+ for image in self.images:
+ caption, metadata = self._get_captions_and_metadata(
+ model=model, processor=processor, image=image
+ )
+ doc = Document(page_content=caption, metadata=metadata)
+ results.append(doc)
+
+ return results
+
+ def _get_captions_and_metadata(
+ self, model: Any, processor: Any, image: Union[str, bytes]
+ ) -> Tuple[str, dict]:
+ """Helper function for getting the captions and metadata of an image."""
+ try:
+ from PIL import Image
+ except ImportError:
+ raise ImportError(
+ "`PIL` package not found, please install with `pip install pillow`"
+ )
+
+ image_source = image # Save the original source for later reference
+
+ try:
+ if isinstance(image, bytes):
+ image = Image.open(BytesIO(image)).convert("RGB")
+ elif image.startswith("http://") or image.startswith("https://"):
+ image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
+ else:
+ image = Image.open(image).convert("RGB")
+ except Exception:
+ if isinstance(image_source, bytes):
+ msg = "Could not get image data from bytes"
+ else:
+ msg = f"Could not get image data for {image_source}"
+ raise ValueError(msg)
+
+ inputs = processor(image, "an image of", return_tensors="pt")
+ output = model.generate(**inputs)
+
+ caption: str = processor.decode(output[0])
+ if isinstance(image_source, bytes):
+ metadata: dict = {"image_source": "Image bytes provided"}
+ else:
+ metadata = {"image_path": image_source}
+
+ return caption, metadata
diff --git a/libs/community/langchain_community/document_loaders/imsdb.py b/libs/community/langchain_community/document_loaders/imsdb.py
new file mode 100644
index 00000000000..af224020620
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/imsdb.py
@@ -0,0 +1,16 @@
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+
+class IMSDbLoader(WebBaseLoader):
+ """Load `IMSDb` webpages."""
+
+ def load(self) -> List[Document]:
+ """Load webpage."""
+ soup = self.scrape()
+ text = soup.select_one("td[class='scrtext']").text
+ metadata = {"source": self.web_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/iugu.py b/libs/community/langchain_community/document_loaders/iugu.py
new file mode 100644
index 00000000000..31e56bb0c7a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/iugu.py
@@ -0,0 +1,49 @@
+import json
+import urllib.request
+from typing import List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.utils import get_from_env, stringify_dict
+
+from langchain_community.document_loaders.base import BaseLoader
+
+IUGU_ENDPOINTS = {
+ "invoices": "https://api.iugu.com/v1/invoices",
+ "customers": "https://api.iugu.com/v1/customers",
+ "charges": "https://api.iugu.com/v1/charges",
+ "subscriptions": "https://api.iugu.com/v1/subscriptions",
+ "plans": "https://api.iugu.com/v1/plans",
+}
+
+
+class IuguLoader(BaseLoader):
+ """Load from `IUGU`."""
+
+ def __init__(self, resource: str, api_token: Optional[str] = None) -> None:
+ """Initialize the IUGU resource.
+
+ Args:
+ resource: The name of the resource to fetch.
+ api_token: The IUGU API token to use.
+ """
+ self.resource = resource
+ api_token = api_token or get_from_env("api_token", "IUGU_API_TOKEN")
+ self.headers = {"Authorization": f"Bearer {api_token}"}
+
+ def _make_request(self, url: str) -> List[Document]:
+ request = urllib.request.Request(url, headers=self.headers)
+
+ with urllib.request.urlopen(request) as response:
+ json_data = json.loads(response.read().decode())
+ text = stringify_dict(json_data)
+ metadata = {"source": url}
+ return [Document(page_content=text, metadata=metadata)]
+
+ def _get_resource(self) -> List[Document]:
+ endpoint = IUGU_ENDPOINTS.get(self.resource)
+ if endpoint is None:
+ return []
+ return self._make_request(endpoint)
+
+ def load(self) -> List[Document]:
+ return self._get_resource()
diff --git a/libs/community/langchain_community/document_loaders/joplin.py b/libs/community/langchain_community/document_loaders/joplin.py
new file mode 100644
index 00000000000..50a19f14059
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/joplin.py
@@ -0,0 +1,96 @@
+import json
+import urllib
+from datetime import datetime
+from typing import Iterator, List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.utils import get_from_env
+
+from langchain_community.document_loaders.base import BaseLoader
+
+LINK_NOTE_TEMPLATE = "joplin://x-callback-url/openNote?id={id}"
+
+
+class JoplinLoader(BaseLoader):
+ """Load notes from `Joplin`.
+
+ In order to use this loader, you need to have Joplin running with the
+ Web Clipper enabled (look for "Web Clipper" in the app settings).
+
+ To get the access token, you need to go to the Web Clipper options and
+ under "Advanced Options" you will find the access token.
+
+ You can find more information about the Web Clipper service here:
+ https://joplinapp.org/clipper/
+ """
+
+ def __init__(
+ self,
+ access_token: Optional[str] = None,
+ port: int = 41184,
+ host: str = "localhost",
+ ) -> None:
+ """
+
+ Args:
+ access_token: The access token to use.
+ port: The port where the Web Clipper service is running. Default is 41184.
+ host: The host where the Web Clipper service is running.
+ Default is localhost.
+ """
+ access_token = access_token or get_from_env(
+ "access_token", "JOPLIN_ACCESS_TOKEN"
+ )
+ base_url = f"http://{host}:{port}"
+ self._get_note_url = (
+ f"{base_url}/notes?token={access_token}"
+ f"&fields=id,parent_id,title,body,created_time,updated_time&page={{page}}"
+ )
+ self._get_folder_url = (
+ f"{base_url}/folders/{{id}}?token={access_token}&fields=title"
+ )
+ self._get_tag_url = (
+ f"{base_url}/notes/{{id}}/tags?token={access_token}&fields=title"
+ )
+
+ def _get_notes(self) -> Iterator[Document]:
+ has_more = True
+ page = 1
+ while has_more:
+ req_note = urllib.request.Request(self._get_note_url.format(page=page))
+ with urllib.request.urlopen(req_note) as response:
+ json_data = json.loads(response.read().decode())
+ for note in json_data["items"]:
+ metadata = {
+ "source": LINK_NOTE_TEMPLATE.format(id=note["id"]),
+ "folder": self._get_folder(note["parent_id"]),
+ "tags": self._get_tags(note["id"]),
+ "title": note["title"],
+ "created_time": self._convert_date(note["created_time"]),
+ "updated_time": self._convert_date(note["updated_time"]),
+ }
+ yield Document(page_content=note["body"], metadata=metadata)
+
+ has_more = json_data["has_more"]
+ page += 1
+
+ def _get_folder(self, folder_id: str) -> str:
+ req_folder = urllib.request.Request(self._get_folder_url.format(id=folder_id))
+ with urllib.request.urlopen(req_folder) as response:
+ json_data = json.loads(response.read().decode())
+ return json_data["title"]
+
+ def _get_tags(self, note_id: str) -> List[str]:
+ req_tag = urllib.request.Request(self._get_tag_url.format(id=note_id))
+ with urllib.request.urlopen(req_tag) as response:
+ json_data = json.loads(response.read().decode())
+ return [tag["title"] for tag in json_data["items"]]
+
+ def _convert_date(self, date: int) -> str:
+ return datetime.fromtimestamp(date / 1000).strftime("%Y-%m-%d %H:%M:%S")
+
+ def lazy_load(self) -> Iterator[Document]:
+ yield from self._get_notes()
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/json_loader.py b/libs/community/langchain_community/document_loaders/json_loader.py
new file mode 100644
index 00000000000..ef7be7caaf8
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/json_loader.py
@@ -0,0 +1,151 @@
+import json
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class JSONLoader(BaseLoader):
+ """Load a `JSON` file using a `jq` schema.
+
+ Example:
+ [{"text": ...}, {"text": ...}, {"text": ...}] -> schema = .[].text
+ {"key": [{"text": ...}, {"text": ...}, {"text": ...}]} -> schema = .key[].text
+ ["", "", ""] -> schema = .[]
+ """
+
+ def __init__(
+ self,
+ file_path: Union[str, Path],
+ jq_schema: str,
+ content_key: Optional[str] = None,
+ metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
+ text_content: bool = True,
+ json_lines: bool = False,
+ ):
+ """Initialize the JSONLoader.
+
+ Args:
+ file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
+ jq_schema (str): The jq schema to use to extract the data or text from
+ the JSON.
+ content_key (str): The key to use to extract the content from the JSON if
+ the jq_schema results to a list of objects (dict).
+ metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
+ object extracted by the jq_schema and the default metadata and returns
+ a dict of the updated metadata.
+ text_content (bool): Boolean flag to indicate whether the content is in
+ string format, default to True.
+ json_lines (bool): Boolean flag to indicate whether the input is in
+ JSON Lines format.
+ """
+ try:
+ import jq # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "jq package not found, please install it with `pip install jq`"
+ )
+
+ self.file_path = Path(file_path).resolve()
+ self._jq_schema = jq.compile(jq_schema)
+ self._content_key = content_key
+ self._metadata_func = metadata_func
+ self._text_content = text_content
+ self._json_lines = json_lines
+
+ def load(self) -> List[Document]:
+ """Load and return documents from the JSON file."""
+ docs: List[Document] = []
+ if self._json_lines:
+ with self.file_path.open(encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ self._parse(line, docs)
+ else:
+ self._parse(self.file_path.read_text(encoding="utf-8"), docs)
+ return docs
+
+ def _parse(self, content: str, docs: List[Document]) -> None:
+ """Convert given content to documents."""
+ data = self._jq_schema.input(json.loads(content))
+
+ # Perform some validation
+ # This is not a perfect validation, but it should catch most cases
+ # and prevent the user from getting a cryptic error later on.
+ if self._content_key is not None:
+ self._validate_content_key(data)
+ if self._metadata_func is not None:
+ self._validate_metadata_func(data)
+
+ for i, sample in enumerate(data, len(docs) + 1):
+ text = self._get_text(sample=sample)
+ metadata = self._get_metadata(
+ sample=sample, source=str(self.file_path), seq_num=i
+ )
+ docs.append(Document(page_content=text, metadata=metadata))
+
+ def _get_text(self, sample: Any) -> str:
+ """Convert sample to string format"""
+ if self._content_key is not None:
+ content = sample.get(self._content_key)
+ else:
+ content = sample
+
+ if self._text_content and not isinstance(content, str):
+ raise ValueError(
+ f"Expected page_content is string, got {type(content)} instead. \
+ Set `text_content=False` if the desired input for \
+ `page_content` is not a string"
+ )
+
+ # In case the text is None, set it to an empty string
+ elif isinstance(content, str):
+ return content
+ elif isinstance(content, dict):
+ return json.dumps(content) if content else ""
+ else:
+ return str(content) if content is not None else ""
+
+ def _get_metadata(
+ self, sample: Dict[str, Any], **additional_fields: Any
+ ) -> Dict[str, Any]:
+ """
+ Return a metadata dictionary base on the existence of metadata_func
+ :param sample: single data payload
+ :param additional_fields: key-word arguments to be added as metadata values
+ :return:
+ """
+ if self._metadata_func is not None:
+ return self._metadata_func(sample, additional_fields)
+ else:
+ return additional_fields
+
+ def _validate_content_key(self, data: Any) -> None:
+ """Check if a content key is valid"""
+ sample = data.first()
+ if not isinstance(sample, dict):
+ raise ValueError(
+ f"Expected the jq schema to result in a list of objects (dict), \
+ so sample must be a dict but got `{type(sample)}`"
+ )
+
+ if sample.get(self._content_key) is None:
+ raise ValueError(
+ f"Expected the jq schema to result in a list of objects (dict) \
+ with the key `{self._content_key}`"
+ )
+
+ def _validate_metadata_func(self, data: Any) -> None:
+ """Check if the metadata_func output is valid"""
+
+ sample = data.first()
+ if self._metadata_func is not None:
+ sample_metadata = self._metadata_func(sample, {})
+ if not isinstance(sample_metadata, dict):
+ raise ValueError(
+ f"Expected the metadata_func to return a dict but got \
+ `{type(sample_metadata)}`"
+ )
diff --git a/libs/community/langchain_community/document_loaders/lakefs.py b/libs/community/langchain_community/document_loaders/lakefs.py
new file mode 100644
index 00000000000..d2ebec14d36
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/lakefs.py
@@ -0,0 +1,179 @@
+import os
+import tempfile
+import urllib.parse
+from typing import Any, List, Optional
+from urllib.parse import urljoin
+
+import requests
+from langchain_core.documents import Document
+from requests.auth import HTTPBasicAuth
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredBaseLoader
+
+
+class LakeFSClient:
+ def __init__(
+ self,
+ lakefs_access_key: str,
+ lakefs_secret_key: str,
+ lakefs_endpoint: str,
+ ):
+ self.__endpoint = "/".join([lakefs_endpoint, "api", "v1/"])
+ self.__auth = HTTPBasicAuth(lakefs_access_key, lakefs_secret_key)
+ try:
+ health_check = requests.get(
+ urljoin(self.__endpoint, "healthcheck"), auth=self.__auth
+ )
+ health_check.raise_for_status()
+ except Exception:
+ raise ValueError(
+ "lakeFS server isn't accessible. Make sure lakeFS is running."
+ )
+
+ def ls_objects(
+ self, repo: str, ref: str, path: str, presign: Optional[bool]
+ ) -> List:
+ qp = {"prefix": path, "presign": presign}
+ eqp = urllib.parse.urlencode(qp)
+ objects_ls_endpoint = urljoin(
+ self.__endpoint, f"repositories/{repo}/refs/{ref}/objects/ls?{eqp}"
+ )
+ olsr = requests.get(objects_ls_endpoint, auth=self.__auth)
+ olsr.raise_for_status()
+ olsr_json = olsr.json()
+ return list(
+ map(
+ lambda res: (res["path"], res["physical_address"]), olsr_json["results"]
+ )
+ )
+
+ def is_presign_supported(self) -> bool:
+ config_endpoint = self.__endpoint + "config"
+ response = requests.get(config_endpoint, auth=self.__auth)
+ response.raise_for_status()
+ config = response.json()
+ return config["storage_config"]["pre_sign_support"]
+
+
+class LakeFSLoader(BaseLoader):
+ """Load from `lakeFS`."""
+
+ repo: str
+ ref: str
+ path: str
+
+ def __init__(
+ self,
+ lakefs_access_key: str,
+ lakefs_secret_key: str,
+ lakefs_endpoint: str,
+ repo: Optional[str] = None,
+ ref: Optional[str] = "main",
+ path: Optional[str] = "",
+ ):
+ """
+
+ :param lakefs_access_key: [required] lakeFS server's access key
+ :param lakefs_secret_key: [required] lakeFS server's secret key
+ :param lakefs_endpoint: [required] lakeFS server's endpoint address,
+ ex: https://example.my-lakefs.com
+ :param repo: [optional, default = ''] target repository
+ :param ref: [optional, default = 'main'] target ref (branch name,
+ tag, or commit ID)
+ :param path: [optional, default = ''] target path
+ """
+
+ self.__lakefs_client = LakeFSClient(
+ lakefs_access_key, lakefs_secret_key, lakefs_endpoint
+ )
+ self.repo = "" if repo is None or repo == "" else str(repo)
+ self.ref = "main" if ref is None or ref == "" else str(ref)
+ self.path = "" if path is None else str(path)
+
+ def set_path(self, path: str) -> None:
+ self.path = path
+
+ def set_ref(self, ref: str) -> None:
+ self.ref = ref
+
+ def set_repo(self, repo: str) -> None:
+ self.repo = repo
+
+ def load(self) -> List[Document]:
+ self.__validate_instance()
+ presigned = self.__lakefs_client.is_presign_supported()
+ docs: List[Document] = []
+ objs = self.__lakefs_client.ls_objects(
+ repo=self.repo, ref=self.ref, path=self.path, presign=presigned
+ )
+ for obj in objs:
+ lakefs_unstructured_loader = UnstructuredLakeFSLoader(
+ obj[1], self.repo, self.ref, obj[0], presigned
+ )
+ docs.extend(lakefs_unstructured_loader.load())
+ return docs
+
+ def __validate_instance(self) -> None:
+ if self.repo is None or self.repo == "":
+ raise ValueError(
+ "no repository was provided. use `set_repo` to specify a repository"
+ )
+ if self.ref is None or self.ref == "":
+ raise ValueError("no ref was provided. use `set_ref` to specify a ref")
+ if self.path is None:
+ raise ValueError("no path was provided. use `set_path` to specify a path")
+
+
+class UnstructuredLakeFSLoader(UnstructuredBaseLoader):
+ def __init__(
+ self,
+ url: str,
+ repo: str,
+ ref: str = "main",
+ path: str = "",
+ presign: bool = True,
+ **unstructured_kwargs: Any,
+ ):
+ """
+
+ Args:
+
+ :param lakefs_access_key:
+ :param lakefs_secret_key:
+ :param lakefs_endpoint:
+ :param repo:
+ :param ref:
+ """
+
+ super().__init__(**unstructured_kwargs)
+ self.url = url
+ self.repo = repo
+ self.ref = ref
+ self.path = path
+ self.presign = presign
+
+ def _get_metadata(self) -> dict:
+ return {"repo": self.repo, "ref": self.ref, "path": self.path}
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.auto import partition
+
+ local_prefix = "local://"
+
+ if self.presign:
+ with tempfile.TemporaryDirectory() as temp_dir:
+ file_path = f"{temp_dir}/{self.path.split('/')[-1]}"
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ response = requests.get(self.url)
+ response.raise_for_status()
+ with open(file_path, mode="wb") as file:
+ file.write(response.content)
+ return partition(filename=file_path)
+ elif not self.url.startswith(local_prefix):
+ raise ValueError(
+ "Non pre-signed URLs are supported only with 'local' blockstore"
+ )
+ else:
+ local_path = self.url[len(local_prefix) :]
+ return partition(filename=local_path)
diff --git a/libs/community/langchain_community/document_loaders/larksuite.py b/libs/community/langchain_community/document_loaders/larksuite.py
new file mode 100644
index 00000000000..5e6e2787355
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/larksuite.py
@@ -0,0 +1,52 @@
+import json
+import urllib.request
+from typing import Any, Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class LarkSuiteDocLoader(BaseLoader):
+ """Load from `LarkSuite` (`FeiShu`)."""
+
+ def __init__(self, domain: str, access_token: str, document_id: str):
+ """Initialize with domain, access_token (tenant / user), and document_id.
+
+ Args:
+ domain: The domain to load the LarkSuite.
+ access_token: The access_token to use.
+ document_id: The document_id to load.
+ """
+ self.domain = domain
+ self.access_token = access_token
+ self.document_id = document_id
+
+ def _get_larksuite_api_json_data(self, api_url: str) -> Any:
+ """Get LarkSuite (FeiShu) API response json data."""
+ headers = {"Authorization": f"Bearer {self.access_token}"}
+ request = urllib.request.Request(api_url, headers=headers)
+ with urllib.request.urlopen(request) as response:
+ json_data = json.loads(response.read().decode())
+ return json_data
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load LarkSuite (FeiShu) document."""
+ api_url_prefix = f"{self.domain}/open-apis/docx/v1/documents"
+ metadata_json = self._get_larksuite_api_json_data(
+ f"{api_url_prefix}/{self.document_id}"
+ )
+ raw_content_json = self._get_larksuite_api_json_data(
+ f"{api_url_prefix}/{self.document_id}/raw_content"
+ )
+ text = raw_content_json["data"]["content"]
+ metadata = {
+ "document_id": self.document_id,
+ "revision_id": metadata_json["data"]["document"]["revision_id"],
+ "title": metadata_json["data"]["document"]["title"],
+ }
+ yield Document(page_content=text, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """Load LarkSuite (FeiShu) document."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/markdown.py b/libs/community/langchain_community/document_loaders/markdown.py
new file mode 100644
index 00000000000..0d0df3752d3
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/markdown.py
@@ -0,0 +1,45 @@
+from typing import List
+
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class UnstructuredMarkdownLoader(UnstructuredFileLoader):
+ """Load `Markdown` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredMarkdownLoader
+
+ loader = UnstructuredMarkdownLoader(
+ "example.md", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-md
+ """
+
+ def _get_elements(self) -> List:
+ from unstructured.__version__ import __version__ as __unstructured_version__
+ from unstructured.partition.md import partition_md
+
+ # NOTE(MthwRobinson) - enables the loader to work when you're using pre-release
+ # versions of unstructured like 0.4.17-dev1
+ _unstructured_version = __unstructured_version__.split("-")[0]
+ unstructured_version = tuple([int(x) for x in _unstructured_version.split(".")])
+
+ if unstructured_version < (0, 4, 16):
+ raise ValueError(
+ f"You are on unstructured version {__unstructured_version__}. "
+ "Partitioning markdown files is only supported in unstructured>=0.4.16."
+ )
+
+ return partition_md(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/mastodon.py b/libs/community/langchain_community/document_loaders/mastodon.py
new file mode 100644
index 00000000000..9b49cd0fe24
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/mastodon.py
@@ -0,0 +1,90 @@
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ import mastodon
+
+
+def _dependable_mastodon_import() -> mastodon:
+ try:
+ import mastodon
+ except ImportError:
+ raise ImportError(
+ "Mastodon.py package not found, "
+ "please install it with `pip install Mastodon.py`"
+ )
+ return mastodon
+
+
+class MastodonTootsLoader(BaseLoader):
+ """Load the `Mastodon` 'toots'."""
+
+ def __init__(
+ self,
+ mastodon_accounts: Sequence[str],
+ number_toots: Optional[int] = 100,
+ exclude_replies: bool = False,
+ access_token: Optional[str] = None,
+ api_base_url: str = "https://mastodon.social",
+ ):
+ """Instantiate Mastodon toots loader.
+
+ Args:
+ mastodon_accounts: The list of Mastodon accounts to query.
+ number_toots: How many toots to pull for each account. Defaults to 100.
+ exclude_replies: Whether to exclude reply toots from the load.
+ Defaults to False.
+ access_token: An access token if toots are loaded as a Mastodon app. Can
+ also be specified via the environment variables "MASTODON_ACCESS_TOKEN".
+ api_base_url: A Mastodon API base URL to talk to, if not using the default.
+ Defaults to "https://mastodon.social".
+ """
+ mastodon = _dependable_mastodon_import()
+ access_token = access_token or os.environ.get("MASTODON_ACCESS_TOKEN")
+ self.api = mastodon.Mastodon(
+ access_token=access_token, api_base_url=api_base_url
+ )
+ self.mastodon_accounts = mastodon_accounts
+ self.number_toots = number_toots
+ self.exclude_replies = exclude_replies
+
+ def load(self) -> List[Document]:
+ """Load toots into documents."""
+ results: List[Document] = []
+ for account in self.mastodon_accounts:
+ user = self.api.account_lookup(account)
+ toots = self.api.account_statuses(
+ user.id,
+ only_media=False,
+ pinned=False,
+ exclude_replies=self.exclude_replies,
+ exclude_reblogs=True,
+ limit=self.number_toots,
+ )
+ docs = self._format_toots(toots, user)
+ results.extend(docs)
+ return results
+
+ def _format_toots(
+ self, toots: List[Dict[str, Any]], user_info: dict
+ ) -> Iterable[Document]:
+ """Format toots into documents.
+
+ Adding user info, and selected toot fields into the metadata.
+ """
+ for toot in toots:
+ metadata = {
+ "created_at": toot["created_at"],
+ "user_info": user_info,
+ "is_reply": toot["in_reply_to_id"] is not None,
+ }
+ yield Document(
+ page_content=toot["content"],
+ metadata=metadata,
+ )
diff --git a/libs/community/langchain_community/document_loaders/max_compute.py b/libs/community/langchain_community/document_loaders/max_compute.py
new file mode 100644
index 00000000000..b5cadbbaf92
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/max_compute.py
@@ -0,0 +1,83 @@
+from __future__ import annotations
+
+from typing import Any, Iterator, List, Optional, Sequence
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.max_compute import MaxComputeAPIWrapper
+
+
+class MaxComputeLoader(BaseLoader):
+ """Load from `Alibaba Cloud MaxCompute` table."""
+
+ def __init__(
+ self,
+ query: str,
+ api_wrapper: MaxComputeAPIWrapper,
+ *,
+ page_content_columns: Optional[Sequence[str]] = None,
+ metadata_columns: Optional[Sequence[str]] = None,
+ ):
+ """Initialize Alibaba Cloud MaxCompute document loader.
+
+ Args:
+ query: SQL query to execute.
+ api_wrapper: MaxCompute API wrapper.
+ page_content_columns: The columns to write into the `page_content` of the
+ Document. If unspecified, all columns will be written to `page_content`.
+ metadata_columns: The columns to write into the `metadata` of the Document.
+ If unspecified, all columns not added to `page_content` will be written.
+ """
+ self.query = query
+ self.api_wrapper = api_wrapper
+ self.page_content_columns = page_content_columns
+ self.metadata_columns = metadata_columns
+
+ @classmethod
+ def from_params(
+ cls,
+ query: str,
+ endpoint: str,
+ project: str,
+ *,
+ access_id: Optional[str] = None,
+ secret_access_key: Optional[str] = None,
+ **kwargs: Any,
+ ) -> MaxComputeLoader:
+ """Convenience constructor that builds the MaxCompute API wrapper from
+ given parameters.
+
+ Args:
+ query: SQL query to execute.
+ endpoint: MaxCompute endpoint.
+ project: A project is a basic organizational unit of MaxCompute, which is
+ similar to a database.
+ access_id: MaxCompute access ID. Should be passed in directly or set as the
+ environment variable `MAX_COMPUTE_ACCESS_ID`.
+ secret_access_key: MaxCompute secret access key. Should be passed in
+ directly or set as the environment variable
+ `MAX_COMPUTE_SECRET_ACCESS_KEY`.
+ """
+ api_wrapper = MaxComputeAPIWrapper.from_params(
+ endpoint, project, access_id=access_id, secret_access_key=secret_access_key
+ )
+ return cls(query, api_wrapper, **kwargs)
+
+ def lazy_load(self) -> Iterator[Document]:
+ for row in self.api_wrapper.query(self.query):
+ if self.page_content_columns:
+ page_content_data = {
+ k: v for k, v in row.items() if k in self.page_content_columns
+ }
+ else:
+ page_content_data = row
+ page_content = "\n".join(f"{k}: {v}" for k, v in page_content_data.items())
+ if self.metadata_columns:
+ metadata = {k: v for k, v in row.items() if k in self.metadata_columns}
+ else:
+ metadata = {k: v for k, v in row.items() if k not in page_content_data}
+ yield Document(page_content=page_content, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/mediawikidump.py b/libs/community/langchain_community/document_loaders/mediawikidump.py
new file mode 100644
index 00000000000..dd189974309
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/mediawikidump.py
@@ -0,0 +1,96 @@
+import logging
+from pathlib import Path
+from typing import List, Optional, Sequence, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class MWDumpLoader(BaseLoader):
+ """Load `MediaWiki` dump from an `XML` file.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import MWDumpLoader
+
+ loader = MWDumpLoader(
+ file_path="myWiki.xml",
+ encoding="utf8"
+ )
+ docs = loader.load()
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=1000, chunk_overlap=0
+ )
+ texts = text_splitter.split_documents(docs)
+
+
+ :param file_path: XML local file path
+ :type file_path: str
+ :param encoding: Charset encoding, defaults to "utf8"
+ :type encoding: str, optional
+ :param namespaces: The namespace of pages you want to parse.
+ See https://www.mediawiki.org/wiki/Help:Namespaces#Localisation
+ for a list of all common namespaces
+ :type namespaces: List[int],optional
+ :param skip_redirects: TR=rue to skip pages that redirect to other pages,
+ False to keep them. False by default
+ :type skip_redirects: bool, optional
+ :param stop_on_error: False to skip over pages that cause parsing errors,
+ True to stop. True by default
+ :type stop_on_error: bool, optional
+ """
+
+ def __init__(
+ self,
+ file_path: Union[str, Path],
+ encoding: Optional[str] = "utf8",
+ namespaces: Optional[Sequence[int]] = None,
+ skip_redirects: Optional[bool] = False,
+ stop_on_error: Optional[bool] = True,
+ ):
+ self.file_path = file_path if isinstance(file_path, str) else str(file_path)
+ self.encoding = encoding
+ # Namespaces range from -2 to 15, inclusive.
+ self.namespaces = namespaces
+ self.skip_redirects = skip_redirects
+ self.stop_on_error = stop_on_error
+
+ def load(self) -> List[Document]:
+ """Load from a file path."""
+ try:
+ import mwparserfromhell
+ import mwxml
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import 'mwparserfromhell' or 'mwxml'. Please install with"
+ " `pip install mwparserfromhell mwxml`."
+ ) from e
+
+ dump = mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding))
+
+ docs = []
+ for page in dump.pages:
+ if self.skip_redirects and page.redirect:
+ continue
+ if self.namespaces and page.namespace not in self.namespaces:
+ continue
+ try:
+ for revision in page:
+ code = mwparserfromhell.parse(revision.text)
+ text = code.strip_code(
+ normalize=True, collapse=True, keep_template_params=False
+ )
+ metadata = {"source": page.title}
+ docs.append(Document(page_content=text, metadata=metadata))
+ except Exception as e:
+ logger.error("Parsing error: {}".format(e))
+ if self.stop_on_error:
+ raise e
+ else:
+ continue
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/merge.py b/libs/community/langchain_community/document_loaders/merge.py
new file mode 100644
index 00000000000..c93963e70ca
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/merge.py
@@ -0,0 +1,28 @@
+from typing import Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class MergedDataLoader(BaseLoader):
+ """Merge documents from a list of loaders"""
+
+ def __init__(self, loaders: List):
+ """Initialize with a list of loaders"""
+ self.loaders = loaders
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load docs from each individual loader."""
+ for loader in self.loaders:
+ # Check if lazy_load is implemented
+ try:
+ data = loader.lazy_load()
+ except NotImplementedError:
+ data = loader.load()
+ for document in data:
+ yield document
+
+ def load(self) -> List[Document]:
+ """Load docs."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/mhtml.py b/libs/community/langchain_community/document_loaders/mhtml.py
new file mode 100644
index 00000000000..6cc73a4c48d
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/mhtml.py
@@ -0,0 +1,76 @@
+import email
+import logging
+from typing import Dict, List, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class MHTMLLoader(BaseLoader):
+ """Parse `MHTML` files with `BeautifulSoup`."""
+
+ def __init__(
+ self,
+ file_path: str,
+ open_encoding: Union[str, None] = None,
+ bs_kwargs: Union[dict, None] = None,
+ get_text_separator: str = "",
+ ) -> None:
+ """Initialise with path, and optionally, file encoding to use, and any kwargs
+ to pass to the BeautifulSoup object.
+
+ Args:
+ file_path: Path to file to load.
+ open_encoding: The encoding to use when opening the file.
+ bs_kwargs: Any kwargs to pass to the BeautifulSoup object.
+ get_text_separator: The separator to use when getting the text
+ from the soup.
+ """
+ try:
+ import bs4 # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "beautifulsoup4 package not found, please install it with "
+ "`pip install beautifulsoup4`"
+ )
+
+ self.file_path = file_path
+ self.open_encoding = open_encoding
+ if bs_kwargs is None:
+ bs_kwargs = {"features": "lxml"}
+ self.bs_kwargs = bs_kwargs
+ self.get_text_separator = get_text_separator
+
+ def load(self) -> List[Document]:
+ from bs4 import BeautifulSoup
+
+ """Load MHTML document into document objects."""
+
+ with open(self.file_path, "r", encoding=self.open_encoding) as f:
+ message = email.message_from_string(f.read())
+ parts = message.get_payload()
+
+ if not isinstance(parts, list):
+ parts = [message]
+
+ for part in parts:
+ if part.get_content_type() == "text/html":
+ html = part.get_payload(decode=True).decode()
+
+ soup = BeautifulSoup(html, **self.bs_kwargs)
+ text = soup.get_text(self.get_text_separator)
+
+ if soup.title:
+ title = str(soup.title.string)
+ else:
+ title = ""
+
+ metadata: Dict[str, Union[str, None]] = {
+ "source": self.file_path,
+ "title": title,
+ }
+ return [Document(page_content=text, metadata=metadata)]
+ return []
diff --git a/libs/community/langchain_community/document_loaders/modern_treasury.py b/libs/community/langchain_community/document_loaders/modern_treasury.py
new file mode 100644
index 00000000000..045a2e786ed
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/modern_treasury.py
@@ -0,0 +1,73 @@
+import json
+import urllib.request
+from base64 import b64encode
+from typing import List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.utils import get_from_env, stringify_value
+
+from langchain_community.document_loaders.base import BaseLoader
+
+MODERN_TREASURY_ENDPOINTS = {
+ "payment_orders": "https://app.moderntreasury.com/api/payment_orders",
+ "expected_payments": "https://app.moderntreasury.com/api/expected_payments",
+ "returns": "https://app.moderntreasury.com/api/returns",
+ "incoming_payment_details": "https://app.moderntreasury.com/api/\
+incoming_payment_details",
+ "counterparties": "https://app.moderntreasury.com/api/counterparties",
+ "internal_accounts": "https://app.moderntreasury.com/api/internal_accounts",
+ "external_accounts": "https://app.moderntreasury.com/api/external_accounts",
+ "transactions": "https://app.moderntreasury.com/api/transactions",
+ "ledgers": "https://app.moderntreasury.com/api/ledgers",
+ "ledger_accounts": "https://app.moderntreasury.com/api/ledger_accounts",
+ "ledger_transactions": "https://app.moderntreasury.com/api/ledger_transactions",
+ "events": "https://app.moderntreasury.com/api/events",
+ "invoices": "https://app.moderntreasury.com/api/invoices",
+}
+
+
+class ModernTreasuryLoader(BaseLoader):
+ """Load from `Modern Treasury`."""
+
+ def __init__(
+ self,
+ resource: str,
+ organization_id: Optional[str] = None,
+ api_key: Optional[str] = None,
+ ) -> None:
+ """
+
+ Args:
+ resource: The Modern Treasury resource to load.
+ organization_id: The Modern Treasury organization ID. It can also be
+ specified via the environment variable
+ "MODERN_TREASURY_ORGANIZATION_ID".
+ api_key: The Modern Treasury API key. It can also be specified via
+ the environment variable "MODERN_TREASURY_API_KEY".
+ """
+ self.resource = resource
+ organization_id = organization_id or get_from_env(
+ "organization_id", "MODERN_TREASURY_ORGANIZATION_ID"
+ )
+ api_key = api_key or get_from_env("api_key", "MODERN_TREASURY_API_KEY")
+ credentials = f"{organization_id}:{api_key}".encode("utf-8")
+ basic_auth_token = b64encode(credentials).decode("utf-8")
+ self.headers = {"Authorization": f"Basic {basic_auth_token}"}
+
+ def _make_request(self, url: str) -> List[Document]:
+ request = urllib.request.Request(url, headers=self.headers)
+
+ with urllib.request.urlopen(request) as response:
+ json_data = json.loads(response.read().decode())
+ text = stringify_value(json_data)
+ metadata = {"source": url}
+ return [Document(page_content=text, metadata=metadata)]
+
+ def _get_resource(self) -> List[Document]:
+ endpoint = MODERN_TREASURY_ENDPOINTS.get(self.resource)
+ if endpoint is None:
+ return []
+ return self._make_request(endpoint)
+
+ def load(self) -> List[Document]:
+ return self._get_resource()
diff --git a/libs/community/langchain_community/document_loaders/mongodb.py b/libs/community/langchain_community/document_loaders/mongodb.py
new file mode 100644
index 00000000000..5ada325a273
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/mongodb.py
@@ -0,0 +1,77 @@
+import asyncio
+import logging
+from typing import Dict, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class MongodbLoader(BaseLoader):
+ """Load MongoDB documents."""
+
+ def __init__(
+ self,
+ connection_string: str,
+ db_name: str,
+ collection_name: str,
+ *,
+ filter_criteria: Optional[Dict] = None,
+ ) -> None:
+ try:
+ from motor.motor_asyncio import AsyncIOMotorClient
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import from motor, please install with `pip install motor`."
+ ) from e
+ if not connection_string:
+ raise ValueError("connection_string must be provided.")
+
+ if not db_name:
+ raise ValueError("db_name must be provided.")
+
+ if not collection_name:
+ raise ValueError("collection_name must be provided.")
+
+ self.client = AsyncIOMotorClient(connection_string)
+ self.db_name = db_name
+ self.collection_name = collection_name
+ self.filter_criteria = filter_criteria or {}
+
+ self.db = self.client.get_database(db_name)
+ self.collection = self.db.get_collection(collection_name)
+
+ def load(self) -> List[Document]:
+ """Load data into Document objects.
+
+ Attention:
+
+ This implementation starts an asyncio event loop which
+ will only work if running in a sync env. In an async env, it should
+ fail since there is already an event loop running.
+
+ This code should be updated to kick off the event loop from a separate
+ thread if running within an async context.
+ """
+ return asyncio.run(self.aload())
+
+ async def aload(self) -> List[Document]:
+ """Load data into Document objects."""
+ result = []
+ total_docs = await self.collection.count_documents(self.filter_criteria)
+ async for doc in self.collection.find(self.filter_criteria):
+ metadata = {
+ "database": self.db_name,
+ "collection": self.collection_name,
+ }
+ result.append(Document(page_content=str(doc), metadata=metadata))
+
+ if len(result) != total_docs:
+ logger.warning(
+ f"Only partial collection of documents returned. Loaded {len(result)} "
+ f"docs, expected {total_docs}."
+ )
+
+ return result
diff --git a/libs/community/langchain_community/document_loaders/news.py b/libs/community/langchain_community/document_loaders/news.py
new file mode 100644
index 00000000000..7f59601db62
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/news.py
@@ -0,0 +1,125 @@
+"""Loader that uses unstructured to load HTML files."""
+import logging
+from typing import Any, Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class NewsURLLoader(BaseLoader):
+ """Load news articles from URLs using `Unstructured`.
+
+ Args:
+ urls: URLs to load. Each is loaded into its own document.
+ text_mode: If True, extract text from URL and use that for page content.
+ Otherwise, extract raw HTML.
+ nlp: If True, perform NLP on the extracted contents, like providing a summary
+ and extracting keywords.
+ continue_on_failure: If True, continue loading documents even if
+ loading fails for a particular URL.
+ show_progress_bar: If True, use tqdm to show a loading progress bar. Requires
+ tqdm to be installed, ``pip install tqdm``.
+ **newspaper_kwargs: Any additional named arguments to pass to
+ newspaper.Article().
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import NewsURLLoader
+
+ loader = NewsURLLoader(
+ urls=["", ""],
+ )
+ docs = loader.load()
+
+ Newspaper reference:
+ https://newspaper.readthedocs.io/en/latest/
+ """
+
+ def __init__(
+ self,
+ urls: List[str],
+ text_mode: bool = True,
+ nlp: bool = False,
+ continue_on_failure: bool = True,
+ show_progress_bar: bool = False,
+ **newspaper_kwargs: Any,
+ ) -> None:
+ """Initialize with file path."""
+ try:
+ import newspaper # noqa:F401
+
+ self.__version = newspaper.__version__
+ except ImportError:
+ raise ImportError(
+ "newspaper package not found, please install it with "
+ "`pip install newspaper3k`"
+ )
+
+ self.urls = urls
+ self.text_mode = text_mode
+ self.nlp = nlp
+ self.continue_on_failure = continue_on_failure
+ self.newspaper_kwargs = newspaper_kwargs
+ self.show_progress_bar = show_progress_bar
+
+ def load(self) -> List[Document]:
+ iter = self.lazy_load()
+ if self.show_progress_bar:
+ try:
+ from tqdm import tqdm
+ except ImportError as e:
+ raise ImportError(
+ "Package tqdm must be installed if show_progress_bar=True. "
+ "Please install with 'pip install tqdm' or set "
+ "show_progress_bar=False."
+ ) from e
+ iter = tqdm(iter)
+ return list(iter)
+
+ def lazy_load(self) -> Iterator[Document]:
+ try:
+ from newspaper import Article
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import newspaper, please install with `pip install newspaper3k`"
+ ) from e
+
+ for url in self.urls:
+ try:
+ article = Article(url, **self.newspaper_kwargs)
+ article.download()
+ article.parse()
+
+ if self.nlp:
+ article.nlp()
+
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.error(f"Error fetching or processing {url}, exception: {e}")
+ continue
+ else:
+ raise e
+
+ metadata = {
+ "title": getattr(article, "title", ""),
+ "link": getattr(article, "url", getattr(article, "canonical_link", "")),
+ "authors": getattr(article, "authors", []),
+ "language": getattr(article, "meta_lang", ""),
+ "description": getattr(article, "meta_description", ""),
+ "publish_date": getattr(article, "publish_date", ""),
+ }
+
+ if self.text_mode:
+ content = article.text
+ else:
+ content = article.html
+
+ if self.nlp:
+ metadata["keywords"] = getattr(article, "keywords", [])
+ metadata["summary"] = getattr(article, "summary", "")
+
+ yield Document(page_content=content, metadata=metadata)
diff --git a/libs/community/langchain_community/document_loaders/notebook.py b/libs/community/langchain_community/document_loaders/notebook.py
new file mode 100644
index 00000000000..51eec597a52
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/notebook.py
@@ -0,0 +1,133 @@
+"""Loads .ipynb notebook files."""
+import json
+from pathlib import Path
+from typing import Any, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+def concatenate_cells(
+ cell: dict, include_outputs: bool, max_output_length: int, traceback: bool
+) -> str:
+ """Combine cells information in a readable format ready to be used.
+
+ Args:
+ cell: A dictionary
+ include_outputs: Whether to include the outputs of the cell.
+ max_output_length: Maximum length of the output to be displayed.
+ traceback: Whether to return a traceback of the error.
+
+ Returns:
+ A string with the cell information.
+
+ """
+ cell_type = cell["cell_type"]
+ source = cell["source"]
+ output = cell["outputs"]
+
+ if include_outputs and cell_type == "code" and output:
+ if "ename" in output[0].keys():
+ error_name = output[0]["ename"]
+ error_value = output[0]["evalue"]
+ if traceback:
+ traceback = output[0]["traceback"]
+ return (
+ f"'{cell_type}' cell: '{source}'\n, gives error '{error_name}',"
+ f" with description '{error_value}'\n"
+ f"and traceback '{traceback}'\n\n"
+ )
+ else:
+ return (
+ f"'{cell_type}' cell: '{source}'\n, gives error '{error_name}',"
+ f"with description '{error_value}'\n\n"
+ )
+ elif output[0]["output_type"] == "stream":
+ output = output[0]["text"]
+ min_output = min(max_output_length, len(output))
+ return (
+ f"'{cell_type}' cell: '{source}'\n with "
+ f"output: '{output[:min_output]}'\n\n"
+ )
+ else:
+ return f"'{cell_type}' cell: '{source}'\n\n"
+
+ return ""
+
+
+def remove_newlines(x: Any) -> Any:
+ """Recursively remove newlines, no matter the data structure they are stored in."""
+ import pandas as pd
+
+ if isinstance(x, str):
+ return x.replace("\n", "")
+ elif isinstance(x, list):
+ return [remove_newlines(elem) for elem in x]
+ elif isinstance(x, pd.DataFrame):
+ return x.applymap(remove_newlines)
+ else:
+ return x
+
+
+class NotebookLoader(BaseLoader):
+ """Load `Jupyter notebook` (.ipynb) files."""
+
+ def __init__(
+ self,
+ path: str,
+ include_outputs: bool = False,
+ max_output_length: int = 10,
+ remove_newline: bool = False,
+ traceback: bool = False,
+ ):
+ """Initialize with a path.
+
+ Args:
+ path: The path to load the notebook from.
+ include_outputs: Whether to include the outputs of the cell.
+ Defaults to False.
+ max_output_length: Maximum length of the output to be displayed.
+ Defaults to 10.
+ remove_newline: Whether to remove newlines from the notebook.
+ Defaults to False.
+ traceback: Whether to return a traceback of the error.
+ Defaults to False.
+ """
+ self.file_path = path
+ self.include_outputs = include_outputs
+ self.max_output_length = max_output_length
+ self.remove_newline = remove_newline
+ self.traceback = traceback
+
+ def load(
+ self,
+ ) -> List[Document]:
+ """Load documents."""
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ "pandas is needed for Notebook Loader, "
+ "please install with `pip install pandas`"
+ )
+ p = Path(self.file_path)
+
+ with open(p, encoding="utf8") as f:
+ d = json.load(f)
+
+ data = pd.json_normalize(d["cells"])
+ filtered_data = data[["cell_type", "source", "outputs"]]
+ if self.remove_newline:
+ filtered_data = filtered_data.applymap(remove_newlines)
+
+ text = filtered_data.apply(
+ lambda x: concatenate_cells(
+ x, self.include_outputs, self.max_output_length, self.traceback
+ ),
+ axis=1,
+ ).str.cat(sep=" ")
+
+ metadata = {"source": str(p)}
+
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/notion.py b/libs/community/langchain_community/document_loaders/notion.py
new file mode 100644
index 00000000000..c42bf568f32
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/notion.py
@@ -0,0 +1,26 @@
+from pathlib import Path
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class NotionDirectoryLoader(BaseLoader):
+ """Load `Notion directory` dump."""
+
+ def __init__(self, path: str, *, encoding: str = "utf-8") -> None:
+ """Initialize with a file path."""
+ self.file_path = path
+ self.encoding = encoding
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ paths = list(Path(self.file_path).glob("**/*.md"))
+ docs = []
+ for p in paths:
+ with open(p, encoding=self.encoding) as f:
+ text = f.read()
+ metadata = {"source": str(p)}
+ docs.append(Document(page_content=text, metadata=metadata))
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/notiondb.py b/libs/community/langchain_community/document_loaders/notiondb.py
new file mode 100644
index 00000000000..0987d532c97
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/notiondb.py
@@ -0,0 +1,195 @@
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+NOTION_BASE_URL = "https://api.notion.com/v1"
+DATABASE_URL = NOTION_BASE_URL + "/databases/{database_id}/query"
+PAGE_URL = NOTION_BASE_URL + "/pages/{page_id}"
+BLOCK_URL = NOTION_BASE_URL + "/blocks/{block_id}/children"
+
+
+class NotionDBLoader(BaseLoader):
+ """Load from `Notion DB`.
+
+ Reads content from pages within a Notion Database.
+ Args:
+ integration_token (str): Notion integration token.
+ database_id (str): Notion database id.
+ request_timeout_sec (int): Timeout for Notion requests in seconds.
+ Defaults to 10.
+ """
+
+ def __init__(
+ self,
+ integration_token: str,
+ database_id: str,
+ request_timeout_sec: Optional[int] = 10,
+ ) -> None:
+ """Initialize with parameters."""
+ if not integration_token:
+ raise ValueError("integration_token must be provided")
+ if not database_id:
+ raise ValueError("database_id must be provided")
+
+ self.token = integration_token
+ self.database_id = database_id
+ self.headers = {
+ "Authorization": "Bearer " + self.token,
+ "Content-Type": "application/json",
+ "Notion-Version": "2022-06-28",
+ }
+ self.request_timeout_sec = request_timeout_sec
+
+ def load(self) -> List[Document]:
+ """Load documents from the Notion database.
+ Returns:
+ List[Document]: List of documents.
+ """
+ page_summaries = self._retrieve_page_summaries()
+
+ return list(self.load_page(page_summary) for page_summary in page_summaries)
+
+ def _retrieve_page_summaries(
+ self, query_dict: Dict[str, Any] = {"page_size": 100}
+ ) -> List[Dict[str, Any]]:
+ """Get all the pages from a Notion database."""
+ pages: List[Dict[str, Any]] = []
+
+ while True:
+ data = self._request(
+ DATABASE_URL.format(database_id=self.database_id),
+ method="POST",
+ query_dict=query_dict,
+ )
+
+ pages.extend(data.get("results"))
+
+ if not data.get("has_more"):
+ break
+
+ query_dict["start_cursor"] = data.get("next_cursor")
+
+ return pages
+
+ def load_page(self, page_summary: Dict[str, Any]) -> Document:
+ """Read a page.
+
+ Args:
+ page_summary: Page summary from Notion API.
+ """
+ page_id = page_summary["id"]
+
+ # load properties as metadata
+ metadata: Dict[str, Any] = {}
+
+ for prop_name, prop_data in page_summary["properties"].items():
+ prop_type = prop_data["type"]
+
+ if prop_type == "rich_text":
+ value = (
+ prop_data["rich_text"][0]["plain_text"]
+ if prop_data["rich_text"]
+ else None
+ )
+ elif prop_type == "title":
+ value = (
+ prop_data["title"][0]["plain_text"] if prop_data["title"] else None
+ )
+ elif prop_type == "multi_select":
+ value = (
+ [item["name"] for item in prop_data["multi_select"]]
+ if prop_data["multi_select"]
+ else []
+ )
+ elif prop_type == "url":
+ value = prop_data["url"]
+ elif prop_type == "unique_id":
+ value = (
+ f'{prop_data["unique_id"]["prefix"]}-{prop_data["unique_id"]["number"]}'
+ if prop_data["unique_id"]
+ else None
+ )
+ elif prop_type == "status":
+ value = prop_data["status"]["name"] if prop_data["status"] else None
+ elif prop_type == "people":
+ value = (
+ [item["name"] for item in prop_data["people"]]
+ if prop_data["people"]
+ else []
+ )
+ elif prop_type == "date":
+ value = prop_data["date"] if prop_data["date"] else None
+ elif prop_type == "last_edited_time":
+ value = (
+ prop_data["last_edited_time"]
+ if prop_data["last_edited_time"]
+ else None
+ )
+ elif prop_type == "created_time":
+ value = prop_data["created_time"] if prop_data["created_time"] else None
+ elif prop_type == "checkbox":
+ value = prop_data["checkbox"]
+ elif prop_type == "email":
+ value = prop_data["email"]
+ elif prop_type == "number":
+ value = prop_data["number"]
+ elif prop_type == "select":
+ value = prop_data["select"]["name"] if prop_data["select"] else None
+ else:
+ value = None
+
+ metadata[prop_name.lower()] = value
+
+ metadata["id"] = page_id
+
+ return Document(page_content=self._load_blocks(page_id), metadata=metadata)
+
+ def _load_blocks(self, block_id: str, num_tabs: int = 0) -> str:
+ """Read a block and its children."""
+ result_lines_arr: List[str] = []
+ cur_block_id: str = block_id
+
+ while cur_block_id:
+ data = self._request(BLOCK_URL.format(block_id=cur_block_id))
+
+ for result in data["results"]:
+ result_obj = result[result["type"]]
+
+ if "rich_text" not in result_obj:
+ continue
+
+ cur_result_text_arr: List[str] = []
+
+ for rich_text in result_obj["rich_text"]:
+ if "text" in rich_text:
+ cur_result_text_arr.append(
+ "\t" * num_tabs + rich_text["text"]["content"]
+ )
+
+ if result["has_children"]:
+ children_text = self._load_blocks(
+ result["id"], num_tabs=num_tabs + 1
+ )
+ cur_result_text_arr.append(children_text)
+
+ result_lines_arr.append("\n".join(cur_result_text_arr))
+
+ cur_block_id = data.get("next_cursor")
+
+ return "\n".join(result_lines_arr)
+
+ def _request(
+ self, url: str, method: str = "GET", query_dict: Dict[str, Any] = {}
+ ) -> Any:
+ res = requests.request(
+ method,
+ url,
+ headers=self.headers,
+ json=query_dict,
+ timeout=self.request_timeout_sec,
+ )
+ res.raise_for_status()
+ return res.json()
diff --git a/libs/community/langchain_community/document_loaders/nuclia.py b/libs/community/langchain_community/document_loaders/nuclia.py
new file mode 100644
index 00000000000..33128588994
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/nuclia.py
@@ -0,0 +1,33 @@
+import json
+import uuid
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.tools.nuclia.tool import NucliaUnderstandingAPI
+
+
+class NucliaLoader(BaseLoader):
+ """Load from any file type using `Nuclia Understanding API`."""
+
+ def __init__(self, path: str, nuclia_tool: NucliaUnderstandingAPI):
+ self.nua = nuclia_tool
+ self.id = str(uuid.uuid4())
+ self.nua.run({"action": "push", "id": self.id, "path": path, "text": None})
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ data = self.nua.run(
+ {"action": "pull", "id": self.id, "path": None, "text": None}
+ )
+ if not data:
+ return []
+ obj = json.loads(data)
+ text = obj["extracted_text"][0]["body"]["text"]
+ print(text)
+ metadata = {
+ "file": obj["file_extracted_data"][0],
+ "metadata": obj["field_metadata"][0],
+ }
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/obs_directory.py b/libs/community/langchain_community/document_loaders/obs_directory.py
new file mode 100644
index 00000000000..24b67149788
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/obs_directory.py
@@ -0,0 +1,83 @@
+# coding:utf-8
+from typing import List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.obs_file import OBSFileLoader
+
+
+class OBSDirectoryLoader(BaseLoader):
+ """Load from `Huawei OBS directory`."""
+
+ def __init__(
+ self,
+ bucket: str,
+ endpoint: str,
+ config: Optional[dict] = None,
+ prefix: str = "",
+ ):
+ """Initialize the OBSDirectoryLoader with the specified settings.
+
+ Args:
+ bucket (str): The name of the OBS bucket to be used.
+ endpoint (str): The endpoint URL of your OBS bucket.
+ config (dict): The parameters for connecting to OBS, provided as a dictionary. The dictionary could have the following keys:
+ - "ak" (str, optional): Your OBS access key (required if `get_token_from_ecs` is False and bucket policy is not public read).
+ - "sk" (str, optional): Your OBS secret key (required if `get_token_from_ecs` is False and bucket policy is not public read).
+ - "token" (str, optional): Your security token (required if using temporary credentials).
+ - "get_token_from_ecs" (bool, optional): Whether to retrieve the security token from ECS. Defaults to False if not provided. If set to True, `ak`, `sk`, and `token` will be ignored.
+ prefix (str, optional): The prefix to be added to the OBS key. Defaults to "".
+
+ Note:
+ Before using this class, make sure you have registered with OBS and have the necessary credentials. The `ak`, `sk`, and `endpoint` values are mandatory unless `get_token_from_ecs` is True or the bucket policy is public read. `token` is required when using temporary credentials.
+ Example:
+ To create a new OBSDirectoryLoader:
+ ```
+ config = {
+ "ak": "your-access-key",
+ "sk": "your-secret-key"
+ }
+ ```
+ directory_loader = OBSDirectoryLoader("your-bucket-name", "your-end-endpoint", config, "your-prefix")
+ """ # noqa: E501
+ try:
+ from obs import ObsClient
+ except ImportError:
+ raise ImportError(
+ "Could not import esdk-obs-python python package. "
+ "Please install it with `pip install esdk-obs-python`."
+ )
+ if not config:
+ config = dict()
+ if config.get("get_token_from_ecs"):
+ self.client = ObsClient(server=endpoint, security_provider_policy="ECS")
+ else:
+ self.client = ObsClient(
+ access_key_id=config.get("ak"),
+ secret_access_key=config.get("sk"),
+ security_token=config.get("token"),
+ server=endpoint,
+ )
+
+ self.bucket = bucket
+ self.prefix = prefix
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ max_num = 1000
+ mark = None
+ docs = []
+ while True:
+ resp = self.client.listObjects(
+ self.bucket, prefix=self.prefix, marker=mark, max_keys=max_num
+ )
+ if resp.status < 300:
+ for content in resp.body.contents:
+ loader = OBSFileLoader(self.bucket, content.key, client=self.client)
+ docs.extend(loader.load())
+ if resp.body.is_truncated is True:
+ mark = resp.body.next_marker
+ else:
+ break
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/obs_file.py b/libs/community/langchain_community/document_loaders/obs_file.py
new file mode 100644
index 00000000000..d6add62c33b
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/obs_file.py
@@ -0,0 +1,105 @@
+# coding:utf-8
+
+import os
+import tempfile
+from typing import Any, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class OBSFileLoader(BaseLoader):
+ """Load from the `Huawei OBS file`."""
+
+ def __init__(
+ self,
+ bucket: str,
+ key: str,
+ client: Any = None,
+ endpoint: str = "",
+ config: Optional[dict] = None,
+ ) -> None:
+ """Initialize the OBSFileLoader with the specified settings.
+
+ Args:
+ bucket (str): The name of the OBS bucket to be used.
+ key (str): The name of the object in the OBS bucket.
+ client (ObsClient, optional): An instance of the ObsClient to connect to OBS.
+ endpoint (str, optional): The endpoint URL of your OBS bucket. This parameter is mandatory if `client` is not provided.
+ config (dict, optional): The parameters for connecting to OBS, provided as a dictionary. This parameter is ignored if `client` is provided. The dictionary could have the following keys:
+ - "ak" (str, optional): Your OBS access key (required if `get_token_from_ecs` is False and bucket policy is not public read).
+ - "sk" (str, optional): Your OBS secret key (required if `get_token_from_ecs` is False and bucket policy is not public read).
+ - "token" (str, optional): Your security token (required if using temporary credentials).
+ - "get_token_from_ecs" (bool, optional): Whether to retrieve the security token from ECS. Defaults to False if not provided. If set to True, `ak`, `sk`, and `token` will be ignored.
+
+ Raises:
+ ValueError: If the `esdk-obs-python` package is not installed.
+ TypeError: If the provided `client` is not an instance of ObsClient.
+ ValueError: If `client` is not provided, but `endpoint` is missing.
+
+ Note:
+ Before using this class, make sure you have registered with OBS and have the necessary credentials. The `ak`, `sk`, and `endpoint` values are mandatory unless `get_token_from_ecs` is True or the bucket policy is public read. `token` is required when using temporary credentials.
+
+ Example:
+ To create a new OBSFileLoader with a new client:
+ ```
+ config = {
+ "ak": "your-access-key",
+ "sk": "your-secret-key"
+ }
+ obs_loader = OBSFileLoader("your-bucket-name", "your-object-key", config=config)
+ ```
+
+ To create a new OBSFileLoader with an existing client:
+ ```
+ from obs import ObsClient
+
+ # Assuming you have an existing ObsClient object 'obs_client'
+ obs_loader = OBSFileLoader("your-bucket-name", "your-object-key", client=obs_client)
+ ```
+
+ To create a new OBSFileLoader without an existing client:
+ ```
+ obs_loader = OBSFileLoader("your-bucket-name", "your-object-key", endpoint="your-endpoint-url")
+ ```
+ """ # noqa: E501
+ try:
+ from obs import ObsClient
+ except ImportError:
+ raise ImportError(
+ "Could not import esdk-obs-python python package. "
+ "Please install it with `pip install esdk-obs-python`."
+ )
+ if not client:
+ if not endpoint:
+ raise ValueError("Either OBSClient or endpoint must be provided.")
+ if not config:
+ config = dict()
+ if config.get("get_token_from_ecs"):
+ client = ObsClient(server=endpoint, security_provider_policy="ECS")
+ else:
+ client = ObsClient(
+ access_key_id=config.get("ak"),
+ secret_access_key=config.get("sk"),
+ security_token=config.get("token"),
+ server=endpoint,
+ )
+ if not isinstance(client, ObsClient):
+ raise TypeError("Client must be ObsClient type")
+ self.client = client
+ self.bucket = bucket
+ self.key = key
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ file_path = f"{temp_dir}/{self.bucket}/{self.key}"
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ # Download the file to a destination
+ self.client.downloadFile(
+ bucketName=self.bucket, objectKey=self.key, downloadFile=file_path
+ )
+ loader = UnstructuredFileLoader(file_path)
+ return loader.load()
diff --git a/libs/community/langchain_community/document_loaders/obsidian.py b/libs/community/langchain_community/document_loaders/obsidian.py
new file mode 100644
index 00000000000..e4b69341b67
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/obsidian.py
@@ -0,0 +1,168 @@
+import functools
+import logging
+import re
+from pathlib import Path
+from typing import Any, Dict, List
+
+import yaml
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class ObsidianLoader(BaseLoader):
+ """Load `Obsidian` files from directory."""
+
+ FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
+ TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL)
+ TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
+ DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
+ DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE)
+ DATAVIEW_INLINE_PAREN_REGEX = re.compile(r"\((\w+)::\s*(.*)\)", re.MULTILINE)
+
+ def __init__(
+ self, path: str, encoding: str = "UTF-8", collect_metadata: bool = True
+ ):
+ """Initialize with a path.
+
+ Args:
+ path: Path to the directory containing the Obsidian files.
+ encoding: Charset encoding, defaults to "UTF-8"
+ collect_metadata: Whether to collect metadata from the front matter.
+ Defaults to True.
+ """
+ self.file_path = path
+ self.encoding = encoding
+ self.collect_metadata = collect_metadata
+
+ def _replace_template_var(
+ self, placeholders: Dict[str, str], match: re.Match
+ ) -> str:
+ """Replace a template variable with a placeholder."""
+ placeholder = f"__TEMPLATE_VAR_{len(placeholders)}__"
+ placeholders[placeholder] = match.group(1)
+ return placeholder
+
+ def _restore_template_vars(self, obj: Any, placeholders: Dict[str, str]) -> Any:
+ """Restore template variables replaced with placeholders to original values."""
+ if isinstance(obj, str):
+ for placeholder, value in placeholders.items():
+ obj = obj.replace(placeholder, f"{{{{{value}}}}}")
+ elif isinstance(obj, dict):
+ for key, value in obj.items():
+ obj[key] = self._restore_template_vars(value, placeholders)
+ elif isinstance(obj, list):
+ for i, item in enumerate(obj):
+ obj[i] = self._restore_template_vars(item, placeholders)
+ return obj
+
+ def _parse_front_matter(self, content: str) -> dict:
+ """Parse front matter metadata from the content and return it as a dict."""
+ if not self.collect_metadata:
+ return {}
+
+ match = self.FRONT_MATTER_REGEX.search(content)
+ if not match:
+ return {}
+
+ placeholders: Dict[str, str] = {}
+ replace_template_var = functools.partial(
+ self._replace_template_var, placeholders
+ )
+ front_matter_text = self.TEMPLATE_VARIABLE_REGEX.sub(
+ replace_template_var, match.group(1)
+ )
+
+ try:
+ front_matter = yaml.safe_load(front_matter_text)
+ front_matter = self._restore_template_vars(front_matter, placeholders)
+
+ # If tags are a string, split them into a list
+ if "tags" in front_matter and isinstance(front_matter["tags"], str):
+ front_matter["tags"] = front_matter["tags"].split(", ")
+
+ return front_matter
+ except yaml.parser.ParserError:
+ logger.warning("Encountered non-yaml frontmatter")
+ return {}
+
+ def _to_langchain_compatible_metadata(self, metadata: dict) -> dict:
+ """Convert a dictionary to a compatible with langchain."""
+ result = {}
+ for key, value in metadata.items():
+ if type(value) in {str, int, float}:
+ result[key] = value
+ else:
+ result[key] = str(value)
+ return result
+
+ def _parse_document_tags(self, content: str) -> set:
+ """Return a set of all tags in within the document."""
+ if not self.collect_metadata:
+ return set()
+
+ match = self.TAG_REGEX.findall(content)
+ if not match:
+ return set()
+
+ return {tag for tag in match}
+
+ def _parse_dataview_fields(self, content: str) -> dict:
+ """Parse obsidian dataview plugin fields from the content and return it
+ as a dict."""
+ if not self.collect_metadata:
+ return {}
+
+ return {
+ **{
+ match[0]: match[1]
+ for match in self.DATAVIEW_LINE_REGEX.findall(content)
+ },
+ **{
+ match[0]: match[1]
+ for match in self.DATAVIEW_INLINE_PAREN_REGEX.findall(content)
+ },
+ **{
+ match[0]: match[1]
+ for match in self.DATAVIEW_INLINE_BRACKET_REGEX.findall(content)
+ },
+ }
+
+ def _remove_front_matter(self, content: str) -> str:
+ """Remove front matter metadata from the given content."""
+ if not self.collect_metadata:
+ return content
+ return self.FRONT_MATTER_REGEX.sub("", content)
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ paths = list(Path(self.file_path).glob("**/*.md"))
+ docs = []
+ for path in paths:
+ with open(path, encoding=self.encoding) as f:
+ text = f.read()
+
+ front_matter = self._parse_front_matter(text)
+ tags = self._parse_document_tags(text)
+ dataview_fields = self._parse_dataview_fields(text)
+ text = self._remove_front_matter(text)
+ metadata = {
+ "source": str(path.name),
+ "path": str(path),
+ "created": path.stat().st_ctime,
+ "last_modified": path.stat().st_mtime,
+ "last_accessed": path.stat().st_atime,
+ **self._to_langchain_compatible_metadata(front_matter),
+ **dataview_fields,
+ }
+
+ if tags or front_matter.get("tags"):
+ metadata["tags"] = ",".join(
+ tags | set(front_matter.get("tags", []) or [])
+ )
+
+ docs.append(Document(page_content=text, metadata=metadata))
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/odt.py b/libs/community/langchain_community/document_loaders/odt.py
new file mode 100644
index 00000000000..6d2cc3474ea
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/odt.py
@@ -0,0 +1,50 @@
+from typing import Any, List
+
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ validate_unstructured_version,
+)
+
+
+class UnstructuredODTLoader(UnstructuredFileLoader):
+ """Load `OpenOffice ODT` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredODTLoader
+
+ loader = UnstructuredODTLoader(
+ "example.odt", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-odt
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ """
+
+ Args:
+ file_path: The path to the file to load.
+ mode: The mode to use when loading the file. Can be one of "single",
+ "multi", or "all". Default is "single".
+ **unstructured_kwargs: Any kwargs to pass to the unstructured.
+ """
+ validate_unstructured_version(min_unstructured_version="0.6.3")
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.odt import partition_odt
+
+ return partition_odt(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/onedrive.py b/libs/community/langchain_community/document_loaders/onedrive.py
new file mode 100644
index 00000000000..3d899d0941f
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/onedrive.py
@@ -0,0 +1,97 @@
+"""Loads data from OneDrive"""
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Union
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.document_loaders.base_o365 import (
+ O365BaseLoader,
+ _FileType,
+)
+from langchain_community.document_loaders.parsers.registry import get_parser
+
+if TYPE_CHECKING:
+ from O365.drive import Drive, Folder
+
+logger = logging.getLogger(__name__)
+
+
+class OneDriveLoader(O365BaseLoader):
+ """Load from `Microsoft OneDrive`."""
+
+ drive_id: str = Field(...)
+ """ The ID of the OneDrive drive to load data from."""
+ folder_path: Optional[str] = None
+ """ The path to the folder to load data from."""
+ object_ids: Optional[List[str]] = None
+ """ The IDs of the objects to load data from."""
+
+ @property
+ def _file_types(self) -> Sequence[_FileType]:
+ """Return supported file types."""
+ return _FileType.DOC, _FileType.DOCX, _FileType.PDF
+
+ @property
+ def _scopes(self) -> List[str]:
+ """Return required scopes."""
+ return ["offline_access", "Files.Read.All"]
+
+ def _get_folder_from_path(self, drive: Drive) -> Union[Folder, Drive]:
+ """
+ Returns the folder or drive object located at the
+ specified path relative to the given drive.
+
+ Args:
+ drive (Drive): The root drive from which the folder path is relative.
+
+ Returns:
+ Union[Folder, Drive]: The folder or drive object
+ located at the specified path.
+
+ Raises:
+ FileNotFoundError: If the path does not exist.
+ """
+
+ subfolder_drive = drive
+ if self.folder_path is None:
+ return subfolder_drive
+
+ subfolders = [f for f in self.folder_path.split("/") if f != ""]
+ if len(subfolders) == 0:
+ return subfolder_drive
+
+ items = subfolder_drive.get_items()
+ for subfolder in subfolders:
+ try:
+ subfolder_drive = list(filter(lambda x: subfolder in x.name, items))[0]
+ items = subfolder_drive.get_items()
+ except (IndexError, AttributeError):
+ raise FileNotFoundError("Path {} not exist.".format(self.folder_path))
+ return subfolder_drive
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Load documents lazily. Use this when working at a large scale."""
+ try:
+ from O365.drive import Drive
+ except ImportError:
+ raise ImportError(
+ "O365 package not found, please install it with `pip install o365`"
+ )
+ drive = self._auth().storage().get_drive(self.drive_id)
+ if not isinstance(drive, Drive):
+ raise ValueError(f"There isn't a Drive with id {self.drive_id}.")
+ blob_parser = get_parser("default")
+ if self.folder_path:
+ folder = self._get_folder_from_path(drive)
+ for blob in self._load_from_folder(folder):
+ yield from blob_parser.lazy_parse(blob)
+ if self.object_ids:
+ for blob in self._load_from_object_ids(drive, self.object_ids):
+ yield from blob_parser.lazy_parse(blob)
+
+ def load(self) -> List[Document]:
+ """Load all documents."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/onedrive_file.py b/libs/community/langchain_community/document_loaders/onedrive_file.py
new file mode 100644
index 00000000000..2678b04a057
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/onedrive_file.py
@@ -0,0 +1,35 @@
+from __future__ import annotations
+
+import tempfile
+from typing import TYPE_CHECKING, List
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+if TYPE_CHECKING:
+ from O365.drive import File
+
+CHUNK_SIZE = 1024 * 1024 * 5
+
+
+class OneDriveFileLoader(BaseLoader, BaseModel):
+ """Load a file from `Microsoft OneDrive`."""
+
+ file: File = Field(...)
+ """The file to load."""
+
+ class Config:
+ arbitrary_types_allowed = True
+ """Allow arbitrary types. This is needed for the File type. Default is True.
+ See https://pydantic-docs.helpmanual.io/usage/types/#arbitrary-types-allowed"""
+
+ def load(self) -> List[Document]:
+ """Load Documents"""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ file_path = f"{temp_dir}/{self.file.name}"
+ self.file.download(to_path=temp_dir, chunk_size=CHUNK_SIZE)
+ loader = UnstructuredFileLoader(file_path)
+ return loader.load()
diff --git a/libs/community/langchain_community/document_loaders/onenote.py b/libs/community/langchain_community/document_loaders/onenote.py
new file mode 100644
index 00000000000..2c96eb10e3d
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/onenote.py
@@ -0,0 +1,222 @@
+"""Loads data from OneNote Notebooks"""
+from pathlib import Path
+from typing import Dict, Iterator, List, Optional
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import (
+ BaseModel,
+ BaseSettings,
+ Field,
+ FilePath,
+ SecretStr,
+)
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class _OneNoteGraphSettings(BaseSettings):
+ client_id: str = Field(..., env="MS_GRAPH_CLIENT_ID")
+ client_secret: SecretStr = Field(..., env="MS_GRAPH_CLIENT_SECRET")
+
+ class Config:
+ """Config for OneNoteGraphSettings."""
+
+ env_prefix = ""
+ case_sentive = False
+ env_file = ".env"
+
+
+class OneNoteLoader(BaseLoader, BaseModel):
+ """Load pages from OneNote notebooks."""
+
+ settings: _OneNoteGraphSettings = Field(default_factory=_OneNoteGraphSettings)
+ """Settings for the Microsoft Graph API client."""
+ auth_with_token: bool = False
+ """Whether to authenticate with a token or not. Defaults to False."""
+ access_token: str = ""
+ """Personal access token"""
+ onenote_api_base_url: str = "https://graph.microsoft.com/v1.0/me/onenote"
+ """URL of Microsoft Graph API for OneNote"""
+ authority_url = "https://login.microsoftonline.com/consumers/"
+ """A URL that identifies a token authority"""
+ token_path: FilePath = Path.home() / ".credentials" / "onenote_graph_token.txt"
+ """Path to the file where the access token is stored"""
+ notebook_name: Optional[str] = None
+ """Filter on notebook name"""
+ section_name: Optional[str] = None
+ """Filter on section name"""
+ page_title: Optional[str] = None
+ """Filter on section name"""
+ object_ids: Optional[List[str]] = None
+ """ The IDs of the objects to load data from."""
+
+ def lazy_load(self) -> Iterator[Document]:
+ """
+ Get pages from OneNote notebooks.
+
+ Returns:
+ A list of Documents with attributes:
+ - page_content
+ - metadata
+ - title
+ """
+ self._auth()
+
+ try:
+ from bs4 import BeautifulSoup
+ except ImportError:
+ raise ImportError(
+ "beautifulsoup4 package not found, please install it with "
+ "`pip install bs4`"
+ )
+
+ if self.object_ids is not None:
+ for object_id in self.object_ids:
+ page_content_html = self._get_page_content(object_id)
+ soup = BeautifulSoup(page_content_html, "html.parser")
+ page_title = ""
+ title_tag = soup.title
+ if title_tag:
+ page_title = title_tag.get_text(strip=True)
+ page_content = soup.get_text(separator="\n", strip=True)
+ yield Document(
+ page_content=page_content, metadata={"title": page_title}
+ )
+ else:
+ request_url = self._url
+
+ while request_url != "":
+ response = requests.get(request_url, headers=self._headers, timeout=10)
+ response.raise_for_status()
+ pages = response.json()
+
+ for page in pages["value"]:
+ page_id = page["id"]
+ page_content_html = self._get_page_content(page_id)
+ soup = BeautifulSoup(page_content_html, "html.parser")
+ page_title = ""
+ title_tag = soup.title
+ if title_tag:
+ page_content = soup.get_text(separator="\n", strip=True)
+ yield Document(
+ page_content=page_content, metadata={"title": page_title}
+ )
+
+ if "@odata.nextLink" in pages:
+ request_url = pages["@odata.nextLink"]
+ else:
+ request_url = ""
+
+ def load(self) -> List[Document]:
+ """
+ Get pages from OneNote notebooks.
+
+ Returns:
+ A list of Documents with attributes:
+ - page_content
+ - metadata
+ - title
+ """
+ return list(self.lazy_load())
+
+ def _get_page_content(self, page_id: str) -> str:
+ """Get page content from OneNote API"""
+ request_url = self.onenote_api_base_url + f"/pages/{page_id}/content"
+ response = requests.get(request_url, headers=self._headers, timeout=10)
+ response.raise_for_status()
+ return response.text
+
+ @property
+ def _headers(self) -> Dict[str, str]:
+ """Return headers for requests to OneNote API"""
+ return {
+ "Authorization": f"Bearer {self.access_token}",
+ }
+
+ @property
+ def _scopes(self) -> List[str]:
+ """Return required scopes."""
+ return ["Notes.Read"]
+
+ def _auth(self) -> None:
+ """Authenticate with Microsoft Graph API"""
+ if self.access_token != "":
+ return
+
+ if self.auth_with_token:
+ with self.token_path.open("r") as token_file:
+ self.access_token = token_file.read()
+ else:
+ try:
+ from msal import ConfidentialClientApplication
+ except ImportError as e:
+ raise ImportError(
+ "MSAL package not found, please install it with `pip install msal`"
+ ) from e
+
+ client_instance = ConfidentialClientApplication(
+ client_id=self.settings.client_id,
+ client_credential=self.settings.client_secret.get_secret_value(),
+ authority=self.authority_url,
+ )
+
+ authorization_request_url = client_instance.get_authorization_request_url(
+ self._scopes
+ )
+ print("Visit the following url to give consent:")
+ print(authorization_request_url)
+ authorization_url = input("Paste the authenticated url here:\n")
+
+ authorization_code = authorization_url.split("code=")[1].split("&")[0]
+ access_token_json = client_instance.acquire_token_by_authorization_code(
+ code=authorization_code, scopes=self._scopes
+ )
+ self.access_token = access_token_json["access_token"]
+
+ try:
+ if not self.token_path.parent.exists():
+ self.token_path.parent.mkdir(parents=True)
+ except Exception as e:
+ raise Exception(
+ f"Could not create the folder {self.token_path.parent} "
+ + "to store the access token."
+ ) from e
+
+ with self.token_path.open("w") as token_file:
+ token_file.write(self.access_token)
+
+ @property
+ def _url(self) -> str:
+ """Create URL for getting page ids from the OneNoteApi API."""
+ query_params_list = []
+ filter_list = []
+ expand_list = []
+
+ query_params_list.append("$select=id")
+ if self.notebook_name is not None:
+ filter_list.append(
+ "parentNotebook/displayName%20eq%20"
+ + f"'{self.notebook_name.replace(' ', '%20')}'"
+ )
+ expand_list.append("parentNotebook")
+ if self.section_name is not None:
+ filter_list.append(
+ "parentSection/displayName%20eq%20"
+ + f"'{self.section_name.replace(' ', '%20')}'"
+ )
+ expand_list.append("parentSection")
+ if self.page_title is not None:
+ filter_list.append(
+ "title%20eq%20" + f"'{self.page_title.replace(' ', '%20')}'"
+ )
+
+ if len(expand_list) > 0:
+ query_params_list.append("$expand=" + ",".join(expand_list))
+ if len(filter_list) > 0:
+ query_params_list.append("$filter=" + "%20and%20".join(filter_list))
+
+ query_params = "&".join(query_params_list)
+ if query_params != "":
+ query_params = "?" + query_params
+ return f"{self.onenote_api_base_url}/pages{query_params}"
diff --git a/libs/community/langchain_community/document_loaders/open_city_data.py b/libs/community/langchain_community/document_loaders/open_city_data.py
new file mode 100644
index 00000000000..ad75421bd38
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/open_city_data.py
@@ -0,0 +1,44 @@
+from typing import Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class OpenCityDataLoader(BaseLoader):
+ """Load from `Open City`."""
+
+ def __init__(self, city_id: str, dataset_id: str, limit: int):
+ """Initialize with dataset_id.
+ Example: https://dev.socrata.com/foundry/data.sfgov.org/vw6y-z8j6
+ e.g., city_id = data.sfgov.org
+ e.g., dataset_id = vw6y-z8j6
+
+ Args:
+ city_id: The Open City city identifier.
+ dataset_id: The Open City dataset identifier.
+ limit: The maximum number of documents to load.
+ """
+ self.city_id = city_id
+ self.dataset_id = dataset_id
+ self.limit = limit
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load records."""
+
+ from sodapy import Socrata
+
+ client = Socrata(self.city_id, None)
+ results = client.get(self.dataset_id, limit=self.limit)
+ for record in results:
+ yield Document(
+ page_content=str(record),
+ metadata={
+ "source": self.city_id + "_" + self.dataset_id,
+ },
+ )
+
+ def load(self) -> List[Document]:
+ """Load records."""
+
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/org_mode.py b/libs/community/langchain_community/document_loaders/org_mode.py
new file mode 100644
index 00000000000..e926e6f6289
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/org_mode.py
@@ -0,0 +1,50 @@
+from typing import Any, List
+
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ validate_unstructured_version,
+)
+
+
+class UnstructuredOrgModeLoader(UnstructuredFileLoader):
+ """Load `Org-Mode` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredOrgModeLoader
+
+ loader = UnstructuredOrgModeLoader(
+ "example.org", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-org
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ """
+
+ Args:
+ file_path: The path to the file to load.
+ mode: The mode to load the file from. Default is "single".
+ **unstructured_kwargs: Any additional keyword arguments to pass
+ to the unstructured.
+ """
+ validate_unstructured_version(min_unstructured_version="0.7.9")
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.org import partition_org
+
+ return partition_org(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/parsers/__init__.py b/libs/community/langchain_community/document_loaders/parsers/__init__.py
new file mode 100644
index 00000000000..c7bd6d73dff
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/__init__.py
@@ -0,0 +1,25 @@
+from langchain_community.document_loaders.parsers.audio import OpenAIWhisperParser
+from langchain_community.document_loaders.parsers.docai import DocAIParser
+from langchain_community.document_loaders.parsers.grobid import GrobidParser
+from langchain_community.document_loaders.parsers.html import BS4HTMLParser
+from langchain_community.document_loaders.parsers.language import LanguageParser
+from langchain_community.document_loaders.parsers.pdf import (
+ PDFMinerParser,
+ PDFPlumberParser,
+ PyMuPDFParser,
+ PyPDFium2Parser,
+ PyPDFParser,
+)
+
+__all__ = [
+ "BS4HTMLParser",
+ "DocAIParser",
+ "GrobidParser",
+ "LanguageParser",
+ "OpenAIWhisperParser",
+ "PDFMinerParser",
+ "PDFPlumberParser",
+ "PyMuPDFParser",
+ "PyPDFium2Parser",
+ "PyPDFParser",
+]
diff --git a/libs/community/langchain_community/document_loaders/parsers/audio.py b/libs/community/langchain_community/document_loaders/parsers/audio.py
new file mode 100644
index 00000000000..ab54c67ed37
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/audio.py
@@ -0,0 +1,310 @@
+import logging
+import time
+from typing import Dict, Iterator, Optional, Tuple
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+from langchain_community.utils.openai import is_openai_v1
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIWhisperParser(BaseBlobParser):
+ """Transcribe and parse audio files.
+ Audio transcription is with OpenAI Whisper model."""
+
+ def __init__(self, api_key: Optional[str] = None):
+ self.api_key = api_key
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+
+ import io
+
+ try:
+ import openai
+ except ImportError:
+ raise ImportError(
+ "openai package not found, please install it with "
+ "`pip install openai`"
+ )
+ try:
+ from pydub import AudioSegment
+ except ImportError:
+ raise ImportError(
+ "pydub package not found, please install it with " "`pip install pydub`"
+ )
+
+ if is_openai_v1():
+ # api_key optional, defaults to `os.environ['OPENAI_API_KEY']`
+ client = openai.OpenAI(api_key=self.api_key)
+ else:
+ # Set the API key if provided
+ if self.api_key:
+ openai.api_key = self.api_key
+
+ # Audio file from disk
+ audio = AudioSegment.from_file(blob.path)
+
+ # Define the duration of each chunk in minutes
+ # Need to meet 25MB size limit for Whisper API
+ chunk_duration = 20
+ chunk_duration_ms = chunk_duration * 60 * 1000
+
+ # Split the audio into chunk_duration_ms chunks
+ for split_number, i in enumerate(range(0, len(audio), chunk_duration_ms)):
+ # Audio chunk
+ chunk = audio[i : i + chunk_duration_ms]
+ file_obj = io.BytesIO(chunk.export(format="mp3").read())
+ if blob.source is not None:
+ file_obj.name = blob.source + f"_part_{split_number}.mp3"
+ else:
+ file_obj.name = f"part_{split_number}.mp3"
+
+ # Transcribe
+ print(f"Transcribing part {split_number+1}!")
+ attempts = 0
+ while attempts < 3:
+ try:
+ if is_openai_v1():
+ transcript = client.audio.transcriptions.create(
+ model="whisper-1", file=file_obj
+ )
+ else:
+ transcript = openai.Audio.transcribe("whisper-1", file_obj)
+ break
+ except Exception as e:
+ attempts += 1
+ print(f"Attempt {attempts} failed. Exception: {str(e)}")
+ time.sleep(5)
+ else:
+ print("Failed to transcribe after 3 attempts.")
+ continue
+
+ yield Document(
+ page_content=transcript.text,
+ metadata={"source": blob.source, "chunk": split_number},
+ )
+
+
+class OpenAIWhisperParserLocal(BaseBlobParser):
+ """Transcribe and parse audio files with OpenAI Whisper model.
+
+ Audio transcription with OpenAI Whisper model locally from transformers.
+
+ Parameters:
+ device - device to use
+ NOTE: By default uses the gpu if available,
+ if you want to use cpu, please set device = "cpu"
+ lang_model - whisper model to use, for example "openai/whisper-medium"
+ forced_decoder_ids - id states for decoder in multilanguage model,
+ usage example:
+ from transformers import WhisperProcessor
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
+ forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french",
+ task="transcribe")
+ forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french",
+ task="translate")
+
+
+
+ """
+
+ def __init__(
+ self,
+ device: str = "0",
+ lang_model: Optional[str] = None,
+ forced_decoder_ids: Optional[Tuple[Dict]] = None,
+ ):
+ """Initialize the parser.
+
+ Args:
+ device: device to use.
+ lang_model: whisper model to use, for example "openai/whisper-medium".
+ Defaults to None.
+ forced_decoder_ids: id states for decoder in a multilanguage model.
+ Defaults to None.
+ """
+ try:
+ from transformers import pipeline
+ except ImportError:
+ raise ImportError(
+ "transformers package not found, please install it with "
+ "`pip install transformers`"
+ )
+ try:
+ import torch
+ except ImportError:
+ raise ImportError(
+ "torch package not found, please install it with " "`pip install torch`"
+ )
+
+ # set device, cpu by default check if there is a GPU available
+ if device == "cpu":
+ self.device = "cpu"
+ if lang_model is not None:
+ self.lang_model = lang_model
+ print("WARNING! Model override. Using model: ", self.lang_model)
+ else:
+ # unless overridden, use the small base model on cpu
+ self.lang_model = "openai/whisper-base"
+ else:
+ if torch.cuda.is_available():
+ self.device = "cuda:0"
+ # check GPU memory and select automatically the model
+ mem = torch.cuda.get_device_properties(self.device).total_memory / (
+ 1024**2
+ )
+ if mem < 5000:
+ rec_model = "openai/whisper-base"
+ elif mem < 7000:
+ rec_model = "openai/whisper-small"
+ elif mem < 12000:
+ rec_model = "openai/whisper-medium"
+ else:
+ rec_model = "openai/whisper-large"
+
+ # check if model is overridden
+ if lang_model is not None:
+ self.lang_model = lang_model
+ print("WARNING! Model override. Might not fit in your GPU")
+ else:
+ self.lang_model = rec_model
+ else:
+ "cpu"
+
+ print("Using the following model: ", self.lang_model)
+
+ # load model for inference
+ self.pipe = pipeline(
+ "automatic-speech-recognition",
+ model=self.lang_model,
+ chunk_length_s=30,
+ device=self.device,
+ )
+ if forced_decoder_ids is not None:
+ try:
+ self.pipe.model.config.forced_decoder_ids = forced_decoder_ids
+ except Exception as exception_text:
+ logger.info(
+ "Unable to set forced_decoder_ids parameter for whisper model"
+ f"Text of exception: {exception_text}"
+ "Therefore whisper model will use default mode for decoder"
+ )
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+
+ import io
+
+ try:
+ from pydub import AudioSegment
+ except ImportError:
+ raise ImportError(
+ "pydub package not found, please install it with `pip install pydub`"
+ )
+
+ try:
+ import librosa
+ except ImportError:
+ raise ImportError(
+ "librosa package not found, please install it with "
+ "`pip install librosa`"
+ )
+
+ # Audio file from disk
+ audio = AudioSegment.from_file(blob.path)
+
+ file_obj = io.BytesIO(audio.export(format="mp3").read())
+
+ # Transcribe
+ print(f"Transcribing part {blob.path}!")
+
+ y, sr = librosa.load(file_obj, sr=16000)
+
+ prediction = self.pipe(y.copy(), batch_size=8)["text"]
+
+ yield Document(
+ page_content=prediction,
+ metadata={"source": blob.source},
+ )
+
+
+class YandexSTTParser(BaseBlobParser):
+ """Transcribe and parse audio files.
+ Audio transcription is with OpenAI Whisper model."""
+
+ def __init__(
+ self,
+ *,
+ api_key: Optional[str] = None,
+ iam_token: Optional[str] = None,
+ model: str = "general",
+ language: str = "auto",
+ ):
+ """Initialize the parser.
+
+ Args:
+ api_key: API key for a service account
+ with the `ai.speechkit-stt.user` role.
+ iam_token: IAM token for a service account
+ with the `ai.speechkit-stt.user` role.
+ model: Recognition model name.
+ Defaults to general.
+ language: The language in ISO 639-1 format.
+ Defaults to automatic language recognition.
+ Either `api_key` or `iam_token` must be provided, but not both.
+ """
+ if (api_key is None) == (iam_token is None):
+ raise ValueError(
+ "Either 'api_key' or 'iam_token' must be provided, but not both."
+ )
+ self.api_key = api_key
+ self.iam_token = iam_token
+ self.model = model
+ self.language = language
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+
+ try:
+ from speechkit import configure_credentials, creds, model_repository
+ from speechkit.stt import AudioProcessingType
+ except ImportError:
+ raise ImportError(
+ "yandex-speechkit package not found, please install it with "
+ "`pip install yandex-speechkit`"
+ )
+ try:
+ from pydub import AudioSegment
+ except ImportError:
+ raise ImportError(
+ "pydub package not found, please install it with " "`pip install pydub`"
+ )
+
+ if self.api_key:
+ configure_credentials(
+ yandex_credentials=creds.YandexCredentials(api_key=self.api_key)
+ )
+ else:
+ configure_credentials(
+ yandex_credentials=creds.YandexCredentials(iam_token=self.iam_token)
+ )
+
+ audio = AudioSegment.from_file(blob.path)
+
+ model = model_repository.recognition_model()
+
+ model.model = self.model
+ model.language = self.language
+ model.audio_processing_type = AudioProcessingType.Full
+
+ result = model.transcribe(audio)
+
+ for res in result:
+ yield Document(
+ page_content=res.normalized_text,
+ metadata={"source": blob.source},
+ )
diff --git a/libs/community/langchain_community/document_loaders/parsers/docai.py b/libs/community/langchain_community/document_loaders/parsers/docai.py
new file mode 100644
index 00000000000..ca2bc8e5350
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/docai.py
@@ -0,0 +1,388 @@
+"""Module contains a PDF parser based on Document AI from Google Cloud.
+
+You need to install two libraries to use this parser:
+pip install google-cloud-documentai
+pip install google-cloud-documentai-toolbox
+"""
+import logging
+import re
+import time
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence
+
+from langchain_core.documents import Document
+from langchain_core.utils.iter import batch_iterate
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+from langchain_community.utilities.vertexai import get_client_info
+
+if TYPE_CHECKING:
+ from google.api_core.operation import Operation
+ from google.cloud.documentai import DocumentProcessorServiceClient
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class DocAIParsingResults:
+ """A dataclass to store Document AI parsing results."""
+
+ source_path: str
+ parsed_path: str
+
+
+class DocAIParser(BaseBlobParser):
+ """`Google Cloud Document AI` parser.
+
+ For a detailed explanation of Document AI, refer to the product documentation.
+ https://cloud.google.com/document-ai/docs/overview
+ """
+
+ def __init__(
+ self,
+ *,
+ client: Optional["DocumentProcessorServiceClient"] = None,
+ location: Optional[str] = None,
+ gcs_output_path: Optional[str] = None,
+ processor_name: Optional[str] = None,
+ ):
+ """Initializes the parser.
+
+ Args:
+ client: a DocumentProcessorServiceClient to use
+ location: a Google Cloud location where a Document AI processor is located
+ gcs_output_path: a path on Google Cloud Storage to store parsing results
+ processor_name: full resource name of a Document AI processor or processor
+ version
+
+ You should provide either a client or location (and then a client
+ would be instantiated).
+ """
+
+ if bool(client) == bool(location):
+ raise ValueError(
+ "You must specify either a client or a location to instantiate "
+ "a client."
+ )
+
+ pattern = r"projects\/[0-9]+\/locations\/[a-z\-0-9]+\/processors\/[a-z0-9]+"
+ if processor_name and not re.fullmatch(pattern, processor_name):
+ raise ValueError(
+ f"Processor name {processor_name} has the wrong format. If your "
+ "prediction endpoint looks like https://us-documentai.googleapis.com"
+ "/v1/projects/PROJECT_ID/locations/us/processors/PROCESSOR_ID:process,"
+ " use only projects/PROJECT_ID/locations/us/processors/PROCESSOR_ID "
+ "part."
+ )
+
+ self._gcs_output_path = gcs_output_path
+ self._processor_name = processor_name
+ if client:
+ self._client = client
+ else:
+ try:
+ from google.api_core.client_options import ClientOptions
+ from google.cloud.documentai import DocumentProcessorServiceClient
+ except ImportError as exc:
+ raise ImportError(
+ "documentai package not found, please install it with"
+ " `pip install google-cloud-documentai`"
+ ) from exc
+ options = ClientOptions(
+ api_endpoint=f"{location}-documentai.googleapis.com"
+ )
+ self._client = DocumentProcessorServiceClient(
+ client_options=options,
+ client_info=get_client_info(module="document-ai"),
+ )
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Parses a blob lazily.
+
+ Args:
+ blobs: a Blob to parse
+
+ This is a long-running operation. A recommended way is to batch
+ documents together and use the `batch_parse()` method.
+ """
+ yield from self.batch_parse([blob], gcs_output_path=self._gcs_output_path)
+
+ def online_process(
+ self,
+ blob: Blob,
+ enable_native_pdf_parsing: bool = True,
+ field_mask: Optional[str] = None,
+ page_range: Optional[List[int]] = None,
+ ) -> Iterator[Document]:
+ """Parses a blob lazily using online processing.
+
+ Args:
+ blob: a blob to parse.
+ enable_native_pdf_parsing: enable pdf embedded text extraction
+ field_mask: a comma-separated list of which fields to include in the
+ Document AI response.
+ suggested: "text,pages.pageNumber,pages.layout"
+ page_range: list of page numbers to parse. If `None`,
+ entire document will be parsed.
+ """
+ try:
+ from google.cloud import documentai
+ from google.cloud.documentai_v1.types import (
+ IndividualPageSelector,
+ OcrConfig,
+ ProcessOptions,
+ )
+ except ImportError as exc:
+ raise ImportError(
+ "documentai package not found, please install it with"
+ " `pip install google-cloud-documentai`"
+ ) from exc
+ try:
+ from google.cloud.documentai_toolbox.wrappers.page import _text_from_layout
+ except ImportError as exc:
+ raise ImportError(
+ "documentai_toolbox package not found, please install it with"
+ " `pip install google-cloud-documentai-toolbox`"
+ ) from exc
+ ocr_config = (
+ OcrConfig(enable_native_pdf_parsing=enable_native_pdf_parsing)
+ if enable_native_pdf_parsing
+ else None
+ )
+ individual_page_selector = (
+ IndividualPageSelector(pages=page_range) if page_range else None
+ )
+
+ response = self._client.process_document(
+ documentai.ProcessRequest(
+ name=self._processor_name,
+ gcs_document=documentai.GcsDocument(
+ gcs_uri=blob.path,
+ mime_type=blob.mimetype or "application/pdf",
+ ),
+ process_options=ProcessOptions(
+ ocr_config=ocr_config,
+ individual_page_selector=individual_page_selector,
+ ),
+ skip_human_review=True,
+ field_mask=field_mask,
+ )
+ )
+ yield from (
+ Document(
+ page_content=_text_from_layout(page.layout, response.document.text),
+ metadata={
+ "page": page.page_number,
+ "source": blob.path,
+ },
+ )
+ for page in response.document.pages
+ )
+
+ def batch_parse(
+ self,
+ blobs: Sequence[Blob],
+ gcs_output_path: Optional[str] = None,
+ timeout_sec: int = 3600,
+ check_in_interval_sec: int = 60,
+ ) -> Iterator[Document]:
+ """Parses a list of blobs lazily.
+
+ Args:
+ blobs: a list of blobs to parse.
+ gcs_output_path: a path on Google Cloud Storage to store parsing results.
+ timeout_sec: a timeout to wait for Document AI to complete, in seconds.
+ check_in_interval_sec: an interval to wait until next check
+ whether parsing operations have been completed, in seconds
+ This is a long-running operation. A recommended way is to decouple
+ parsing from creating LangChain Documents:
+ >>> operations = parser.docai_parse(blobs, gcs_path)
+ >>> parser.is_running(operations)
+ You can get operations names and save them:
+ >>> names = [op.operation.name for op in operations]
+ And when all operations are finished, you can use their results:
+ >>> operations = parser.operations_from_names(operation_names)
+ >>> results = parser.get_results(operations)
+ >>> docs = parser.parse_from_results(results)
+ """
+ output_path = gcs_output_path or self._gcs_output_path
+ if not output_path:
+ raise ValueError(
+ "An output path on Google Cloud Storage should be provided."
+ )
+ operations = self.docai_parse(blobs, gcs_output_path=output_path)
+ operation_names = [op.operation.name for op in operations]
+ logger.debug(
+ "Started parsing with Document AI, submitted operations %s", operation_names
+ )
+ time_elapsed = 0
+ while self.is_running(operations):
+ time.sleep(check_in_interval_sec)
+ time_elapsed += check_in_interval_sec
+ if time_elapsed > timeout_sec:
+ raise TimeoutError(
+ "Timeout exceeded! Check operations " f"{operation_names} later!"
+ )
+ logger.debug(".")
+
+ results = self.get_results(operations=operations)
+ yield from self.parse_from_results(results)
+
+ def parse_from_results(
+ self, results: List[DocAIParsingResults]
+ ) -> Iterator[Document]:
+ try:
+ from google.cloud.documentai_toolbox.utilities.gcs_utilities import (
+ split_gcs_uri,
+ )
+ from google.cloud.documentai_toolbox.wrappers.document import _get_shards
+ from google.cloud.documentai_toolbox.wrappers.page import _text_from_layout
+ except ImportError as exc:
+ raise ImportError(
+ "documentai_toolbox package not found, please install it with"
+ " `pip install google-cloud-documentai-toolbox`"
+ ) from exc
+ for result in results:
+ gcs_bucket_name, gcs_prefix = split_gcs_uri(result.parsed_path)
+ shards = _get_shards(gcs_bucket_name, gcs_prefix)
+ yield from (
+ Document(
+ page_content=_text_from_layout(page.layout, shard.text),
+ metadata={"page": page.page_number, "source": result.source_path},
+ )
+ for shard in shards
+ for page in shard.pages
+ )
+
+ def operations_from_names(self, operation_names: List[str]) -> List["Operation"]:
+ """Initializes Long-Running Operations from their names."""
+ try:
+ from google.longrunning.operations_pb2 import (
+ GetOperationRequest, # type: ignore
+ )
+ except ImportError as exc:
+ raise ImportError(
+ "long running operations package not found, please install it with"
+ " `pip install gapic-google-longrunning`"
+ ) from exc
+
+ return [
+ self._client.get_operation(request=GetOperationRequest(name=name))
+ for name in operation_names
+ ]
+
+ def is_running(self, operations: List["Operation"]) -> bool:
+ return any(not op.done() for op in operations)
+
+ def docai_parse(
+ self,
+ blobs: Sequence[Blob],
+ *,
+ gcs_output_path: Optional[str] = None,
+ processor_name: Optional[str] = None,
+ batch_size: int = 1000,
+ enable_native_pdf_parsing: bool = True,
+ field_mask: Optional[str] = None,
+ ) -> List["Operation"]:
+ """Runs Google Document AI PDF Batch Processing on a list of blobs.
+
+ Args:
+ blobs: a list of blobs to be parsed
+ gcs_output_path: a path (folder) on GCS to store results
+ processor_name: name of a Document AI processor.
+ batch_size: amount of documents per batch
+ enable_native_pdf_parsing: a config option for the parser
+ field_mask: a comma-separated list of which fields to include in the
+ Document AI response.
+ suggested: "text,pages.pageNumber,pages.layout"
+
+ Document AI has a 1000 file limit per batch, so batches larger than that need
+ to be split into multiple requests.
+ Batch processing is an async long-running operation
+ and results are stored in a output GCS bucket.
+ """
+ try:
+ from google.cloud import documentai
+ from google.cloud.documentai_v1.types import OcrConfig, ProcessOptions
+ except ImportError as exc:
+ raise ImportError(
+ "documentai package not found, please install it with"
+ " `pip install google-cloud-documentai`"
+ ) from exc
+
+ output_path = gcs_output_path or self._gcs_output_path
+ if output_path is None:
+ raise ValueError(
+ "An output path on Google Cloud Storage should be provided."
+ )
+ processor_name = processor_name or self._processor_name
+ if processor_name is None:
+ raise ValueError("A Document AI processor name should be provided.")
+
+ operations = []
+ for batch in batch_iterate(size=batch_size, iterable=blobs):
+ input_config = documentai.BatchDocumentsInputConfig(
+ gcs_documents=documentai.GcsDocuments(
+ documents=[
+ documentai.GcsDocument(
+ gcs_uri=blob.path,
+ mime_type=blob.mimetype or "application/pdf",
+ )
+ for blob in batch
+ ]
+ )
+ )
+
+ output_config = documentai.DocumentOutputConfig(
+ gcs_output_config=documentai.DocumentOutputConfig.GcsOutputConfig(
+ gcs_uri=output_path, field_mask=field_mask
+ )
+ )
+
+ process_options = (
+ ProcessOptions(
+ ocr_config=OcrConfig(
+ enable_native_pdf_parsing=enable_native_pdf_parsing
+ )
+ )
+ if enable_native_pdf_parsing
+ else None
+ )
+ operations.append(
+ self._client.batch_process_documents(
+ documentai.BatchProcessRequest(
+ name=processor_name,
+ input_documents=input_config,
+ document_output_config=output_config,
+ process_options=process_options,
+ skip_human_review=True,
+ )
+ )
+ )
+ return operations
+
+ def get_results(self, operations: List["Operation"]) -> List[DocAIParsingResults]:
+ try:
+ from google.cloud.documentai_v1 import BatchProcessMetadata
+ except ImportError as exc:
+ raise ImportError(
+ "documentai package not found, please install it with"
+ " `pip install google-cloud-documentai`"
+ ) from exc
+
+ return [
+ DocAIParsingResults(
+ source_path=status.input_gcs_source,
+ parsed_path=status.output_gcs_destination,
+ )
+ for op in operations
+ for status in (
+ op.metadata.individual_process_statuses
+ if isinstance(op.metadata, BatchProcessMetadata)
+ else BatchProcessMetadata.deserialize(
+ op.metadata.value
+ ).individual_process_statuses
+ )
+ ]
diff --git a/libs/community/langchain_community/document_loaders/parsers/generic.py b/libs/community/langchain_community/document_loaders/parsers/generic.py
new file mode 100644
index 00000000000..6b6b91b93ee
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/generic.py
@@ -0,0 +1,70 @@
+"""Code for generic / auxiliary parsers.
+
+This module contains some logic to help assemble more sophisticated parsers.
+"""
+from typing import Iterator, Mapping, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders.schema import Blob
+
+
+class MimeTypeBasedParser(BaseBlobParser):
+ """Parser that uses `mime`-types to parse a blob.
+
+ This parser is useful for simple pipelines where the mime-type is sufficient
+ to determine how to parse a blob.
+
+ To use, configure handlers based on mime-types and pass them to the initializer.
+
+ Example:
+
+ .. code-block:: python
+
+ from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser
+
+ parser = MimeTypeBasedParser(
+ handlers={
+ "application/pdf": ...,
+ },
+ fallback_parser=...,
+ )
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ handlers: Mapping[str, BaseBlobParser],
+ *,
+ fallback_parser: Optional[BaseBlobParser] = None,
+ ) -> None:
+ """Define a parser that uses mime-types to determine how to parse a blob.
+
+ Args:
+ handlers: A mapping from mime-types to functions that take a blob, parse it
+ and return a document.
+ fallback_parser: A fallback_parser parser to use if the mime-type is not
+ found in the handlers. If provided, this parser will be
+ used to parse blobs with all mime-types not found in
+ the handlers.
+ If not provided, a ValueError will be raised if the
+ mime-type is not found in the handlers.
+ """
+ self.handlers = handlers
+ self.fallback_parser = fallback_parser
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Load documents from a blob."""
+ mimetype = blob.mimetype
+
+ if mimetype is None:
+ raise ValueError(f"{blob} does not have a mimetype.")
+
+ if mimetype in self.handlers:
+ handler = self.handlers[mimetype]
+ yield from handler.lazy_parse(blob)
+ else:
+ if self.fallback_parser is not None:
+ yield from self.fallback_parser.lazy_parse(blob)
+ else:
+ raise ValueError(f"Unsupported mime type: {mimetype}")
diff --git a/libs/community/langchain_community/document_loaders/parsers/grobid.py b/libs/community/langchain_community/document_loaders/parsers/grobid.py
new file mode 100644
index 00000000000..8eb9974479c
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/grobid.py
@@ -0,0 +1,148 @@
+import logging
+from typing import Dict, Iterator, List, Union
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+
+logger = logging.getLogger(__name__)
+
+
+class ServerUnavailableException(Exception):
+ """Exception raised when the Grobid server is unavailable."""
+
+ pass
+
+
+class GrobidParser(BaseBlobParser):
+ """Load article `PDF` files using `Grobid`."""
+
+ def __init__(
+ self,
+ segment_sentences: bool,
+ grobid_server: str = "http://localhost:8070/api/processFulltextDocument",
+ ) -> None:
+ self.segment_sentences = segment_sentences
+ self.grobid_server = grobid_server
+ try:
+ requests.get(grobid_server)
+ except requests.exceptions.RequestException:
+ logger.error(
+ "GROBID server does not appear up and running, \
+ please ensure Grobid is installed and the server is running"
+ )
+ raise ServerUnavailableException
+
+ def process_xml(
+ self, file_path: str, xml_data: str, segment_sentences: bool
+ ) -> Iterator[Document]:
+ """Process the XML file from Grobin."""
+
+ try:
+ from bs4 import BeautifulSoup
+ except ImportError:
+ raise ImportError(
+ "`bs4` package not found, please install it with " "`pip install bs4`"
+ )
+ soup = BeautifulSoup(xml_data, "xml")
+ sections = soup.find_all("div")
+ title = soup.find_all("title")[0].text
+ chunks = []
+ for section in sections:
+ sect = section.find("head")
+ if sect is not None:
+ for i, paragraph in enumerate(section.find_all("p")):
+ chunk_bboxes = []
+ paragraph_text = []
+ for i, sentence in enumerate(paragraph.find_all("s")):
+ paragraph_text.append(sentence.text)
+ sbboxes = []
+ for bbox in sentence.get("coords").split(";"):
+ box = bbox.split(",")
+ sbboxes.append(
+ {
+ "page": box[0],
+ "x": box[1],
+ "y": box[2],
+ "h": box[3],
+ "w": box[4],
+ }
+ )
+ chunk_bboxes.append(sbboxes)
+ if segment_sentences is True:
+ fpage, lpage = sbboxes[0]["page"], sbboxes[-1]["page"]
+ sentence_dict = {
+ "text": sentence.text,
+ "para": str(i),
+ "bboxes": [sbboxes],
+ "section_title": sect.text,
+ "section_number": sect.get("n"),
+ "pages": (fpage, lpage),
+ }
+ chunks.append(sentence_dict)
+ if segment_sentences is not True:
+ fpage, lpage = (
+ chunk_bboxes[0][0]["page"],
+ chunk_bboxes[-1][-1]["page"],
+ )
+ paragraph_dict = {
+ "text": "".join(paragraph_text),
+ "para": str(i),
+ "bboxes": chunk_bboxes,
+ "section_title": sect.text,
+ "section_number": sect.get("n"),
+ "pages": (fpage, lpage),
+ }
+ chunks.append(paragraph_dict)
+
+ yield from [
+ Document(
+ page_content=chunk["text"],
+ metadata=dict(
+ {
+ "text": str(chunk["text"]),
+ "para": str(chunk["para"]),
+ "bboxes": str(chunk["bboxes"]),
+ "pages": str(chunk["pages"]),
+ "section_title": str(chunk["section_title"]),
+ "section_number": str(chunk["section_number"]),
+ "paper_title": str(title),
+ "file_path": str(file_path),
+ }
+ ),
+ )
+ for chunk in chunks
+ ]
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ file_path = blob.source
+ if file_path is None:
+ raise ValueError("blob.source cannot be None.")
+ pdf = open(file_path, "rb")
+ files = {"input": (file_path, pdf, "application/pdf", {"Expires": "0"})}
+ try:
+ data: Dict[str, Union[str, List[str]]] = {}
+ for param in ["generateIDs", "consolidateHeader", "segmentSentences"]:
+ data[param] = "1"
+ data["teiCoordinates"] = ["head", "s"]
+ files = files or {}
+ r = requests.request(
+ "POST",
+ self.grobid_server,
+ headers=None,
+ params=None,
+ files=files,
+ data=data,
+ timeout=60,
+ )
+ xml_data = r.text
+ except requests.exceptions.ReadTimeout:
+ logger.error("GROBID server timed out. Return None.")
+ xml_data = None
+
+ if xml_data is None:
+ return iter([])
+ else:
+ return self.process_xml(file_path, xml_data, self.segment_sentences)
diff --git a/libs/community/langchain_community/document_loaders/parsers/html/__init__.py b/libs/community/langchain_community/document_loaders/parsers/html/__init__.py
new file mode 100644
index 00000000000..f59e804b30f
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/html/__init__.py
@@ -0,0 +1,3 @@
+from langchain_community.document_loaders.parsers.html.bs4 import BS4HTMLParser
+
+__all__ = ["BS4HTMLParser"]
diff --git a/libs/community/langchain_community/document_loaders/parsers/html/bs4.py b/libs/community/langchain_community/document_loaders/parsers/html/bs4.py
new file mode 100644
index 00000000000..6863b6d106f
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/html/bs4.py
@@ -0,0 +1,54 @@
+"""Loader that uses bs4 to load HTML files, enriching metadata with page title."""
+
+import logging
+from typing import Any, Dict, Iterator, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+
+logger = logging.getLogger(__name__)
+
+
+class BS4HTMLParser(BaseBlobParser):
+ """Pparse HTML files using `Beautiful Soup`."""
+
+ def __init__(
+ self,
+ *,
+ features: str = "lxml",
+ get_text_separator: str = "",
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a bs4 based HTML parser."""
+ try:
+ import bs4 # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "beautifulsoup4 package not found, please install it with "
+ "`pip install beautifulsoup4`"
+ )
+
+ self.bs_kwargs = {"features": features, **kwargs}
+ self.get_text_separator = get_text_separator
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Load HTML document into document objects."""
+ from bs4 import BeautifulSoup
+
+ with blob.as_bytes_io() as f:
+ soup = BeautifulSoup(f, **self.bs_kwargs)
+
+ text = soup.get_text(self.get_text_separator)
+
+ if soup.title:
+ title = str(soup.title.string)
+ else:
+ title = ""
+
+ metadata: Dict[str, Union[str, None]] = {
+ "source": blob.source,
+ "title": title,
+ }
+ yield Document(page_content=text, metadata=metadata)
diff --git a/libs/community/langchain_community/document_loaders/parsers/language/__init__.py b/libs/community/langchain_community/document_loaders/parsers/language/__init__.py
new file mode 100644
index 00000000000..e56cc143cfd
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/language/__init__.py
@@ -0,0 +1,5 @@
+from langchain_community.document_loaders.parsers.language.language_parser import (
+ LanguageParser,
+)
+
+__all__ = ["LanguageParser"]
diff --git a/libs/community/langchain_community/document_loaders/parsers/language/cobol.py b/libs/community/langchain_community/document_loaders/parsers/language/cobol.py
new file mode 100644
index 00000000000..3b33ab6d956
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/language/cobol.py
@@ -0,0 +1,98 @@
+import re
+from typing import Callable, List
+
+from langchain_community.document_loaders.parsers.language.code_segmenter import (
+ CodeSegmenter,
+)
+
+
+class CobolSegmenter(CodeSegmenter):
+ """Code segmenter for `COBOL`."""
+
+ PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
+ DIVISION_PATTERN = re.compile(
+ r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
+ )
+ SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
+
+ def __init__(self, code: str):
+ super().__init__(code)
+ self.source_lines: List[str] = self.code.splitlines()
+
+ def is_valid(self) -> bool:
+ # Identify presence of any division to validate COBOL code
+ return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines)
+
+ def _extract_code(self, start_idx: int, end_idx: int) -> str:
+ return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n")
+
+ def _is_relevant_code(self, line: str) -> bool:
+ """Check if a line is part of the procedure division or a relevant section."""
+ if "PROCEDURE DIVISION" in line.upper():
+ return True
+ # Add additional conditions for relevant sections if needed
+ return False
+
+ def _process_lines(self, func: Callable) -> List[str]:
+ """A generic function to process COBOL lines based on provided func."""
+ elements: List[str] = []
+ start_idx = None
+ inside_relevant_section = False
+
+ for i, line in enumerate(self.source_lines):
+ if self._is_relevant_code(line):
+ inside_relevant_section = True
+
+ if inside_relevant_section and (
+ self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
+ or self.SECTION_PATTERN.match(line.strip())
+ ):
+ if start_idx is not None:
+ func(elements, start_idx, i)
+ start_idx = i
+
+ # Handle the last element if exists
+ if start_idx is not None:
+ func(elements, start_idx, len(self.source_lines))
+
+ return elements
+
+ def extract_functions_classes(self) -> List[str]:
+ def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None:
+ elements.append(self._extract_code(start_idx, end_idx))
+
+ return self._process_lines(extract_func)
+
+ def simplify_code(self) -> str:
+ simplified_lines: List[str] = []
+ inside_relevant_section = False
+ omitted_code_added = (
+ False # To track if "* OMITTED CODE *" has been added after the last header
+ )
+
+ for line in self.source_lines:
+ is_header = (
+ "PROCEDURE DIVISION" in line
+ or "DATA DIVISION" in line
+ or "IDENTIFICATION DIVISION" in line
+ or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
+ or self.SECTION_PATTERN.match(line.strip())
+ )
+
+ if is_header:
+ inside_relevant_section = True
+ # Reset the flag since we're entering a new section/division or
+ # paragraph
+ omitted_code_added = False
+
+ if inside_relevant_section:
+ if is_header:
+ # Add header and reset the omitted code added flag
+ simplified_lines.append(line)
+ elif not omitted_code_added:
+ # Add omitted code comment only if it hasn't been added directly
+ # after the last header
+ simplified_lines.append("* OMITTED CODE *")
+ omitted_code_added = True
+
+ return "\n".join(simplified_lines)
diff --git a/libs/community/langchain_community/document_loaders/parsers/language/code_segmenter.py b/libs/community/langchain_community/document_loaders/parsers/language/code_segmenter.py
new file mode 100644
index 00000000000..2efb2add448
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/language/code_segmenter.py
@@ -0,0 +1,20 @@
+from abc import ABC, abstractmethod
+from typing import List
+
+
+class CodeSegmenter(ABC):
+ """Abstract class for the code segmenter."""
+
+ def __init__(self, code: str):
+ self.code = code
+
+ def is_valid(self) -> bool:
+ return True
+
+ @abstractmethod
+ def simplify_code(self) -> str:
+ raise NotImplementedError() # pragma: no cover
+
+ @abstractmethod
+ def extract_functions_classes(self) -> List[str]:
+ raise NotImplementedError() # pragma: no cover
diff --git a/libs/community/langchain_community/document_loaders/parsers/language/javascript.py b/libs/community/langchain_community/document_loaders/parsers/language/javascript.py
new file mode 100644
index 00000000000..0f2fea68fa2
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/language/javascript.py
@@ -0,0 +1,69 @@
+from typing import Any, List
+
+from langchain_community.document_loaders.parsers.language.code_segmenter import (
+ CodeSegmenter,
+)
+
+
+class JavaScriptSegmenter(CodeSegmenter):
+ """Code segmenter for JavaScript."""
+
+ def __init__(self, code: str):
+ super().__init__(code)
+ self.source_lines = self.code.splitlines()
+
+ try:
+ import esprima # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import esprima Python package. "
+ "Please install it with `pip install esprima`."
+ )
+
+ def is_valid(self) -> bool:
+ import esprima
+
+ try:
+ esprima.parseScript(self.code)
+ return True
+ except esprima.Error:
+ return False
+
+ def _extract_code(self, node: Any) -> str:
+ start = node.loc.start.line - 1
+ end = node.loc.end.line
+ return "\n".join(self.source_lines[start:end])
+
+ def extract_functions_classes(self) -> List[str]:
+ import esprima
+
+ tree = esprima.parseScript(self.code, loc=True)
+ functions_classes = []
+
+ for node in tree.body:
+ if isinstance(
+ node,
+ (esprima.nodes.FunctionDeclaration, esprima.nodes.ClassDeclaration),
+ ):
+ functions_classes.append(self._extract_code(node))
+
+ return functions_classes
+
+ def simplify_code(self) -> str:
+ import esprima
+
+ tree = esprima.parseScript(self.code, loc=True)
+ simplified_lines = self.source_lines[:]
+
+ for node in tree.body:
+ if isinstance(
+ node,
+ (esprima.nodes.FunctionDeclaration, esprima.nodes.ClassDeclaration),
+ ):
+ start = node.loc.start.line - 1
+ simplified_lines[start] = f"// Code for: {simplified_lines[start]}"
+
+ for line_num in range(start + 1, node.loc.end.line):
+ simplified_lines[line_num] = None # type: ignore
+
+ return "\n".join(line for line in simplified_lines if line is not None)
diff --git a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py
new file mode 100644
index 00000000000..c53dd493fde
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py
@@ -0,0 +1,158 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+from langchain_community.document_loaders.parsers.language.cobol import CobolSegmenter
+from langchain_community.document_loaders.parsers.language.javascript import (
+ JavaScriptSegmenter,
+)
+from langchain_community.document_loaders.parsers.language.python import PythonSegmenter
+
+if TYPE_CHECKING:
+ from langchain.text_splitter import Language
+
+try:
+ from langchain.text_splitter import Language
+
+ LANGUAGE_EXTENSIONS: Dict[str, str] = {
+ "py": Language.PYTHON,
+ "js": Language.JS,
+ "cobol": Language.COBOL,
+ }
+
+ LANGUAGE_SEGMENTERS: Dict[str, Any] = {
+ Language.PYTHON: PythonSegmenter,
+ Language.JS: JavaScriptSegmenter,
+ Language.COBOL: CobolSegmenter,
+ }
+except ImportError:
+ LANGUAGE_EXTENSIONS = {}
+ LANGUAGE_SEGMENTERS = {}
+
+
+class LanguageParser(BaseBlobParser):
+ """Parse using the respective programming language syntax.
+
+ Each top-level function and class in the code is loaded into separate documents.
+ Furthermore, an extra document is generated, containing the remaining top-level code
+ that excludes the already segmented functions and classes.
+
+ This approach can potentially improve the accuracy of QA models over source code.
+
+ Currently, the supported languages for code parsing are Python and JavaScript.
+
+ The language used for parsing can be configured, along with the minimum number of
+ lines required to activate the splitting based on syntax.
+
+ Examples:
+
+ .. code-block:: python
+
+ from langchain.text_splitter.Language
+ from langchain_community.document_loaders.generic import GenericLoader
+ from langchain_community.document_loaders.parsers import LanguageParser
+
+ loader = GenericLoader.from_filesystem(
+ "./code",
+ glob="**/*",
+ suffixes=[".py", ".js"],
+ parser=LanguageParser()
+ )
+ docs = loader.load()
+
+ Example instantiations to manually select the language:
+
+ .. code-block:: python
+
+ from langchain.text_splitter import Language
+
+ loader = GenericLoader.from_filesystem(
+ "./code",
+ glob="**/*",
+ suffixes=[".py"],
+ parser=LanguageParser(language=Language.PYTHON)
+ )
+
+ Example instantiations to set number of lines threshold:
+
+ .. code-block:: python
+
+ loader = GenericLoader.from_filesystem(
+ "./code",
+ glob="**/*",
+ suffixes=[".py"],
+ parser=LanguageParser(parser_threshold=200)
+ )
+ """
+
+ def __init__(self, language: Optional[Language] = None, parser_threshold: int = 0):
+ """
+ Language parser that split code using the respective language syntax.
+
+ Args:
+ language: If None (default), it will try to infer language from source.
+ parser_threshold: Minimum lines needed to activate parsing (0 by default).
+ """
+ self.language = language
+ self.parser_threshold = parser_threshold
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ code = blob.as_string()
+
+ language = self.language or (
+ LANGUAGE_EXTENSIONS.get(blob.source.rsplit(".", 1)[-1])
+ if isinstance(blob.source, str)
+ else None
+ )
+
+ if language is None:
+ yield Document(
+ page_content=code,
+ metadata={
+ "source": blob.source,
+ },
+ )
+ return
+
+ if self.parser_threshold >= len(code.splitlines()):
+ yield Document(
+ page_content=code,
+ metadata={
+ "source": blob.source,
+ "language": language,
+ },
+ )
+ return
+
+ self.Segmenter = LANGUAGE_SEGMENTERS[language]
+ segmenter = self.Segmenter(blob.as_string())
+ if not segmenter.is_valid():
+ yield Document(
+ page_content=code,
+ metadata={
+ "source": blob.source,
+ },
+ )
+ return
+
+ for functions_classes in segmenter.extract_functions_classes():
+ yield Document(
+ page_content=functions_classes,
+ metadata={
+ "source": blob.source,
+ "content_type": "functions_classes",
+ "language": language,
+ },
+ )
+ yield Document(
+ page_content=segmenter.simplify_code(),
+ metadata={
+ "source": blob.source,
+ "content_type": "simplified_code",
+ "language": language,
+ },
+ )
diff --git a/libs/community/langchain_community/document_loaders/parsers/language/python.py b/libs/community/langchain_community/document_loaders/parsers/language/python.py
new file mode 100644
index 00000000000..ca810946bd4
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/language/python.py
@@ -0,0 +1,51 @@
+import ast
+from typing import Any, List
+
+from langchain_community.document_loaders.parsers.language.code_segmenter import (
+ CodeSegmenter,
+)
+
+
+class PythonSegmenter(CodeSegmenter):
+ """Code segmenter for `Python`."""
+
+ def __init__(self, code: str):
+ super().__init__(code)
+ self.source_lines = self.code.splitlines()
+
+ def is_valid(self) -> bool:
+ try:
+ ast.parse(self.code)
+ return True
+ except SyntaxError:
+ return False
+
+ def _extract_code(self, node: Any) -> str:
+ start = node.lineno - 1
+ end = node.end_lineno
+ return "\n".join(self.source_lines[start:end])
+
+ def extract_functions_classes(self) -> List[str]:
+ tree = ast.parse(self.code)
+ functions_classes = []
+
+ for node in ast.iter_child_nodes(tree):
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
+ functions_classes.append(self._extract_code(node))
+
+ return functions_classes
+
+ def simplify_code(self) -> str:
+ tree = ast.parse(self.code)
+ simplified_lines = self.source_lines[:]
+
+ for node in ast.iter_child_nodes(tree):
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
+ start = node.lineno - 1
+ simplified_lines[start] = f"# Code for: {simplified_lines[start]}"
+
+ assert isinstance(node.end_lineno, int)
+ for line_num in range(start + 1, node.end_lineno):
+ simplified_lines[line_num] = None # type: ignore
+
+ return "\n".join(line for line in simplified_lines if line is not None)
diff --git a/libs/community/langchain_community/document_loaders/parsers/msword.py b/libs/community/langchain_community/document_loaders/parsers/msword.py
new file mode 100644
index 00000000000..f2a03cc37da
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/msword.py
@@ -0,0 +1,45 @@
+from typing import Iterator
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+
+
+class MsWordParser(BaseBlobParser):
+ """Parse the Microsoft Word documents from a blob."""
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Parse a Microsoft Word document into the Document iterator.
+
+ Args:
+ blob: The blob to parse.
+
+ Returns: An iterator of Documents.
+
+ """
+ try:
+ from unstructured.partition.doc import partition_doc
+ from unstructured.partition.docx import partition_docx
+ except ImportError as e:
+ raise ImportError(
+ "Could not import unstructured, please install with `pip install "
+ "unstructured`."
+ ) from e
+
+ mime_type_parser = {
+ "application/msword": partition_doc,
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document": (
+ partition_docx
+ ),
+ }
+ if blob.mimetype not in (
+ "application/msword",
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
+ ):
+ raise ValueError("This blob type is not supported for this parser.")
+ with blob.as_bytes_io() as word_document:
+ elements = mime_type_parser[blob.mimetype](file=word_document)
+ text = "\n\n".join([str(el) for el in elements])
+ metadata = {"source": blob.source}
+ yield Document(page_content=text, metadata=metadata)
diff --git a/libs/community/langchain_community/document_loaders/parsers/pdf.py b/libs/community/langchain_community/document_loaders/parsers/pdf.py
new file mode 100644
index 00000000000..93ba53527b9
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/pdf.py
@@ -0,0 +1,573 @@
+"""Module contains common parsers for PDFs."""
+from __future__ import annotations
+
+import warnings
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Iterable,
+ Iterator,
+ Mapping,
+ Optional,
+ Sequence,
+ Union,
+)
+from urllib.parse import urlparse
+
+import numpy as np
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+
+if TYPE_CHECKING:
+ import fitz.fitz
+ import pdfminer.layout
+ import pdfplumber.page
+ import pypdf._page
+ import pypdfium2._helpers.page
+
+
+_PDF_FILTER_WITH_LOSS = ["DCTDecode", "DCT", "JPXDecode"]
+_PDF_FILTER_WITHOUT_LOSS = [
+ "LZWDecode",
+ "LZW",
+ "FlateDecode",
+ "Fl",
+ "ASCII85Decode",
+ "A85",
+ "ASCIIHexDecode",
+ "AHx",
+ "RunLengthDecode",
+ "RL",
+ "CCITTFaxDecode",
+ "CCF",
+ "JBIG2Decode",
+]
+
+
+def extract_from_images_with_rapidocr(
+ images: Sequence[Union[Iterable[np.ndarray], bytes]],
+) -> str:
+ """Extract text from images with RapidOCR.
+
+ Args:
+ images: Images to extract text from.
+
+ Returns:
+ Text extracted from images.
+
+ Raises:
+ ImportError: If `rapidocr-onnxruntime` package is not installed.
+ """
+ try:
+ from rapidocr_onnxruntime import RapidOCR
+ except ImportError:
+ raise ImportError(
+ "`rapidocr-onnxruntime` package not found, please install it with "
+ "`pip install rapidocr-onnxruntime`"
+ )
+ ocr = RapidOCR()
+ text = ""
+ for img in images:
+ result, _ = ocr(img)
+ if result:
+ result = [text[1] for text in result]
+ text += "\n".join(result)
+ return text
+
+
+class PyPDFParser(BaseBlobParser):
+ """Load `PDF` using `pypdf`"""
+
+ def __init__(
+ self, password: Optional[Union[str, bytes]] = None, extract_images: bool = False
+ ):
+ self.password = password
+ self.extract_images = extract_images
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+ import pypdf
+
+ with blob.as_bytes_io() as pdf_file_obj:
+ pdf_reader = pypdf.PdfReader(pdf_file_obj, password=self.password)
+ yield from [
+ Document(
+ page_content=page.extract_text()
+ + self._extract_images_from_page(page),
+ metadata={"source": blob.source, "page": page_number},
+ )
+ for page_number, page in enumerate(pdf_reader.pages)
+ ]
+
+ def _extract_images_from_page(self, page: pypdf._page.PageObject) -> str:
+ """Extract images from page and get the text with RapidOCR."""
+ if not self.extract_images or "/XObject" not in page["/Resources"].keys():
+ return ""
+
+ xObject = page["/Resources"]["/XObject"].get_object() # type: ignore
+ images = []
+ for obj in xObject:
+ if xObject[obj]["/Subtype"] == "/Image":
+ if xObject[obj]["/Filter"][1:] in _PDF_FILTER_WITHOUT_LOSS:
+ height, width = xObject[obj]["/Height"], xObject[obj]["/Width"]
+
+ images.append(
+ np.frombuffer(xObject[obj].get_data(), dtype=np.uint8).reshape(
+ height, width, -1
+ )
+ )
+ elif xObject[obj]["/Filter"][1:] in _PDF_FILTER_WITH_LOSS:
+ images.append(xObject[obj].get_data())
+ else:
+ warnings.warn("Unknown PDF Filter!")
+ return extract_from_images_with_rapidocr(images)
+
+
+class PDFMinerParser(BaseBlobParser):
+ """Parse `PDF` using `PDFMiner`."""
+
+ def __init__(self, extract_images: bool = False, *, concatenate_pages: bool = True):
+ """Initialize a parser based on PDFMiner.
+
+ Args:
+ extract_images: Whether to extract images from PDF.
+ concatenate_pages: If True, concatenate all PDF pages into one a single
+ document. Otherwise, return one document per page.
+ """
+ self.extract_images = extract_images
+ self.concatenate_pages = concatenate_pages
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+
+ if not self.extract_images:
+ from pdfminer.high_level import extract_text
+
+ with blob.as_bytes_io() as pdf_file_obj:
+ if self.concatenate_pages:
+ text = extract_text(pdf_file_obj)
+ metadata = {"source": blob.source}
+ yield Document(page_content=text, metadata=metadata)
+ else:
+ from pdfminer.pdfpage import PDFPage
+
+ pages = PDFPage.get_pages(pdf_file_obj)
+ for i, _ in enumerate(pages):
+ text = extract_text(pdf_file_obj, page_numbers=[i])
+ metadata = {"source": blob.source, "page": str(i)}
+ yield Document(page_content=text, metadata=metadata)
+ else:
+ import io
+
+ from pdfminer.converter import PDFPageAggregator, TextConverter
+ from pdfminer.layout import LAParams
+ from pdfminer.pdfinterp import PDFPageInterpreter, PDFResourceManager
+ from pdfminer.pdfpage import PDFPage
+
+ text_io = io.StringIO()
+ with blob.as_bytes_io() as pdf_file_obj:
+ pages = PDFPage.get_pages(pdf_file_obj)
+ rsrcmgr = PDFResourceManager()
+ device_for_text = TextConverter(rsrcmgr, text_io, laparams=LAParams())
+ device_for_image = PDFPageAggregator(rsrcmgr, laparams=LAParams())
+ interpreter_for_text = PDFPageInterpreter(rsrcmgr, device_for_text)
+ interpreter_for_image = PDFPageInterpreter(rsrcmgr, device_for_image)
+ for i, page in enumerate(pages):
+ interpreter_for_text.process_page(page)
+ interpreter_for_image.process_page(page)
+ content = text_io.getvalue() + self._extract_images_from_page(
+ device_for_image.get_result()
+ )
+ text_io.truncate(0)
+ text_io.seek(0)
+ metadata = {"source": blob.source, "page": str(i)}
+ yield Document(page_content=content, metadata=metadata)
+
+ def _extract_images_from_page(self, page: pdfminer.layout.LTPage) -> str:
+ """Extract images from page and get the text with RapidOCR."""
+ import pdfminer
+
+ def get_image(layout_object: Any) -> Any:
+ if isinstance(layout_object, pdfminer.layout.LTImage):
+ return layout_object
+ if isinstance(layout_object, pdfminer.layout.LTContainer):
+ for child in layout_object:
+ return get_image(child)
+ else:
+ return None
+
+ images = []
+
+ for img in list(filter(bool, map(get_image, page))):
+ if img.stream["Filter"].name in _PDF_FILTER_WITHOUT_LOSS:
+ images.append(
+ np.frombuffer(img.stream.get_data(), dtype=np.uint8).reshape(
+ img.stream["Height"], img.stream["Width"], -1
+ )
+ )
+ elif img.stream["Filter"].name in _PDF_FILTER_WITH_LOSS:
+ images.append(img.stream.get_data())
+ else:
+ warnings.warn("Unknown PDF Filter!")
+ return extract_from_images_with_rapidocr(images)
+
+
+class PyMuPDFParser(BaseBlobParser):
+ """Parse `PDF` using `PyMuPDF`."""
+
+ def __init__(
+ self,
+ text_kwargs: Optional[Mapping[str, Any]] = None,
+ extract_images: bool = False,
+ ) -> None:
+ """Initialize the parser.
+
+ Args:
+ text_kwargs: Keyword arguments to pass to ``fitz.Page.get_text()``.
+ """
+ self.text_kwargs = text_kwargs or {}
+ self.extract_images = extract_images
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+ import fitz
+
+ with blob.as_bytes_io() as file_path:
+ if blob.data is None:
+ doc = fitz.open(file_path)
+ else:
+ doc = fitz.open(stream=file_path, filetype="pdf")
+
+ yield from [
+ Document(
+ page_content=page.get_text(**self.text_kwargs)
+ + self._extract_images_from_page(doc, page),
+ metadata=dict(
+ {
+ "source": blob.source,
+ "file_path": blob.source,
+ "page": page.number,
+ "total_pages": len(doc),
+ },
+ **{
+ k: doc.metadata[k]
+ for k in doc.metadata
+ if type(doc.metadata[k]) in [str, int]
+ },
+ ),
+ )
+ for page in doc
+ ]
+
+ def _extract_images_from_page(
+ self, doc: fitz.fitz.Document, page: fitz.fitz.Page
+ ) -> str:
+ """Extract images from page and get the text with RapidOCR."""
+ if not self.extract_images:
+ return ""
+ import fitz
+
+ img_list = page.get_images()
+ imgs = []
+ for img in img_list:
+ xref = img[0]
+ pix = fitz.Pixmap(doc, xref)
+ imgs.append(
+ np.frombuffer(pix.samples, dtype=np.uint8).reshape(
+ pix.height, pix.width, -1
+ )
+ )
+ return extract_from_images_with_rapidocr(imgs)
+
+
+class PyPDFium2Parser(BaseBlobParser):
+ """Parse `PDF` with `PyPDFium2`."""
+
+ def __init__(self, extract_images: bool = False) -> None:
+ """Initialize the parser."""
+ try:
+ import pypdfium2 # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "pypdfium2 package not found, please install it with"
+ " `pip install pypdfium2`"
+ )
+ self.extract_images = extract_images
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+ import pypdfium2
+
+ # pypdfium2 is really finicky with respect to closing things,
+ # if done incorrectly creates seg faults.
+ with blob.as_bytes_io() as file_path:
+ pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
+ try:
+ for page_number, page in enumerate(pdf_reader):
+ text_page = page.get_textpage()
+ content = text_page.get_text_range()
+ text_page.close()
+ content += "\n" + self._extract_images_from_page(page)
+ page.close()
+ metadata = {"source": blob.source, "page": page_number}
+ yield Document(page_content=content, metadata=metadata)
+ finally:
+ pdf_reader.close()
+
+ def _extract_images_from_page(self, page: pypdfium2._helpers.page.PdfPage) -> str:
+ """Extract images from page and get the text with RapidOCR."""
+ if not self.extract_images:
+ return ""
+
+ import pypdfium2.raw as pdfium_c
+
+ images = list(page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,)))
+
+ images = list(map(lambda x: x.get_bitmap().to_numpy(), images))
+ return extract_from_images_with_rapidocr(images)
+
+
+class PDFPlumberParser(BaseBlobParser):
+ """Parse `PDF` with `PDFPlumber`."""
+
+ def __init__(
+ self,
+ text_kwargs: Optional[Mapping[str, Any]] = None,
+ dedupe: bool = False,
+ extract_images: bool = False,
+ ) -> None:
+ """Initialize the parser.
+
+ Args:
+ text_kwargs: Keyword arguments to pass to ``pdfplumber.Page.extract_text()``
+ dedupe: Avoiding the error of duplicate characters if `dedupe=True`.
+ """
+ self.text_kwargs = text_kwargs or {}
+ self.dedupe = dedupe
+ self.extract_images = extract_images
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+ import pdfplumber
+
+ with blob.as_bytes_io() as file_path:
+ doc = pdfplumber.open(file_path) # open document
+
+ yield from [
+ Document(
+ page_content=self._process_page_content(page)
+ + "\n"
+ + self._extract_images_from_page(page),
+ metadata=dict(
+ {
+ "source": blob.source,
+ "file_path": blob.source,
+ "page": page.page_number - 1,
+ "total_pages": len(doc.pages),
+ },
+ **{
+ k: doc.metadata[k]
+ for k in doc.metadata
+ if type(doc.metadata[k]) in [str, int]
+ },
+ ),
+ )
+ for page in doc.pages
+ ]
+
+ def _process_page_content(self, page: pdfplumber.page.Page) -> str:
+ """Process the page content based on dedupe."""
+ if self.dedupe:
+ return page.dedupe_chars().extract_text(**self.text_kwargs)
+ return page.extract_text(**self.text_kwargs)
+
+ def _extract_images_from_page(self, page: pdfplumber.page.Page) -> str:
+ """Extract images from page and get the text with RapidOCR."""
+ if not self.extract_images:
+ return ""
+
+ images = []
+ for img in page.images:
+ if img["stream"]["Filter"].name in _PDF_FILTER_WITHOUT_LOSS:
+ images.append(
+ np.frombuffer(img["stream"].get_data(), dtype=np.uint8).reshape(
+ img["stream"]["Height"], img["stream"]["Width"], -1
+ )
+ )
+ elif img["stream"]["Filter"].name in _PDF_FILTER_WITH_LOSS:
+ images.append(img["stream"].get_data())
+ else:
+ warnings.warn("Unknown PDF Filter!")
+
+ return extract_from_images_with_rapidocr(images)
+
+
+class AmazonTextractPDFParser(BaseBlobParser):
+ """Send `PDF` files to `Amazon Textract` and parse them.
+
+ For parsing multi-page PDFs, they have to reside on S3.
+
+ The AmazonTextractPDFLoader calls the
+ [Amazon Textract Service](https://aws.amazon.com/textract/)
+ to convert PDFs into a Document structure.
+ Single and multi-page documents are supported with up to 3000 pages
+ and 512 MB of size.
+
+ For the call to be successful an AWS account is required,
+ similar to the
+ [AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html)
+ requirements.
+
+ Besides the AWS configuration, it is very similar to the other PDF
+ loaders, while also supporting JPEG, PNG and TIFF and non-native
+ PDF formats.
+
+ ```python
+ from langchain_community.document_loaders import AmazonTextractPDFLoader
+ loader=AmazonTextractPDFLoader("example_data/alejandro_rosalez_sample-small.jpeg")
+ documents = loader.load()
+ ```
+
+ One feature is the linearization of the output.
+ When using the features LAYOUT, FORMS or TABLES together with Textract
+
+ ```python
+ from langchain_community.document_loaders import AmazonTextractPDFLoader
+ # you can mix and match each of the features
+ loader=AmazonTextractPDFLoader(
+ "example_data/alejandro_rosalez_sample-small.jpeg",
+ textract_features=["TABLES", "LAYOUT"])
+ documents = loader.load()
+ ```
+
+ it will generate output that formats the text in reading order and
+ try to output the information in a tabular structure or
+ output the key/value pairs with a colon (key: value).
+ This helps most LLMs to achieve better accuracy when
+ processing these texts.
+
+ """
+
+ def __init__(
+ self,
+ textract_features: Optional[Sequence[int]] = None,
+ client: Optional[Any] = None,
+ ) -> None:
+ """Initializes the parser.
+
+ Args:
+ textract_features: Features to be used for extraction, each feature
+ should be passed as an int that conforms to the enum
+ `Textract_Features`, see `amazon-textract-caller` pkg
+ client: boto3 textract client
+ """
+
+ try:
+ import textractcaller as tc
+ import textractor.entities.document as textractor
+
+ self.tc = tc
+ self.textractor = textractor
+
+ if textract_features is not None:
+ self.textract_features = [
+ tc.Textract_Features(f) for f in textract_features
+ ]
+ else:
+ self.textract_features = []
+ except ImportError:
+ raise ImportError(
+ "Could not import amazon-textract-caller or "
+ "amazon-textract-textractor python package. Please install it "
+ "with `pip install amazon-textract-caller` & "
+ "`pip install amazon-textract-textractor`."
+ )
+
+ if not client:
+ try:
+ import boto3
+
+ self.boto3_textract_client = boto3.client("textract")
+ except ImportError:
+ raise ImportError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ else:
+ self.boto3_textract_client = client
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Iterates over the Blob pages and returns an Iterator with a Document
+ for each page, like the other parsers If multi-page document, blob.path
+ has to be set to the S3 URI and for single page docs
+ the blob.data is taken
+ """
+
+ url_parse_result = urlparse(str(blob.path)) if blob.path else None
+ # Either call with S3 path (multi-page) or with bytes (single-page)
+ if (
+ url_parse_result
+ and url_parse_result.scheme == "s3"
+ and url_parse_result.netloc
+ ):
+ textract_response_json = self.tc.call_textract(
+ input_document=str(blob.path),
+ features=self.textract_features,
+ boto3_textract_client=self.boto3_textract_client,
+ )
+ else:
+ textract_response_json = self.tc.call_textract(
+ input_document=blob.as_bytes(),
+ features=self.textract_features,
+ call_mode=self.tc.Textract_Call_Mode.FORCE_SYNC,
+ boto3_textract_client=self.boto3_textract_client,
+ )
+
+ document = self.textractor.Document.open(textract_response_json)
+
+ linearizer_config = self.textractor.TextLinearizationConfig(
+ hide_figure_layout=True,
+ title_prefix="# ",
+ section_header_prefix="## ",
+ list_element_prefix="*",
+ )
+ for idx, page in enumerate(document.pages):
+ yield Document(
+ page_content=page.get_text(config=linearizer_config),
+ metadata={"source": blob.source, "page": idx + 1},
+ )
+
+
+class DocumentIntelligenceParser(BaseBlobParser):
+ """Loads a PDF with Azure Document Intelligence
+ (formerly Forms Recognizer) and chunks at character level."""
+
+ def __init__(self, client: Any, model: str):
+ self.client = client
+ self.model = model
+
+ def _generate_docs(self, blob: Blob, result: Any) -> Iterator[Document]:
+ for p in result.pages:
+ content = " ".join([line.content for line in p.lines])
+
+ d = Document(
+ page_content=content,
+ metadata={
+ "source": blob.source,
+ "page": p.page_number,
+ },
+ )
+ yield d
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+
+ with blob.as_bytes_io() as file_obj:
+ poller = self.client.begin_analyze_document(self.model, file_obj)
+ result = poller.result()
+
+ docs = self._generate_docs(blob, result)
+
+ yield from docs
diff --git a/libs/community/langchain_community/document_loaders/parsers/registry.py b/libs/community/langchain_community/document_loaders/parsers/registry.py
new file mode 100644
index 00000000000..7756862238a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/registry.py
@@ -0,0 +1,35 @@
+"""Module includes a registry of default parser configurations."""
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser
+from langchain_community.document_loaders.parsers.msword import MsWordParser
+from langchain_community.document_loaders.parsers.pdf import PyMuPDFParser
+from langchain_community.document_loaders.parsers.txt import TextParser
+
+
+def _get_default_parser() -> BaseBlobParser:
+ """Get default mime-type based parser."""
+ return MimeTypeBasedParser(
+ handlers={
+ "application/pdf": PyMuPDFParser(),
+ "text/plain": TextParser(),
+ "application/msword": MsWordParser(),
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document": (
+ MsWordParser()
+ ),
+ },
+ fallback_parser=None,
+ )
+
+
+_REGISTRY = {
+ "default": _get_default_parser,
+}
+
+# PUBLIC API
+
+
+def get_parser(parser_name: str) -> BaseBlobParser:
+ """Get a parser by parser name."""
+ if parser_name not in _REGISTRY:
+ raise ValueError(f"Unknown parser combination: {parser_name}")
+ return _REGISTRY[parser_name]()
diff --git a/libs/community/langchain_community/document_loaders/parsers/txt.py b/libs/community/langchain_community/document_loaders/parsers/txt.py
new file mode 100644
index 00000000000..abdfed8de5a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/parsers/txt.py
@@ -0,0 +1,15 @@
+"""Module for parsing text files.."""
+from typing import Iterator
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+
+
+class TextParser(BaseBlobParser):
+ """Parser for text blobs."""
+
+ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
+ """Lazily parse the blob."""
+ yield Document(page_content=blob.as_string(), metadata={"source": blob.source})
diff --git a/libs/community/langchain_community/document_loaders/pdf.py b/libs/community/langchain_community/document_loaders/pdf.py
new file mode 100644
index 00000000000..dc488834a6b
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/pdf.py
@@ -0,0 +1,750 @@
+import json
+import logging
+import os
+import tempfile
+import time
+from abc import ABC
+from io import StringIO
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Union
+from urllib.parse import urlparse
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.blob_loaders import Blob
+from langchain_community.document_loaders.parsers.pdf import (
+ AmazonTextractPDFParser,
+ DocumentIntelligenceParser,
+ PDFMinerParser,
+ PDFPlumberParser,
+ PyMuPDFParser,
+ PyPDFium2Parser,
+ PyPDFParser,
+)
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+logger = logging.getLogger(__file__)
+
+
+class UnstructuredPDFLoader(UnstructuredFileLoader):
+ """Load `PDF` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredPDFLoader
+
+ loader = UnstructuredPDFLoader(
+ "example.pdf", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-pdf
+ """
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.pdf import partition_pdf
+
+ return partition_pdf(filename=self.file_path, **self.unstructured_kwargs)
+
+
+class BasePDFLoader(BaseLoader, ABC):
+ """Base Loader class for `PDF` files.
+
+ If the file is a web path, it will download it to a temporary file, use it, then
+ clean up the temporary file after completion.
+ """
+
+ def __init__(self, file_path: str, *, headers: Optional[Dict] = None):
+ """Initialize with a file path.
+
+ Args:
+ file_path: Either a local, S3 or web path to a PDF file.
+ headers: Headers to use for GET request to download a file from a web path.
+ """
+ self.file_path = file_path
+ self.web_path = None
+ self.headers = headers
+ if "~" in self.file_path:
+ self.file_path = os.path.expanduser(self.file_path)
+
+ # If the file is a web path or S3, download it to a temporary file, and use that
+ if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
+ self.temp_dir = tempfile.TemporaryDirectory()
+ _, suffix = os.path.splitext(self.file_path)
+ temp_pdf = os.path.join(self.temp_dir.name, f"tmp{suffix}")
+ self.web_path = self.file_path
+ if not self._is_s3_url(self.file_path):
+ r = requests.get(self.file_path, headers=self.headers)
+ if r.status_code != 200:
+ raise ValueError(
+ "Check the url of your file; returned status code %s"
+ % r.status_code
+ )
+
+ with open(temp_pdf, mode="wb") as f:
+ f.write(r.content)
+ self.file_path = str(temp_pdf)
+ elif not os.path.isfile(self.file_path):
+ raise ValueError("File path %s is not a valid file or url" % self.file_path)
+
+ def __del__(self) -> None:
+ if hasattr(self, "temp_dir"):
+ self.temp_dir.cleanup()
+
+ @staticmethod
+ def _is_valid_url(url: str) -> bool:
+ """Check if the url is valid."""
+ parsed = urlparse(url)
+ return bool(parsed.netloc) and bool(parsed.scheme)
+
+ @staticmethod
+ def _is_s3_url(url: str) -> bool:
+ """check if the url is S3"""
+ try:
+ result = urlparse(url)
+ if result.scheme == "s3" and result.netloc:
+ return True
+ return False
+ except ValueError:
+ return False
+
+ @property
+ def source(self) -> str:
+ return self.web_path if self.web_path is not None else self.file_path
+
+
+class OnlinePDFLoader(BasePDFLoader):
+ """Load online `PDF`."""
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ loader = UnstructuredPDFLoader(str(self.file_path))
+ return loader.load()
+
+
+class PyPDFLoader(BasePDFLoader):
+ """Load PDF using pypdf into list of documents.
+
+ Loader chunks by page and stores page numbers in metadata.
+ """
+
+ def __init__(
+ self,
+ file_path: str,
+ password: Optional[Union[str, bytes]] = None,
+ headers: Optional[Dict] = None,
+ extract_images: bool = False,
+ ) -> None:
+ """Initialize with a file path."""
+ try:
+ import pypdf # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "pypdf package not found, please install it with " "`pip install pypdf`"
+ )
+ super().__init__(file_path, headers=headers)
+ self.parser = PyPDFParser(password=password, extract_images=extract_images)
+
+ def load(self) -> List[Document]:
+ """Load given path as pages."""
+ return list(self.lazy_load())
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Lazy load given path as pages."""
+ if self.web_path:
+ blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
+ else:
+ blob = Blob.from_path(self.file_path)
+ yield from self.parser.parse(blob)
+
+
+class PyPDFium2Loader(BasePDFLoader):
+ """Load `PDF` using `pypdfium2` and chunks at character level."""
+
+ def __init__(
+ self,
+ file_path: str,
+ *,
+ headers: Optional[Dict] = None,
+ extract_images: bool = False,
+ ):
+ """Initialize with a file path."""
+ super().__init__(file_path, headers=headers)
+ self.parser = PyPDFium2Parser(extract_images=extract_images)
+
+ def load(self) -> List[Document]:
+ """Load given path as pages."""
+ return list(self.lazy_load())
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Lazy load given path as pages."""
+ if self.web_path:
+ blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
+ else:
+ blob = Blob.from_path(self.file_path)
+ yield from self.parser.parse(blob)
+
+
+class PyPDFDirectoryLoader(BaseLoader):
+ """Load a directory with `PDF` files using `pypdf` and chunks at character level.
+
+ Loader also stores page numbers in metadata.
+ """
+
+ def __init__(
+ self,
+ path: str,
+ glob: str = "**/[!.]*.pdf",
+ silent_errors: bool = False,
+ load_hidden: bool = False,
+ recursive: bool = False,
+ extract_images: bool = False,
+ ):
+ self.path = path
+ self.glob = glob
+ self.load_hidden = load_hidden
+ self.recursive = recursive
+ self.silent_errors = silent_errors
+ self.extract_images = extract_images
+
+ @staticmethod
+ def _is_visible(path: Path) -> bool:
+ return not any(part.startswith(".") for part in path.parts)
+
+ def load(self) -> List[Document]:
+ p = Path(self.path)
+ docs = []
+ items = p.rglob(self.glob) if self.recursive else p.glob(self.glob)
+ for i in items:
+ if i.is_file():
+ if self._is_visible(i.relative_to(p)) or self.load_hidden:
+ try:
+ loader = PyPDFLoader(str(i), extract_images=self.extract_images)
+ sub_docs = loader.load()
+ for doc in sub_docs:
+ doc.metadata["source"] = str(i)
+ docs.extend(sub_docs)
+ except Exception as e:
+ if self.silent_errors:
+ logger.warning(e)
+ else:
+ raise e
+ return docs
+
+
+class PDFMinerLoader(BasePDFLoader):
+ """Load `PDF` files using `PDFMiner`."""
+
+ def __init__(
+ self,
+ file_path: str,
+ *,
+ headers: Optional[Dict] = None,
+ extract_images: bool = False,
+ concatenate_pages: bool = True,
+ ) -> None:
+ """Initialize with file path.
+
+ Args:
+ extract_images: Whether to extract images from PDF.
+ concatenate_pages: If True, concatenate all PDF pages into one a single
+ document. Otherwise, return one document per page.
+ """
+ try:
+ from pdfminer.high_level import extract_text # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "`pdfminer` package not found, please install it with "
+ "`pip install pdfminer.six`"
+ )
+
+ super().__init__(file_path, headers=headers)
+ self.parser = PDFMinerParser(
+ extract_images=extract_images, concatenate_pages=concatenate_pages
+ )
+
+ def load(self) -> List[Document]:
+ """Eagerly load the content."""
+ return list(self.lazy_load())
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Lazily load documents."""
+ if self.web_path:
+ blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
+ else:
+ blob = Blob.from_path(self.file_path)
+ yield from self.parser.parse(blob)
+
+
+class PDFMinerPDFasHTMLLoader(BasePDFLoader):
+ """Load `PDF` files as HTML content using `PDFMiner`."""
+
+ def __init__(self, file_path: str, *, headers: Optional[Dict] = None):
+ """Initialize with a file path."""
+ try:
+ from pdfminer.high_level import extract_text_to_fp # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "`pdfminer` package not found, please install it with "
+ "`pip install pdfminer.six`"
+ )
+
+ super().__init__(file_path, headers=headers)
+
+ def load(self) -> List[Document]:
+ """Load file."""
+ from pdfminer.high_level import extract_text_to_fp
+ from pdfminer.layout import LAParams
+ from pdfminer.utils import open_filename
+
+ output_string = StringIO()
+ with open_filename(self.file_path, "rb") as fp:
+ extract_text_to_fp(
+ fp, # type: ignore[arg-type]
+ output_string,
+ codec="",
+ laparams=LAParams(),
+ output_type="html",
+ )
+ metadata = {
+ "source": self.file_path if self.web_path is None else self.web_path
+ }
+ return [Document(page_content=output_string.getvalue(), metadata=metadata)]
+
+
+class PyMuPDFLoader(BasePDFLoader):
+ """Load `PDF` files using `PyMuPDF`."""
+
+ def __init__(
+ self,
+ file_path: str,
+ *,
+ headers: Optional[Dict] = None,
+ extract_images: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize with a file path."""
+ try:
+ import fitz # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "`PyMuPDF` package not found, please install it with "
+ "`pip install pymupdf`"
+ )
+ super().__init__(file_path, headers=headers)
+ self.extract_images = extract_images
+ self.text_kwargs = kwargs
+
+ def load(self, **kwargs: Any) -> List[Document]:
+ """Load file."""
+ if kwargs:
+ logger.warning(
+ f"Received runtime arguments {kwargs}. Passing runtime args to `load`"
+ f" is deprecated. Please pass arguments during initialization instead."
+ )
+
+ text_kwargs = {**self.text_kwargs, **kwargs}
+ parser = PyMuPDFParser(
+ text_kwargs=text_kwargs, extract_images=self.extract_images
+ )
+ if self.web_path:
+ blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
+ else:
+ blob = Blob.from_path(self.file_path)
+ return parser.parse(blob)
+
+
+# MathpixPDFLoader implementation taken largely from Daniel Gross's:
+# https://gist.github.com/danielgross/3ab4104e14faccc12b49200843adab21
+class MathpixPDFLoader(BasePDFLoader):
+ """Load `PDF` files using `Mathpix` service."""
+
+ def __init__(
+ self,
+ file_path: str,
+ processed_file_format: str = "md",
+ max_wait_time_seconds: int = 500,
+ should_clean_pdf: bool = False,
+ extra_request_data: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize with a file path.
+
+ Args:
+ file_path: a file for loading.
+ processed_file_format: a format of the processed file. Default is "md".
+ max_wait_time_seconds: a maximum time to wait for the response from
+ the server. Default is 500.
+ should_clean_pdf: a flag to clean the PDF file. Default is False.
+ extra_request_data: Additional request data.
+ **kwargs: additional keyword arguments.
+ """
+ self.mathpix_api_key = get_from_dict_or_env(
+ kwargs, "mathpix_api_key", "MATHPIX_API_KEY"
+ )
+ self.mathpix_api_id = get_from_dict_or_env(
+ kwargs, "mathpix_api_id", "MATHPIX_API_ID"
+ )
+
+ # The base class isn't expecting these and doesn't collect **kwargs
+ kwargs.pop("mathpix_api_key", None)
+ kwargs.pop("mathpix_api_id", None)
+
+ super().__init__(file_path, **kwargs)
+ self.processed_file_format = processed_file_format
+ self.extra_request_data = (
+ extra_request_data if extra_request_data is not None else {}
+ )
+ self.max_wait_time_seconds = max_wait_time_seconds
+ self.should_clean_pdf = should_clean_pdf
+
+ @property
+ def _mathpix_headers(self) -> Dict[str, str]:
+ return {"app_id": self.mathpix_api_id, "app_key": self.mathpix_api_key}
+
+ @property
+ def url(self) -> str:
+ return "https://api.mathpix.com/v3/pdf"
+
+ @property
+ def data(self) -> dict:
+ options = {
+ "conversion_formats": {self.processed_file_format: True},
+ **self.extra_request_data,
+ }
+ return {"options_json": json.dumps(options)}
+
+ def send_pdf(self) -> str:
+ with open(self.file_path, "rb") as f:
+ files = {"file": f}
+ response = requests.post(
+ self.url, headers=self._mathpix_headers, files=files, data=self.data
+ )
+ response_data = response.json()
+ if "error" in response_data:
+ raise ValueError(f"Mathpix request failed: {response_data['error']}")
+ if "pdf_id" in response_data:
+ pdf_id = response_data["pdf_id"]
+ return pdf_id
+ else:
+ raise ValueError("Unable to send PDF to Mathpix.")
+
+ def wait_for_processing(self, pdf_id: str) -> None:
+ """Wait for processing to complete.
+
+ Args:
+ pdf_id: a PDF id.
+
+ Returns: None
+ """
+ url = self.url + "/" + pdf_id
+ for _ in range(0, self.max_wait_time_seconds, 5):
+ response = requests.get(url, headers=self._mathpix_headers)
+ response_data = response.json()
+
+ # This indicates an error with the request (e.g. auth problems)
+ error = response_data.get("error", None)
+
+ if error is not None:
+ raise ValueError(f"Unable to retrieve PDF from Mathpix: {error}")
+
+ status = response_data.get("status", None)
+
+ if status == "completed":
+ return
+ elif status == "error":
+ # This indicates an error with the PDF processing
+ raise ValueError("Unable to retrieve PDF from Mathpix")
+ else:
+ print(f"Status: {status}, waiting for processing to complete")
+ time.sleep(5)
+ raise TimeoutError
+
+ def get_processed_pdf(self, pdf_id: str) -> str:
+ self.wait_for_processing(pdf_id)
+ url = f"{self.url}/{pdf_id}.{self.processed_file_format}"
+ response = requests.get(url, headers=self._mathpix_headers)
+ return response.content.decode("utf-8")
+
+ def clean_pdf(self, contents: str) -> str:
+ """Clean the PDF file.
+
+ Args:
+ contents: a PDF file contents.
+
+ Returns:
+
+ """
+ contents = "\n".join(
+ [line for line in contents.split("\n") if not line.startswith("![]")]
+ )
+ # replace \section{Title} with # Title
+ contents = contents.replace("\\section{", "# ").replace("}", "")
+ # replace the "\" slash that Mathpix adds to escape $, %, (, etc.
+ contents = (
+ contents.replace(r"\$", "$")
+ .replace(r"\%", "%")
+ .replace(r"\(", "(")
+ .replace(r"\)", ")")
+ )
+ return contents
+
+ def load(self) -> List[Document]:
+ pdf_id = self.send_pdf()
+ contents = self.get_processed_pdf(pdf_id)
+ if self.should_clean_pdf:
+ contents = self.clean_pdf(contents)
+ metadata = {"source": self.source, "file_path": self.source}
+ return [Document(page_content=contents, metadata=metadata)]
+
+
+class PDFPlumberLoader(BasePDFLoader):
+ """Load `PDF` files using `pdfplumber`."""
+
+ def __init__(
+ self,
+ file_path: str,
+ text_kwargs: Optional[Mapping[str, Any]] = None,
+ dedupe: bool = False,
+ headers: Optional[Dict] = None,
+ extract_images: bool = False,
+ ) -> None:
+ """Initialize with a file path."""
+ try:
+ import pdfplumber # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "pdfplumber package not found, please install it with "
+ "`pip install pdfplumber`"
+ )
+
+ super().__init__(file_path, headers=headers)
+ self.text_kwargs = text_kwargs or {}
+ self.dedupe = dedupe
+ self.extract_images = extract_images
+
+ def load(self) -> List[Document]:
+ """Load file."""
+
+ parser = PDFPlumberParser(
+ text_kwargs=self.text_kwargs,
+ dedupe=self.dedupe,
+ extract_images=self.extract_images,
+ )
+ if self.web_path:
+ blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
+ else:
+ blob = Blob.from_path(self.file_path)
+ return parser.parse(blob)
+
+
+class AmazonTextractPDFLoader(BasePDFLoader):
+ """Load `PDF` files from a local file system, HTTP or S3.
+
+ To authenticate, the AWS client uses the following methods to
+ automatically load credentials:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+
+ If a specific credential profile should be used, you must pass
+ the name of the profile from the ~/.aws/credentials file that is to be used.
+
+ Make sure the credentials / roles used have the required policies to
+ access the Amazon Textract service.
+
+ Example:
+ .. code-block:: python
+ from langchain_community.document_loaders import AmazonTextractPDFLoader
+ loader = AmazonTextractPDFLoader(
+ file_path="s3://pdfs/myfile.pdf"
+ )
+ document = loader.load()
+ """
+
+ def __init__(
+ self,
+ file_path: str,
+ textract_features: Optional[Sequence[str]] = None,
+ client: Optional[Any] = None,
+ credentials_profile_name: Optional[str] = None,
+ region_name: Optional[str] = None,
+ endpoint_url: Optional[str] = None,
+ headers: Optional[Dict] = None,
+ ) -> None:
+ """Initialize the loader.
+
+ Args:
+ file_path: A file, url or s3 path for input file
+ textract_features: Features to be used for extraction, each feature
+ should be passed as a str that conforms to the enum
+ `Textract_Features`, see `amazon-textract-caller` pkg
+ client: boto3 textract client (Optional)
+ credentials_profile_name: AWS profile name, if not default (Optional)
+ region_name: AWS region, eg us-east-1 (Optional)
+ endpoint_url: endpoint url for the textract service (Optional)
+
+ """
+ super().__init__(file_path, headers=headers)
+
+ try:
+ import textractcaller as tc # noqa: F401
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import amazon-textract-caller python package. "
+ "Please install it with `pip install amazon-textract-caller`."
+ )
+ if textract_features:
+ features = [tc.Textract_Features[x] for x in textract_features]
+ else:
+ features = []
+
+ if credentials_profile_name or region_name or endpoint_url:
+ try:
+ import boto3
+
+ if credentials_profile_name is not None:
+ session = boto3.Session(profile_name=credentials_profile_name)
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ client_params = {}
+ if region_name:
+ client_params["region_name"] = region_name
+ if endpoint_url:
+ client_params["endpoint_url"] = endpoint_url
+
+ client = session.client("textract", **client_params)
+
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ except Exception as e:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+ self.parser = AmazonTextractPDFParser(textract_features=features, client=client)
+
+ def load(self) -> List[Document]:
+ """Load given path as pages."""
+ return list(self.lazy_load())
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Lazy load documents"""
+ # the self.file_path is local, but the blob has to include
+ # the S3 location if the file originated from S3 for multi-page documents
+ # raises ValueError when multi-page and not on S3"""
+
+ if self.web_path and self._is_s3_url(self.web_path):
+ blob = Blob(path=self.web_path)
+ else:
+ blob = Blob.from_path(self.file_path)
+ if AmazonTextractPDFLoader._get_number_of_pages(blob) > 1:
+ raise ValueError(
+ f"the file {blob.path} is a multi-page document, \
+ but not stored on S3. \
+ Textract requires multi-page documents to be on S3."
+ )
+
+ yield from self.parser.parse(blob)
+
+ @staticmethod
+ def _get_number_of_pages(blob: Blob) -> int:
+ try:
+ import pypdf
+ from PIL import Image, ImageSequence
+
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import pypdf or Pilloe python package. "
+ "Please install it with `pip install pypdf Pillow`."
+ )
+ if blob.mimetype == "application/pdf":
+ with blob.as_bytes_io() as input_pdf_file:
+ pdf_reader = pypdf.PdfReader(input_pdf_file)
+ return len(pdf_reader.pages)
+ elif blob.mimetype == "image/tiff":
+ num_pages = 0
+ img = Image.open(blob.as_bytes())
+ for _, _ in enumerate(ImageSequence.Iterator(img)):
+ num_pages += 1
+ return num_pages
+ elif blob.mimetype in ["image/png", "image/jpeg"]:
+ return 1
+ else:
+ raise ValueError(f"unsupported mime type: {blob.mimetype}")
+
+
+class DocumentIntelligenceLoader(BasePDFLoader):
+ """Loads a PDF with Azure Document Intelligence"""
+
+ def __init__(
+ self,
+ file_path: str,
+ client: Any,
+ model: str = "prebuilt-document",
+ headers: Optional[Dict] = None,
+ ) -> None:
+ """
+ Initialize the object for file processing with Azure Document Intelligence
+ (formerly Form Recognizer).
+
+ This constructor initializes a DocumentIntelligenceParser object to be used
+ for parsing files using the Azure Document Intelligence API. The load method
+ generates a Document node including metadata (source blob and page number)
+ for each page.
+
+ Parameters:
+ -----------
+ file_path : str
+ The path to the file that needs to be parsed.
+ client: Any
+ A DocumentAnalysisClient to perform the analysis of the blob
+ model : str
+ The model name or ID to be used for form recognition in Azure.
+
+ Examples:
+ ---------
+ >>> obj = DocumentIntelligenceLoader(
+ ... file_path="path/to/file",
+ ... client=client,
+ ... model="prebuilt-document"
+ ... )
+ """
+
+ self.parser = DocumentIntelligenceParser(client=client, model=model)
+ super().__init__(file_path, headers=headers)
+
+ def load(self) -> List[Document]:
+ """Load given path as pages."""
+ return list(self.lazy_load())
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Lazy load given path as pages."""
+ blob = Blob.from_path(self.file_path)
+ yield from self.parser.parse(blob)
diff --git a/libs/community/langchain_community/document_loaders/polars_dataframe.py b/libs/community/langchain_community/document_loaders/polars_dataframe.py
new file mode 100644
index 00000000000..bb523df647f
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/polars_dataframe.py
@@ -0,0 +1,33 @@
+from typing import Any, Iterator
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.dataframe import BaseDataFrameLoader
+
+
+class PolarsDataFrameLoader(BaseDataFrameLoader):
+ """Load `Polars` DataFrame."""
+
+ def __init__(self, data_frame: Any, *, page_content_column: str = "text"):
+ """Initialize with dataframe object.
+
+ Args:
+ data_frame: Polars DataFrame object.
+ page_content_column: Name of the column containing the page content.
+ Defaults to "text".
+ """
+ import polars as pl
+
+ if not isinstance(data_frame, pl.DataFrame):
+ raise ValueError(
+ f"Expected data_frame to be a pl.DataFrame, got {type(data_frame)}"
+ )
+ super().__init__(data_frame, page_content_column=page_content_column)
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load records from dataframe."""
+
+ for row in self.data_frame.iter_rows(named=True):
+ text = row[self.page_content_column]
+ row.pop(self.page_content_column)
+ yield Document(page_content=text, metadata=row)
diff --git a/libs/community/langchain_community/document_loaders/powerpoint.py b/libs/community/langchain_community/document_loaders/powerpoint.py
new file mode 100644
index 00000000000..6f559881551
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/powerpoint.py
@@ -0,0 +1,64 @@
+import os
+from typing import List
+
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class UnstructuredPowerPointLoader(UnstructuredFileLoader):
+ """Load `Microsoft PowerPoint` files using `Unstructured`.
+
+ Works with both .ppt and .pptx files.
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredPowerPointLoader
+
+ loader = UnstructuredPowerPointLoader(
+ "example.pptx", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-pptx
+ """
+
+ def _get_elements(self) -> List:
+ from unstructured.__version__ import __version__ as __unstructured_version__
+ from unstructured.file_utils.filetype import FileType, detect_filetype
+
+ unstructured_version = tuple(
+ [int(x) for x in __unstructured_version__.split(".")]
+ )
+ # NOTE(MthwRobinson) - magic will raise an import error if the libmagic
+ # system dependency isn't installed. If it's not installed, we'll just
+ # check the file extension
+ try:
+ import magic # noqa: F401
+
+ is_ppt = detect_filetype(self.file_path) == FileType.PPT
+ except ImportError:
+ _, extension = os.path.splitext(str(self.file_path))
+ is_ppt = extension == ".ppt"
+
+ if is_ppt and unstructured_version < (0, 4, 11):
+ raise ValueError(
+ f"You are on unstructured version {__unstructured_version__}. "
+ "Partitioning .ppt files is only supported in unstructured>=0.4.11. "
+ "Please upgrade the unstructured package and try again."
+ )
+
+ if is_ppt:
+ from unstructured.partition.ppt import partition_ppt
+
+ return partition_ppt(filename=self.file_path, **self.unstructured_kwargs)
+ else:
+ from unstructured.partition.pptx import partition_pptx
+
+ return partition_pptx(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/psychic.py b/libs/community/langchain_community/document_loaders/psychic.py
new file mode 100644
index 00000000000..aa73be801d3
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/psychic.py
@@ -0,0 +1,44 @@
+from typing import List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class PsychicLoader(BaseLoader):
+ """Load from `Psychic.dev`."""
+
+ def __init__(
+ self, api_key: str, account_id: str, connector_id: Optional[str] = None
+ ):
+ """Initialize with API key, connector id, and account id.
+
+ Args:
+ api_key: The Psychic API key.
+ account_id: The Psychic account id.
+ connector_id: The Psychic connector id.
+ """
+
+ try:
+ from psychicapi import ConnectorId, Psychic # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "`psychicapi` package not found, please run `pip install psychicapi`"
+ )
+ self.psychic = Psychic(secret_key=api_key)
+ self.connector_id = ConnectorId(connector_id)
+ self.account_id = account_id
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+
+ psychic_docs = self.psychic.get_documents(
+ connector_id=self.connector_id, account_id=self.account_id
+ )
+ return [
+ Document(
+ page_content=doc["content"],
+ metadata={"title": doc["title"], "source": doc["uri"]},
+ )
+ for doc in psychic_docs.documents
+ ]
diff --git a/libs/community/langchain_community/document_loaders/pubmed.py b/libs/community/langchain_community/document_loaders/pubmed.py
new file mode 100644
index 00000000000..de47787d7bb
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/pubmed.py
@@ -0,0 +1,40 @@
+from typing import Iterator, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.pubmed import PubMedAPIWrapper
+
+
+class PubMedLoader(BaseLoader):
+ """Load from the `PubMed` biomedical library.
+
+ Attributes:
+ query: The query to be passed to the PubMed API.
+ load_max_docs: The maximum number of documents to load.
+ """
+
+ def __init__(
+ self,
+ query: str,
+ load_max_docs: Optional[int] = 3,
+ ):
+ """Initialize the PubMedLoader.
+
+ Args:
+ query: The query to be passed to the PubMed API.
+ load_max_docs: The maximum number of documents to load.
+ Defaults to 3.
+ """
+ self.query = query
+ self.load_max_docs = load_max_docs
+ self._client = PubMedAPIWrapper(
+ top_k_results=load_max_docs,
+ )
+
+ def load(self) -> List[Document]:
+ return list(self._client.lazy_load_docs(self.query))
+
+ def lazy_load(self) -> Iterator[Document]:
+ for doc in self._client.lazy_load_docs(self.query):
+ yield doc
diff --git a/libs/community/langchain_community/document_loaders/pyspark_dataframe.py b/libs/community/langchain_community/document_loaders/pyspark_dataframe.py
new file mode 100644
index 00000000000..410b7d07afc
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/pyspark_dataframe.py
@@ -0,0 +1,92 @@
+import itertools
+import logging
+import sys
+from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__file__)
+
+if TYPE_CHECKING:
+ from pyspark.sql import SparkSession
+
+
+class PySparkDataFrameLoader(BaseLoader):
+ """Load `PySpark` DataFrames."""
+
+ def __init__(
+ self,
+ spark_session: Optional["SparkSession"] = None,
+ df: Optional[Any] = None,
+ page_content_column: str = "text",
+ fraction_of_memory: float = 0.1,
+ ):
+ """Initialize with a Spark DataFrame object.
+
+ Args:
+ spark_session: The SparkSession object.
+ df: The Spark DataFrame object.
+ page_content_column: The name of the column containing the page content.
+ Defaults to "text".
+ fraction_of_memory: The fraction of memory to use. Defaults to 0.1.
+ """
+ try:
+ from pyspark.sql import DataFrame, SparkSession
+ except ImportError:
+ raise ImportError(
+ "pyspark is not installed. "
+ "Please install it with `pip install pyspark`"
+ )
+
+ self.spark = (
+ spark_session if spark_session else SparkSession.builder.getOrCreate()
+ )
+
+ if not isinstance(df, DataFrame):
+ raise ValueError(
+ f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
+ )
+ self.df = df
+ self.page_content_column = page_content_column
+ self.fraction_of_memory = fraction_of_memory
+ self.num_rows, self.max_num_rows = self.get_num_rows()
+ self.rdd_df = self.df.rdd.map(list)
+ self.column_names = self.df.columns
+
+ def get_num_rows(self) -> Tuple[int, int]:
+ """Gets the number of "feasible" rows for the DataFrame"""
+ try:
+ import psutil
+ except ImportError as e:
+ raise ImportError(
+ "psutil not installed. Please install it with `pip install psutil`."
+ ) from e
+ row = self.df.limit(1).collect()[0]
+ estimated_row_size = sys.getsizeof(row)
+ mem_info = psutil.virtual_memory()
+ available_memory = mem_info.available
+ max_num_rows = int(
+ (available_memory / estimated_row_size) * self.fraction_of_memory
+ )
+ return min(max_num_rows, self.df.count()), max_num_rows
+
+ def lazy_load(self) -> Iterator[Document]:
+ """A lazy loader for document content."""
+ for row in self.rdd_df.toLocalIterator():
+ metadata = {self.column_names[i]: row[i] for i in range(len(row))}
+ text = metadata[self.page_content_column]
+ metadata.pop(self.page_content_column)
+ yield Document(page_content=text, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """Load from the dataframe."""
+ if self.df.count() > self.max_num_rows:
+ logger.warning(
+ f"The number of DataFrame rows is {self.df.count()}, "
+ f"but we will only include the amount "
+ f"of rows that can reasonably fit in memory: {self.num_rows}."
+ )
+ lazy_load_iterator = self.lazy_load()
+ return list(itertools.islice(lazy_load_iterator, self.num_rows))
diff --git a/libs/community/langchain_community/document_loaders/python.py b/libs/community/langchain_community/document_loaders/python.py
new file mode 100644
index 00000000000..9afbbd30f78
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/python.py
@@ -0,0 +1,17 @@
+import tokenize
+
+from langchain_community.document_loaders.text import TextLoader
+
+
+class PythonLoader(TextLoader):
+ """Load `Python` files, respecting any non-default encoding if specified."""
+
+ def __init__(self, file_path: str):
+ """Initialize with a file path.
+
+ Args:
+ file_path: The path to the file to load.
+ """
+ with open(file_path, "rb") as f:
+ encoding, _ = tokenize.detect_encoding(f.readline)
+ super().__init__(file_path=file_path, encoding=encoding)
diff --git a/libs/community/langchain_community/document_loaders/quip.py b/libs/community/langchain_community/document_loaders/quip.py
new file mode 100644
index 00000000000..0d9c4474fb0
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/quip.py
@@ -0,0 +1,233 @@
+import logging
+import re
+import xml.etree.cElementTree
+import xml.sax.saxutils
+from io import BytesIO
+from typing import List, Optional, Sequence
+from xml.etree.ElementTree import ElementTree
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+_MAXIMUM_TITLE_LENGTH = 64
+
+
+class QuipLoader(BaseLoader):
+ """Load `Quip` pages.
+
+ Port of https://github.com/quip/quip-api/tree/master/samples/baqup
+ """
+
+ def __init__(
+ self, api_url: str, access_token: str, request_timeout: Optional[int] = 60
+ ):
+ """
+ Args:
+ api_url: https://platform.quip.com
+ access_token: token of access quip API. Please refer:
+ https://quip.com/dev/automation/documentation/current#section/Authentication/Get-Access-to-Quip's-APIs
+ request_timeout: timeout of request, default 60s.
+ """
+ try:
+ from quip_api.quip import QuipClient
+ except ImportError:
+ raise ImportError(
+ "`quip_api` package not found, please run " "`pip install quip_api`"
+ )
+
+ self.quip_client = QuipClient(
+ access_token=access_token, base_url=api_url, request_timeout=request_timeout
+ )
+
+ def load(
+ self,
+ folder_ids: Optional[List[str]] = None,
+ thread_ids: Optional[List[str]] = None,
+ max_docs: Optional[int] = 1000,
+ include_all_folders: bool = False,
+ include_comments: bool = False,
+ include_images: bool = False,
+ ) -> List[Document]:
+ """
+ Args:
+ :param folder_ids: List of specific folder IDs to load, defaults to None
+ :param thread_ids: List of specific thread IDs to load, defaults to None
+ :param max_docs: Maximum number of docs to retrieve in total, defaults 1000
+ :param include_all_folders: Include all folders that your access_token
+ can access, but doesn't include your private folder
+ :param include_comments: Include comments, defaults to False
+ :param include_images: Include images, defaults to False
+ """
+ if not folder_ids and not thread_ids and not include_all_folders:
+ raise ValueError(
+ "Must specify at least one among `folder_ids`, `thread_ids` "
+ "or set `include_all`_folders as True"
+ )
+
+ thread_ids = thread_ids or []
+
+ if folder_ids:
+ for folder_id in folder_ids:
+ self.get_thread_ids_by_folder_id(folder_id, 0, thread_ids)
+
+ if include_all_folders:
+ user = self.quip_client.get_authenticated_user()
+ if "group_folder_ids" in user:
+ self.get_thread_ids_by_folder_id(
+ user["group_folder_ids"], 0, thread_ids
+ )
+ if "shared_folder_ids" in user:
+ self.get_thread_ids_by_folder_id(
+ user["shared_folder_ids"], 0, thread_ids
+ )
+
+ thread_ids = list(set(thread_ids[:max_docs]))
+ return self.process_threads(thread_ids, include_images, include_comments)
+
+ def get_thread_ids_by_folder_id(
+ self, folder_id: str, depth: int, thread_ids: List[str]
+ ) -> None:
+ """Get thread ids by folder id and update in thread_ids"""
+ from quip_api.quip import HTTPError, QuipError
+
+ try:
+ folder = self.quip_client.get_folder(folder_id)
+ except QuipError as e:
+ if e.code == 403:
+ logging.warning(
+ f"depth {depth}, Skipped over restricted folder {folder_id}, {e}"
+ )
+ else:
+ logging.warning(
+ f"depth {depth}, Skipped over folder {folder_id} "
+ f"due to unknown error {e.code}"
+ )
+ return
+ except HTTPError as e:
+ logging.warning(
+ f"depth {depth}, Skipped over folder {folder_id} "
+ f"due to HTTP error {e.code}"
+ )
+ return
+
+ title = folder["folder"].get("title", "Folder %s" % folder_id)
+
+ logging.info(f"depth {depth}, Processing folder {title}")
+ for child in folder["children"]:
+ if "folder_id" in child:
+ self.get_thread_ids_by_folder_id(
+ child["folder_id"], depth + 1, thread_ids
+ )
+ elif "thread_id" in child:
+ thread_ids.append(child["thread_id"])
+
+ def process_threads(
+ self, thread_ids: Sequence[str], include_images: bool, include_messages: bool
+ ) -> List[Document]:
+ """Process a list of thread into a list of documents."""
+ docs = []
+ for thread_id in thread_ids:
+ doc = self.process_thread(thread_id, include_images, include_messages)
+ if doc is not None:
+ docs.append(doc)
+ return docs
+
+ def process_thread(
+ self, thread_id: str, include_images: bool, include_messages: bool
+ ) -> Optional[Document]:
+ thread = self.quip_client.get_thread(thread_id)
+ thread_id = thread["thread"]["id"]
+ title = thread["thread"]["title"]
+ link = thread["thread"]["link"]
+ update_ts = thread["thread"]["updated_usec"]
+ sanitized_title = QuipLoader._sanitize_title(title)
+
+ logger.info(
+ f"processing thread {thread_id} title {sanitized_title} "
+ f"link {link} update_ts {update_ts}"
+ )
+
+ if "html" in thread:
+ # Parse the document
+ try:
+ tree = self.quip_client.parse_document_html(thread["html"])
+ except xml.etree.cElementTree.ParseError as e:
+ logger.error(f"Error parsing thread {title} {thread_id}, skipping, {e}")
+ return None
+
+ metadata = {
+ "title": sanitized_title,
+ "update_ts": update_ts,
+ "id": thread_id,
+ "source": link,
+ }
+
+ # Download each image and replace with the new URL
+ text = ""
+ if include_images:
+ text = self.process_thread_images(tree)
+
+ if include_messages:
+ text = text + "/n" + self.process_thread_messages(thread_id)
+
+ return Document(
+ page_content=thread["html"] + text,
+ metadata=metadata,
+ )
+ return None
+
+ def process_thread_images(self, tree: ElementTree) -> str:
+ text = ""
+
+ try:
+ from PIL import Image
+ from pytesseract import pytesseract
+ except ImportError:
+ raise ImportError(
+ "`Pillow or pytesseract` package not found, "
+ "please run "
+ "`pip install Pillow` or `pip install pytesseract`"
+ )
+
+ for img in tree.iter("img"):
+ src = img.get("src")
+ if not src or not src.startswith("/blob"):
+ continue
+ _, _, thread_id, blob_id = src.split("/")
+ blob_response = self.quip_client.get_blob(thread_id, blob_id)
+ try:
+ image = Image.open(BytesIO(blob_response.read()))
+ text = text + "\n" + pytesseract.image_to_string(image)
+ except OSError as e:
+ logger.error(f"failed to convert image to text, {e}")
+ raise e
+ return text
+
+ def process_thread_messages(self, thread_id: str) -> str:
+ max_created_usec = None
+ messages = []
+ while True:
+ chunk = self.quip_client.get_messages(
+ thread_id, max_created_usec=max_created_usec, count=100
+ )
+ messages.extend(chunk)
+ if chunk:
+ max_created_usec = chunk[-1]["created_usec"] - 1
+ else:
+ break
+ messages.reverse()
+
+ texts = [message["text"] for message in messages]
+
+ return "\n".join(texts)
+
+ @staticmethod
+ def _sanitize_title(title: str) -> str:
+ sanitized_title = re.sub(r"\s", " ", title)
+ sanitized_title = re.sub(r"(?u)[^- \w.]", "", sanitized_title)
+ if len(sanitized_title) > _MAXIMUM_TITLE_LENGTH:
+ sanitized_title = sanitized_title[:_MAXIMUM_TITLE_LENGTH]
+ return sanitized_title
diff --git a/libs/community/langchain_community/document_loaders/readthedocs.py b/libs/community/langchain_community/document_loaders/readthedocs.py
new file mode 100644
index 00000000000..aee16bab6d7
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/readthedocs.py
@@ -0,0 +1,222 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence, Tuple, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ from bs4 import NavigableString
+ from bs4.element import Comment, Tag
+
+
+class ReadTheDocsLoader(BaseLoader):
+ """Load `ReadTheDocs` documentation directory."""
+
+ def __init__(
+ self,
+ path: Union[str, Path],
+ encoding: Optional[str] = None,
+ errors: Optional[str] = None,
+ custom_html_tag: Optional[Tuple[str, dict]] = None,
+ patterns: Sequence[str] = ("*.htm", "*.html"),
+ exclude_links_ratio: float = 1.0,
+ **kwargs: Optional[Any],
+ ):
+ """
+ Initialize ReadTheDocsLoader
+
+ The loader loops over all files under `path` and extracts the actual content of
+ the files by retrieving main html tags. Default main html tags include
+ ``, and ``. You
+ can also define your own html tags by passing custom_html_tag, e.g.
+ `("div", "class=main")`. The loader iterates html tags with the order of
+ custom html tags (if exists) and default html tags. If any of the tags is not
+ empty, the loop will break and retrieve the content out of that tag.
+
+ Args:
+ path: The location of pulled readthedocs folder.
+ encoding: The encoding with which to open the documents.
+ errors: Specify how encoding and decoding errors are to be handledβthis
+ cannot be used in binary mode.
+ custom_html_tag: Optional custom html tag to retrieve the content from
+ files.
+ patterns: The file patterns to load, passed to `glob.rglob`.
+ exclude_links_ratio: The ratio of links:content to exclude pages from.
+ This is to reduce the frequency at which index pages make their
+ way into retrieved results. Recommended: 0.5
+ kwargs: named arguments passed to `bs4.BeautifulSoup`.
+ """
+ try:
+ from bs4 import BeautifulSoup
+ except ImportError:
+ raise ImportError(
+ "Could not import python packages. "
+ "Please install it with `pip install beautifulsoup4`. "
+ )
+
+ try:
+ _ = BeautifulSoup(
+ "Parser builder library test.",
+ "html.parser",
+ **kwargs,
+ )
+ except Exception as e:
+ raise ValueError("Parsing kwargs do not appear valid") from e
+
+ self.file_path = Path(path)
+ self.encoding = encoding
+ self.errors = errors
+ self.custom_html_tag = custom_html_tag
+ self.patterns = patterns
+ self.bs_kwargs = kwargs
+ self.exclude_links_ratio = exclude_links_ratio
+
+ def lazy_load(self) -> Iterator[Document]:
+ """A lazy loader for Documents."""
+ for file_pattern in self.patterns:
+ for p in self.file_path.rglob(file_pattern):
+ if p.is_dir():
+ continue
+ with open(p, encoding=self.encoding, errors=self.errors) as f:
+ text = self._clean_data(f.read())
+ yield Document(page_content=text, metadata={"source": str(p)})
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ return list(self.lazy_load())
+
+ def _clean_data(self, data: str) -> str:
+ from bs4 import BeautifulSoup
+
+ soup = BeautifulSoup(data, "html.parser", **self.bs_kwargs)
+
+ # default tags
+ html_tags = [
+ ("div", {"role": "main"}),
+ ("main", {"id": "main-content"}),
+ ]
+
+ if self.custom_html_tag is not None:
+ html_tags.append(self.custom_html_tag)
+
+ element = None
+
+ # reversed order. check the custom one first
+ for tag, attrs in html_tags[::-1]:
+ element = soup.find(tag, attrs)
+ # if found, break
+ if element is not None:
+ break
+
+ if element is not None and _get_link_ratio(element) <= self.exclude_links_ratio:
+ text = _get_clean_text(element)
+ else:
+ text = ""
+ # trim empty lines
+ return "\n".join([t for t in text.split("\n") if t])
+
+
+def _get_clean_text(element: Tag) -> str:
+ """Returns cleaned text with newlines preserved and irrelevant elements removed."""
+ elements_to_skip = [
+ "script",
+ "noscript",
+ "canvas",
+ "meta",
+ "svg",
+ "map",
+ "area",
+ "audio",
+ "source",
+ "track",
+ "video",
+ "embed",
+ "object",
+ "param",
+ "picture",
+ "iframe",
+ "frame",
+ "frameset",
+ "noframes",
+ "applet",
+ "form",
+ "button",
+ "select",
+ "base",
+ "style",
+ "img",
+ ]
+
+ newline_elements = [
+ "p",
+ "div",
+ "ul",
+ "ol",
+ "li",
+ "h1",
+ "h2",
+ "h3",
+ "h4",
+ "h5",
+ "h6",
+ "pre",
+ "table",
+ "tr",
+ ]
+
+ text = _process_element(element, elements_to_skip, newline_elements)
+ return text.strip()
+
+
+def _get_link_ratio(section: Tag) -> float:
+ links = section.find_all("a")
+ total_text = "".join(str(s) for s in section.stripped_strings)
+ if len(total_text) == 0:
+ return 0
+
+ link_text = "".join(
+ str(string.string.strip())
+ for link in links
+ for string in link.strings
+ if string
+ )
+ return len(link_text) / len(total_text)
+
+
+def _process_element(
+ element: Union[Tag, NavigableString, Comment],
+ elements_to_skip: List[str],
+ newline_elements: List[str],
+) -> str:
+ """
+ Traverse through HTML tree recursively to preserve newline and skip
+ unwanted (code/binary) elements
+ """
+ from bs4 import NavigableString
+ from bs4.element import Comment, Tag
+
+ tag_name = getattr(element, "name", None)
+ if isinstance(element, Comment) or tag_name in elements_to_skip:
+ return ""
+ elif isinstance(element, NavigableString):
+ return element
+ elif tag_name == "br":
+ return "\n"
+ elif tag_name in newline_elements:
+ return (
+ "".join(
+ _process_element(child, elements_to_skip, newline_elements)
+ for child in element.children
+ if isinstance(child, (Tag, NavigableString, Comment))
+ )
+ + "\n"
+ )
+ else:
+ return "".join(
+ _process_element(child, elements_to_skip, newline_elements)
+ for child in element.children
+ if isinstance(child, (Tag, NavigableString, Comment))
+ )
diff --git a/libs/community/langchain_community/document_loaders/recursive_url_loader.py b/libs/community/langchain_community/document_loaders/recursive_url_loader.py
new file mode 100644
index 00000000000..6fca4edf86a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/recursive_url_loader.py
@@ -0,0 +1,303 @@
+from __future__ import annotations
+
+import asyncio
+import logging
+import re
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Union,
+)
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.utils.html import extract_sub_links
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ import aiohttp
+
+logger = logging.getLogger(__name__)
+
+
+def _metadata_extractor(raw_html: str, url: str) -> dict:
+ """Extract metadata from raw html using BeautifulSoup."""
+ metadata = {"source": url}
+
+ try:
+ from bs4 import BeautifulSoup
+ except ImportError:
+ logger.warning(
+ "The bs4 package is required for default metadata extraction. "
+ "Please install it with `pip install bs4`."
+ )
+ return metadata
+ soup = BeautifulSoup(raw_html, "html.parser")
+ if title := soup.find("title"):
+ metadata["title"] = title.get_text()
+ if description := soup.find("meta", attrs={"name": "description"}):
+ metadata["description"] = description.get("content", None)
+ if html := soup.find("html"):
+ metadata["language"] = html.get("lang", None)
+ return metadata
+
+
+class RecursiveUrlLoader(BaseLoader):
+ """Load all child links from a URL page.
+
+ **Security Note**: This loader is a crawler that will start crawling
+ at a given URL and then expand to crawl child links recursively.
+
+ Web crawlers should generally NOT be deployed with network access
+ to any internal servers.
+
+ Control access to who can submit crawling requests and what network access
+ the crawler has.
+
+ While crawling, the crawler may encounter malicious URLs that would lead to a
+ server-side request forgery (SSRF) attack.
+
+ To mitigate risks, the crawler by default will only load URLs from the same
+ domain as the start URL (controlled via prevent_outside named argument).
+
+ This will mitigate the risk of SSRF attacks, but will not eliminate it.
+
+ For example, if crawling a host which hosts several sites:
+
+ https://some_host/alice_site/
+ https://some_host/bob_site/
+
+ A malicious URL on Alice's site could cause the crawler to make a malicious
+ GET request to an endpoint on Bob's site. Both sites are hosted on the
+ same host, so such a request would not be prevented by default.
+
+ See https://python.langchain.com/docs/security
+ """
+
+ def __init__(
+ self,
+ url: str,
+ max_depth: Optional[int] = 2,
+ use_async: Optional[bool] = None,
+ extractor: Optional[Callable[[str], str]] = None,
+ metadata_extractor: Optional[Callable[[str, str], str]] = None,
+ exclude_dirs: Optional[Sequence[str]] = (),
+ timeout: Optional[int] = 10,
+ prevent_outside: bool = True,
+ link_regex: Union[str, re.Pattern, None] = None,
+ headers: Optional[dict] = None,
+ check_response_status: bool = False,
+ ) -> None:
+ """Initialize with URL to crawl and any subdirectories to exclude.
+
+ Args:
+ url: The URL to crawl.
+ max_depth: The max depth of the recursive loading.
+ use_async: Whether to use asynchronous loading.
+ If True, this function will not be lazy, but it will still work in the
+ expected way, just not lazy.
+ extractor: A function to extract document contents from raw html.
+ When extract function returns an empty string, the document is
+ ignored.
+ metadata_extractor: A function to extract metadata from raw html and the
+ source url (args in that order). Default extractor will attempt
+ to use BeautifulSoup4 to extract the title, description and language
+ of the page.
+ exclude_dirs: A list of subdirectories to exclude.
+ timeout: The timeout for the requests, in the unit of seconds. If None then
+ connection will not timeout.
+ prevent_outside: If True, prevent loading from urls which are not children
+ of the root url.
+ link_regex: Regex for extracting sub-links from the raw html of a web page.
+ check_response_status: If True, check HTTP response status and skip
+ URLs with error responses (400-599).
+ """
+
+ self.url = url
+ self.max_depth = max_depth if max_depth is not None else 2
+ self.use_async = use_async if use_async is not None else False
+ self.extractor = extractor if extractor is not None else lambda x: x
+ self.metadata_extractor = (
+ metadata_extractor
+ if metadata_extractor is not None
+ else _metadata_extractor
+ )
+ self.exclude_dirs = exclude_dirs if exclude_dirs is not None else ()
+
+ if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs):
+ raise ValueError(
+ f"Base url is included in exclude_dirs. Received base_url: {url} and "
+ f"exclude_dirs: {self.exclude_dirs}"
+ )
+
+ self.timeout = timeout
+ self.prevent_outside = prevent_outside if prevent_outside is not None else True
+ self.link_regex = link_regex
+ self._lock = asyncio.Lock() if self.use_async else None
+ self.headers = headers
+ self.check_response_status = check_response_status
+
+ def _get_child_links_recursive(
+ self, url: str, visited: Set[str], *, depth: int = 0
+ ) -> Iterator[Document]:
+ """Recursively get all child links starting with the path of the input URL.
+
+ Args:
+ url: The URL to crawl.
+ visited: A set of visited URLs.
+ depth: Current depth of recursion. Stop when depth >= max_depth.
+ """
+
+ if depth >= self.max_depth:
+ return
+
+ # Get all links that can be accessed from the current URL
+ visited.add(url)
+ try:
+ response = requests.get(url, timeout=self.timeout, headers=self.headers)
+ if self.check_response_status and 400 <= response.status_code <= 599:
+ raise ValueError(f"Received HTTP status {response.status_code}")
+ except Exception as e:
+ logger.warning(
+ f"Unable to load from {url}. Received error {e} of type "
+ f"{e.__class__.__name__}"
+ )
+ return
+ content = self.extractor(response.text)
+ if content:
+ yield Document(
+ page_content=content,
+ metadata=self.metadata_extractor(response.text, url),
+ )
+
+ # Store the visited links and recursively visit the children
+ sub_links = extract_sub_links(
+ response.text,
+ url,
+ base_url=self.url,
+ pattern=self.link_regex,
+ prevent_outside=self.prevent_outside,
+ exclude_prefixes=self.exclude_dirs,
+ )
+ for link in sub_links:
+ # Check all unvisited links
+ if link not in visited:
+ yield from self._get_child_links_recursive(
+ link, visited, depth=depth + 1
+ )
+
+ async def _async_get_child_links_recursive(
+ self,
+ url: str,
+ visited: Set[str],
+ *,
+ session: Optional[aiohttp.ClientSession] = None,
+ depth: int = 0,
+ ) -> List[Document]:
+ """Recursively get all child links starting with the path of the input URL.
+
+ Args:
+ url: The URL to crawl.
+ visited: A set of visited URLs.
+ depth: To reach the current url, how many pages have been visited.
+ """
+ try:
+ import aiohttp
+ except ImportError:
+ raise ImportError(
+ "The aiohttp package is required for the RecursiveUrlLoader. "
+ "Please install it with `pip install aiohttp`."
+ )
+ if depth >= self.max_depth:
+ return []
+
+ # Disable SSL verification because websites may have invalid SSL certificates,
+ # but won't cause any security issues for us.
+ close_session = session is None
+ session = (
+ session
+ if session is not None
+ else aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl=False),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
+ headers=self.headers,
+ )
+ )
+ async with self._lock: # type: ignore
+ visited.add(url)
+ try:
+ async with session.get(url) as response:
+ text = await response.text()
+ if self.check_response_status and 400 <= response.status <= 599:
+ raise ValueError(f"Received HTTP status {response.status}")
+ except (aiohttp.client_exceptions.InvalidURL, Exception) as e:
+ logger.warning(
+ f"Unable to load {url}. Received error {e} of type "
+ f"{e.__class__.__name__}"
+ )
+ if close_session:
+ await session.close()
+ return []
+ results = []
+ content = self.extractor(text)
+ if content:
+ results.append(
+ Document(
+ page_content=content,
+ metadata=self.metadata_extractor(text, url),
+ )
+ )
+ if depth < self.max_depth - 1:
+ sub_links = extract_sub_links(
+ text,
+ url,
+ base_url=self.url,
+ pattern=self.link_regex,
+ prevent_outside=self.prevent_outside,
+ exclude_prefixes=self.exclude_dirs,
+ )
+
+ # Recursively call the function to get the children of the children
+ sub_tasks = []
+ async with self._lock: # type: ignore
+ to_visit = set(sub_links).difference(visited)
+ for link in to_visit:
+ sub_tasks.append(
+ self._async_get_child_links_recursive(
+ link, visited, session=session, depth=depth + 1
+ )
+ )
+ next_results = await asyncio.gather(*sub_tasks)
+ for sub_result in next_results:
+ if isinstance(sub_result, Exception) or sub_result is None:
+ # We don't want to stop the whole process, so just ignore it
+ # Not standard html format or invalid url or 404 may cause this.
+ continue
+ # locking not fully working, temporary hack to ensure deduplication
+ results += [r for r in sub_result if r not in results]
+ if close_session:
+ await session.close()
+ return results
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load web pages.
+ When use_async is True, this function will not be lazy,
+ but it will still work in the expected way, just not lazy."""
+ visited: Set[str] = set()
+ if self.use_async:
+ results = asyncio.run(
+ self._async_get_child_links_recursive(self.url, visited)
+ )
+ return iter(results or [])
+ else:
+ return self._get_child_links_recursive(self.url, visited)
+
+ def load(self) -> List[Document]:
+ """Load web pages."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/reddit.py b/libs/community/langchain_community/document_loaders/reddit.py
new file mode 100644
index 00000000000..47c46570d8e
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/reddit.py
@@ -0,0 +1,143 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ import praw
+
+
+def _dependable_praw_import() -> praw:
+ try:
+ import praw
+ except ImportError:
+ raise ImportError(
+ "praw package not found, please install it with `pip install praw`"
+ )
+ return praw
+
+
+class RedditPostsLoader(BaseLoader):
+ """Load `Reddit` posts.
+
+ Read posts on a subreddit.
+ First, you need to go to
+ https://www.reddit.com/prefs/apps/
+ and create your application
+ """
+
+ def __init__(
+ self,
+ client_id: str,
+ client_secret: str,
+ user_agent: str,
+ search_queries: Sequence[str],
+ mode: str,
+ categories: Sequence[str] = ["new"],
+ number_posts: Optional[int] = 10,
+ ):
+ """
+ Initialize with client_id, client_secret, user_agent, search_queries, mode,
+ categories, number_posts.
+ Example: https://www.reddit.com/r/learnpython/
+
+ Args:
+ client_id: Reddit client id.
+ client_secret: Reddit client secret.
+ user_agent: Reddit user agent.
+ search_queries: The search queries.
+ mode: The mode.
+ categories: The categories. Default: ["new"]
+ number_posts: The number of posts. Default: 10
+ """
+ self.client_id = client_id
+ self.client_secret = client_secret
+ self.user_agent = user_agent
+ self.search_queries = search_queries
+ self.mode = mode
+ self.categories = categories
+ self.number_posts = number_posts
+
+ def load(self) -> List[Document]:
+ """Load reddits."""
+ praw = _dependable_praw_import()
+
+ reddit = praw.Reddit(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ user_agent=self.user_agent,
+ )
+
+ results: List[Document] = []
+
+ if self.mode == "subreddit":
+ for search_query in self.search_queries:
+ for category in self.categories:
+ docs = self._subreddit_posts_loader(
+ search_query=search_query, category=category, reddit=reddit
+ )
+ results.extend(docs)
+
+ elif self.mode == "username":
+ for search_query in self.search_queries:
+ for category in self.categories:
+ docs = self._user_posts_loader(
+ search_query=search_query, category=category, reddit=reddit
+ )
+ results.extend(docs)
+
+ else:
+ raise ValueError(
+ "mode not correct, please enter 'username' or 'subreddit' as mode"
+ )
+
+ return results
+
+ def _subreddit_posts_loader(
+ self, search_query: str, category: str, reddit: praw.reddit.Reddit
+ ) -> Iterable[Document]:
+ subreddit = reddit.subreddit(search_query)
+ method = getattr(subreddit, category)
+ cat_posts = method(limit=self.number_posts)
+
+ """Format reddit posts into a string."""
+ for post in cat_posts:
+ metadata = {
+ "post_subreddit": post.subreddit_name_prefixed,
+ "post_category": category,
+ "post_title": post.title,
+ "post_score": post.score,
+ "post_id": post.id,
+ "post_url": post.url,
+ "post_author": post.author,
+ }
+ yield Document(
+ page_content=post.selftext,
+ metadata=metadata,
+ )
+
+ def _user_posts_loader(
+ self, search_query: str, category: str, reddit: praw.reddit.Reddit
+ ) -> Iterable[Document]:
+ user = reddit.redditor(search_query)
+ method = getattr(user.submissions, category)
+ cat_posts = method(limit=self.number_posts)
+
+ """Format reddit posts into a string."""
+ for post in cat_posts:
+ metadata = {
+ "post_subreddit": post.subreddit_name_prefixed,
+ "post_category": category,
+ "post_title": post.title,
+ "post_score": post.score,
+ "post_id": post.id,
+ "post_url": post.url,
+ "post_author": post.author,
+ }
+ yield Document(
+ page_content=post.selftext,
+ metadata=metadata,
+ )
diff --git a/libs/community/langchain_community/document_loaders/roam.py b/libs/community/langchain_community/document_loaders/roam.py
new file mode 100644
index 00000000000..a21b827a1d5
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/roam.py
@@ -0,0 +1,25 @@
+from pathlib import Path
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class RoamLoader(BaseLoader):
+ """Load `Roam` files from a directory."""
+
+ def __init__(self, path: str):
+ """Initialize with a path."""
+ self.file_path = path
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ ps = list(Path(self.file_path).glob("**/*.md"))
+ docs = []
+ for p in ps:
+ with open(p) as f:
+ text = f.read()
+ metadata = {"source": str(p)}
+ docs.append(Document(page_content=text, metadata=metadata))
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/rocksetdb.py b/libs/community/langchain_community/document_loaders/rocksetdb.py
new file mode 100644
index 00000000000..a1783bc552d
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/rocksetdb.py
@@ -0,0 +1,125 @@
+from typing import Any, Callable, Iterator, List, Optional, Tuple
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+def default_joiner(docs: List[Tuple[str, Any]]) -> str:
+ """Default joiner for content columns."""
+ return "\n".join([doc[1] for doc in docs])
+
+
+class ColumnNotFoundError(Exception):
+ """Column not found error."""
+
+ def __init__(self, missing_key: str, query: str):
+ super().__init__(f'Column "{missing_key}" not selected in query:\n{query}')
+
+
+class RocksetLoader(BaseLoader):
+ """Load from a `Rockset` database.
+
+ To use, you should have the `rockset` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ # This code will load 3 records from the "langchain_demo"
+ # collection as Documents, with the `text` column used as
+ # the content
+
+ from langchain_community.document_loaders import RocksetLoader
+ from rockset import RocksetClient, Regions, models
+
+ loader = RocksetLoader(
+ RocksetClient(Regions.usw2a1, ""),
+ models.QueryRequestSql(
+ query="select * from langchain_demo limit 3"
+ ),
+ ["text"]
+ )
+ )
+ """
+
+ def __init__(
+ self,
+ client: Any,
+ query: Any,
+ content_keys: List[str],
+ metadata_keys: Optional[List[str]] = None,
+ content_columns_joiner: Callable[[List[Tuple[str, Any]]], str] = default_joiner,
+ ):
+ """Initialize with Rockset client.
+
+ Args:
+ client: Rockset client object.
+ query: Rockset query object.
+ content_keys: The collection columns to be written into the `page_content`
+ of the Documents.
+ metadata_keys: The collection columns to be written into the `metadata` of
+ the Documents. By default, this is all the keys in the document.
+ content_columns_joiner: Method that joins content_keys and its values into a
+ string. It's method that takes in a List[Tuple[str, Any]]],
+ representing a list of tuples of (column name, column value).
+ By default, this is a method that joins each column value with a new
+ line. This method is only relevant if there are multiple content_keys.
+ """
+ try:
+ from rockset import QueryPaginator, RocksetClient
+ from rockset.models import QueryRequestSql
+ except ImportError:
+ raise ImportError(
+ "Could not import rockset client python package. "
+ "Please install it with `pip install rockset`."
+ )
+
+ if not isinstance(client, RocksetClient):
+ raise ValueError(
+ f"client should be an instance of rockset.RocksetClient, "
+ f"got {type(client)}"
+ )
+
+ if not isinstance(query, QueryRequestSql):
+ raise ValueError(
+ f"query should be an instance of rockset.model.QueryRequestSql, "
+ f"got {type(query)}"
+ )
+
+ self.client = client
+ self.query = query
+ self.content_keys = content_keys
+ self.content_columns_joiner = content_columns_joiner
+ self.metadata_keys = metadata_keys
+ self.paginator = QueryPaginator
+ self.request_model = QueryRequestSql
+
+ try:
+ self.client.set_application("langchain")
+ except AttributeError:
+ # ignore
+ pass
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ query_results = self.client.Queries.query(
+ sql=self.query
+ ).results # execute the SQL query
+ for doc in query_results: # for each doc in the response
+ try:
+ yield Document(
+ page_content=self.content_columns_joiner(
+ [(col, doc[col]) for col in self.content_keys]
+ ),
+ metadata={col: doc[col] for col in self.metadata_keys}
+ if self.metadata_keys is not None
+ else doc,
+ ) # try to yield the Document
+ except (
+ KeyError
+ ) as e: # either content_columns or metadata_columns is invalid
+ raise ColumnNotFoundError(
+ e.args[0], self.query
+ ) # raise that the column isn't in the db schema
diff --git a/libs/community/langchain_community/document_loaders/rspace.py b/libs/community/langchain_community/document_loaders/rspace.py
new file mode 100644
index 00000000000..a4cbcaabd41
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/rspace.py
@@ -0,0 +1,131 @@
+import os
+from typing import Any, Dict, Iterator, List, Optional, Union
+
+from langchain_core.documents import Document
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.document_loaders import PyPDFLoader
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class RSpaceLoader(BaseLoader):
+ """
+ Loads content from RSpace notebooks, folders, documents or PDF Gallery files into
+ Langchain documents.
+
+ Maps RSpace document <-> Langchain Document in 1-1. PDFs are imported using PyPDF.
+
+ Requirements are rspace_client (`pip install rspace_client`) and PyPDF if importing
+ PDF docs (`pip install pypdf`).
+
+ """
+
+ def __init__(
+ self, global_id: str, api_key: Optional[str] = None, url: Optional[str] = None
+ ):
+ """api_key: RSpace API key - can also be supplied as environment variable
+ 'RSPACE_API_KEY'
+ url: str
+ The URL of your RSpace instance - can also be supplied as environment
+ variable 'RSPACE_URL'
+ global_id: str
+ The global ID of the resource to load,
+ e.g. 'SD12344' (a single document); 'GL12345'(A PDF file in the gallery);
+ 'NB4567' (a notebook); 'FL12244' (a folder)
+ """
+ args: Dict[str, Optional[str]] = {
+ "api_key": api_key,
+ "url": url,
+ "global_id": global_id,
+ }
+ verified_args: Dict[str, str] = RSpaceLoader.validate_environment(args)
+ self.api_key = verified_args["api_key"]
+ self.url = verified_args["url"]
+ self.global_id: str = verified_args["global_id"]
+
+ @classmethod
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that API key and URL exists in environment."""
+ values["api_key"] = get_from_dict_or_env(values, "api_key", "RSPACE_API_KEY")
+ values["url"] = get_from_dict_or_env(values, "url", "RSPACE_URL")
+ if "global_id" not in values or values["global_id"] is None:
+ raise ValueError(
+ "No value supplied for global_id. Please supply an RSpace global ID"
+ )
+ return values
+
+ def _create_rspace_client(self) -> Any:
+ """Create a RSpace client."""
+ try:
+ from rspace_client.eln import eln, field_content
+
+ except ImportError:
+ raise ImportError("You must run " "`pip install rspace_client`")
+
+ try:
+ eln = eln.ELNClient(self.url, self.api_key)
+ eln.get_status()
+
+ except Exception:
+ raise Exception(
+ f"Unable to initialise client - is url {self.url} or "
+ f"api key correct?"
+ )
+
+ return eln, field_content.FieldContent
+
+ def _get_doc(self, cli: Any, field_content: Any, d_id: Union[str, int]) -> Document:
+ content = ""
+ doc = cli.get_document(d_id)
+ content += f"
{doc['name']}"
+ for f in doc["fields"]:
+ content += f"{f['name']}\n"
+ fc = field_content(f["content"])
+ content += fc.get_text()
+ content += "\n"
+ return Document(
+ metadata={"source": f"rspace: {doc['name']}-{doc['globalId']}"},
+ page_content=content,
+ )
+
+ def _load_structured_doc(self) -> Iterator[Document]:
+ cli, field_content = self._create_rspace_client()
+ yield self._get_doc(cli, field_content, self.global_id)
+
+ def _load_folder_tree(self) -> Iterator[Document]:
+ cli, field_content = self._create_rspace_client()
+ if self.global_id:
+ docs_in_folder = cli.list_folder_tree(
+ folder_id=self.global_id[2:], typesToInclude=["document"]
+ )
+ doc_ids: List[int] = [d["id"] for d in docs_in_folder["records"]]
+ for doc_id in doc_ids:
+ yield self._get_doc(cli, field_content, doc_id)
+
+ def _load_pdf(self) -> Iterator[Document]:
+ cli, field_content = self._create_rspace_client()
+ file_info = cli.get_file_info(self.global_id)
+ _, ext = os.path.splitext(file_info["name"])
+ if ext.lower() == ".pdf":
+ outfile = f"{self.global_id}.pdf"
+ cli.download_file(self.global_id, outfile)
+ pdf_loader = PyPDFLoader(outfile)
+ for pdf in pdf_loader.lazy_load():
+ pdf.metadata["rspace_src"] = self.global_id
+ yield pdf
+
+ def lazy_load(self) -> Iterator[Document]:
+ if self.global_id and "GL" in self.global_id:
+ for d in self._load_pdf():
+ yield d
+ elif self.global_id and "SD" in self.global_id:
+ for d in self._load_structured_doc():
+ yield d
+ elif self.global_id and self.global_id[0:2] in ["FL", "NB"]:
+ for d in self._load_folder_tree():
+ yield d
+ else:
+ raise ValueError("Unknown global ID type")
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/rss.py b/libs/community/langchain_community/document_loaders/rss.py
new file mode 100644
index 00000000000..b245e214606
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/rss.py
@@ -0,0 +1,133 @@
+import logging
+from typing import Any, Iterator, List, Optional, Sequence
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.news import NewsURLLoader
+
+logger = logging.getLogger(__name__)
+
+
+class RSSFeedLoader(BaseLoader):
+ """Load news articles from `RSS` feeds using `Unstructured`.
+
+ Args:
+ urls: URLs for RSS feeds to load. Each articles in the feed is loaded into its own document.
+ opml: OPML file to load feed urls from. Only one of urls or opml should be provided. The value
+ can be a URL string, or OPML markup contents as byte or string.
+ continue_on_failure: If True, continue loading documents even if
+ loading fails for a particular URL.
+ show_progress_bar: If True, use tqdm to show a loading progress bar. Requires
+ tqdm to be installed, ``pip install tqdm``.
+ **newsloader_kwargs: Any additional named arguments to pass to
+ NewsURLLoader.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import RSSFeedLoader
+
+ loader = RSSFeedLoader(
+ urls=["", ""],
+ )
+ docs = loader.load()
+
+ The loader uses feedparser to parse RSS feeds. The feedparser library is not installed by default so you should
+ install it if using this loader:
+ https://pythonhosted.org/feedparser/
+
+ If you use OPML, you should also install listparser:
+ https://pythonhosted.org/listparser/
+
+ Finally, newspaper is used to process each article:
+ https://newspaper.readthedocs.io/en/latest/
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ urls: Optional[Sequence[str]] = None,
+ opml: Optional[str] = None,
+ continue_on_failure: bool = True,
+ show_progress_bar: bool = False,
+ **newsloader_kwargs: Any,
+ ) -> None:
+ """Initialize with urls or OPML."""
+ if (urls is None) == (
+ opml is None
+ ): # This is True if both are None or neither is None
+ raise ValueError(
+ "Provide either the urls or the opml argument, but not both."
+ )
+ self.urls = urls
+ self.opml = opml
+ self.continue_on_failure = continue_on_failure
+ self.show_progress_bar = show_progress_bar
+ self.newsloader_kwargs = newsloader_kwargs
+
+ def load(self) -> List[Document]:
+ iter = self.lazy_load()
+ if self.show_progress_bar:
+ try:
+ from tqdm import tqdm
+ except ImportError as e:
+ raise ImportError(
+ "Package tqdm must be installed if show_progress_bar=True. "
+ "Please install with 'pip install tqdm' or set "
+ "show_progress_bar=False."
+ ) from e
+ iter = tqdm(iter)
+ return list(iter)
+
+ @property
+ def _get_urls(self) -> Sequence[str]:
+ if self.urls:
+ return self.urls
+ try:
+ import listparser
+ except ImportError as e:
+ raise ImportError(
+ "Package listparser must be installed if the opml arg is used. "
+ "Please install with 'pip install listparser' or use the "
+ "urls arg instead."
+ ) from e
+ rss = listparser.parse(self.opml)
+ return [feed.url for feed in rss.feeds]
+
+ def lazy_load(self) -> Iterator[Document]:
+ try:
+ import feedparser # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "feedparser package not found, please install it with "
+ "`pip install feedparser`"
+ )
+
+ for url in self._get_urls:
+ try:
+ feed = feedparser.parse(url)
+ if getattr(feed, "bozo", False):
+ raise ValueError(
+ f"Error fetching {url}, exception: {feed.bozo_exception}"
+ )
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.error(f"Error fetching {url}, exception: {e}")
+ continue
+ else:
+ raise e
+ try:
+ for entry in feed.entries:
+ loader = NewsURLLoader(
+ urls=[entry.link],
+ **self.newsloader_kwargs,
+ )
+ article = loader.load()[0]
+ article.metadata["feed"] = url
+ yield article
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.error(f"Error processing entry {entry.link}, exception: {e}")
+ continue
+ else:
+ raise e
diff --git a/libs/community/langchain_community/document_loaders/rst.py b/libs/community/langchain_community/document_loaders/rst.py
new file mode 100644
index 00000000000..103b24414e0
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/rst.py
@@ -0,0 +1,53 @@
+"""Loads RST files."""
+from typing import Any, List
+
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ validate_unstructured_version,
+)
+
+
+class UnstructuredRSTLoader(UnstructuredFileLoader):
+ """Load `RST` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredRSTLoader
+
+ loader = UnstructuredRSTLoader(
+ "example.rst", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-rst
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ """
+ Initialize with a file path.
+
+ Args:
+ file_path: The path to the file to load.
+ mode: The mode to use for partitioning. See unstructured for details.
+ Defaults to "single".
+ **unstructured_kwargs: Additional keyword arguments to pass
+ to unstructured.
+ """
+ validate_unstructured_version(min_unstructured_version="0.7.5")
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.rst import partition_rst
+
+ return partition_rst(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/rtf.py b/libs/community/langchain_community/document_loaders/rtf.py
new file mode 100644
index 00000000000..3fe2731684c
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/rtf.py
@@ -0,0 +1,59 @@
+"""Loads rich text files."""
+from typing import Any, List
+
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ satisfies_min_unstructured_version,
+)
+
+
+class UnstructuredRTFLoader(UnstructuredFileLoader):
+ """Load `RTF` files using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredRTFLoader
+
+ loader = UnstructuredRTFLoader(
+ "example.rtf", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-rtf
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ """
+ Initialize with a file path.
+
+ Args:
+ file_path: The path to the file to load.
+ mode: The mode to use for partitioning. See unstructured for details.
+ Defaults to "single".
+ **unstructured_kwargs: Additional keyword arguments to pass
+ to unstructured.
+ """
+ min_unstructured_version = "0.5.12"
+ if not satisfies_min_unstructured_version(min_unstructured_version):
+ raise ValueError(
+ "Partitioning rtf files is only supported in "
+ f"unstructured>={min_unstructured_version}."
+ )
+
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.rtf import partition_rtf
+
+ return partition_rtf(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/s3_directory.py b/libs/community/langchain_community/document_loaders/s3_directory.py
new file mode 100644
index 00000000000..9885418ec8a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/s3_directory.py
@@ -0,0 +1,137 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Optional, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.s3_file import S3FileLoader
+
+if TYPE_CHECKING:
+ import botocore
+
+
+class S3DirectoryLoader(BaseLoader):
+ """Load from `Amazon AWS S3` directory."""
+
+ def __init__(
+ self,
+ bucket: str,
+ prefix: str = "",
+ *,
+ region_name: Optional[str] = None,
+ api_version: Optional[str] = None,
+ use_ssl: Optional[bool] = True,
+ verify: Union[str, bool, None] = None,
+ endpoint_url: Optional[str] = None,
+ aws_access_key_id: Optional[str] = None,
+ aws_secret_access_key: Optional[str] = None,
+ aws_session_token: Optional[str] = None,
+ boto_config: Optional[botocore.client.Config] = None,
+ ):
+ """Initialize with bucket and key name.
+
+ :param bucket: The name of the S3 bucket.
+ :param prefix: The prefix of the S3 key. Defaults to "".
+
+ :param region_name: The name of the region associated with the client.
+ A client is associated with a single region.
+
+ :param api_version: The API version to use. By default, botocore will
+ use the latest API version when creating a client. You only need
+ to specify this parameter if you want to use a previous API version
+ of the client.
+
+ :param use_ssl: Whether to use SSL. By default, SSL is used.
+ Note that not all services support non-ssl connections.
+
+ :param verify: Whether to verify SSL certificates.
+ By default SSL certificates are verified. You can provide the
+ following values:
+
+ * False - do not validate SSL certificates. SSL will still be
+ used (unless use_ssl is False), but SSL certificates
+ will not be verified.
+ * path/to/cert/bundle.pem - A filename of the CA cert bundle to
+ uses. You can specify this argument if you want to use a
+ different CA cert bundle than the one used by botocore.
+
+ :param endpoint_url: The complete URL to use for the constructed
+ client. Normally, botocore will automatically construct the
+ appropriate URL to use when communicating with a service. You can
+ specify a complete URL (including the "http/https" scheme) to
+ override this behavior. If this value is provided, then
+ ``use_ssl`` is ignored.
+
+ :param aws_access_key_id: The access key to use when creating
+ the client. This is entirely optional, and if not provided,
+ the credentials configured for the session will automatically
+ be used. You only need to provide this argument if you want
+ to override the credentials used for this specific client.
+
+ :param aws_secret_access_key: The secret key to use when creating
+ the client. Same semantics as aws_access_key_id above.
+
+ :param aws_session_token: The session token to use when creating
+ the client. Same semantics as aws_access_key_id above.
+
+ :type boto_config: botocore.client.Config
+ :param boto_config: Advanced boto3 client configuration options. If a value
+ is specified in the client config, its value will take precedence
+ over environment variables and configuration values, but not over
+ a value passed explicitly to the method. If a default config
+ object is set on the session, the config object used when creating
+ the client will be the result of calling ``merge()`` on the
+ default config with the config provided to this call.
+ """
+ self.bucket = bucket
+ self.prefix = prefix
+ self.region_name = region_name
+ self.api_version = api_version
+ self.use_ssl = use_ssl
+ self.verify = verify
+ self.endpoint_url = endpoint_url
+ self.aws_access_key_id = aws_access_key_id
+ self.aws_secret_access_key = aws_secret_access_key
+ self.aws_session_token = aws_session_token
+ self.boto_config = boto_config
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ try:
+ import boto3
+ except ImportError:
+ raise ImportError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ s3 = boto3.resource(
+ "s3",
+ region_name=self.region_name,
+ api_version=self.api_version,
+ use_ssl=self.use_ssl,
+ verify=self.verify,
+ endpoint_url=self.endpoint_url,
+ aws_access_key_id=self.aws_access_key_id,
+ aws_secret_access_key=self.aws_secret_access_key,
+ aws_session_token=self.aws_session_token,
+ config=self.boto_config,
+ )
+ bucket = s3.Bucket(self.bucket)
+ docs = []
+ for obj in bucket.objects.filter(Prefix=self.prefix):
+ loader = S3FileLoader(
+ self.bucket,
+ obj.key,
+ region_name=self.region_name,
+ api_version=self.api_version,
+ use_ssl=self.use_ssl,
+ verify=self.verify,
+ endpoint_url=self.endpoint_url,
+ aws_access_key_id=self.aws_access_key_id,
+ aws_secret_access_key=self.aws_secret_access_key,
+ aws_session_token=self.aws_session_token,
+ boto_config=self.boto_config,
+ )
+ docs.extend(loader.load())
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/s3_file.py b/libs/community/langchain_community/document_loaders/s3_file.py
new file mode 100644
index 00000000000..eaca0761f2b
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/s3_file.py
@@ -0,0 +1,129 @@
+from __future__ import annotations
+
+import os
+import tempfile
+from typing import TYPE_CHECKING, List, Optional, Union
+
+from langchain_community.document_loaders.unstructured import UnstructuredBaseLoader
+
+if TYPE_CHECKING:
+ import botocore
+
+
+class S3FileLoader(UnstructuredBaseLoader):
+ """Load from `Amazon AWS S3` file."""
+
+ def __init__(
+ self,
+ bucket: str,
+ key: str,
+ *,
+ region_name: Optional[str] = None,
+ api_version: Optional[str] = None,
+ use_ssl: Optional[bool] = True,
+ verify: Union[str, bool, None] = None,
+ endpoint_url: Optional[str] = None,
+ aws_access_key_id: Optional[str] = None,
+ aws_secret_access_key: Optional[str] = None,
+ aws_session_token: Optional[str] = None,
+ boto_config: Optional[botocore.client.Config] = None,
+ ):
+ """Initialize with bucket and key name.
+
+ :param bucket: The name of the S3 bucket.
+ :param key: The key of the S3 object.
+
+ :param region_name: The name of the region associated with the client.
+ A client is associated with a single region.
+
+ :param api_version: The API version to use. By default, botocore will
+ use the latest API version when creating a client. You only need
+ to specify this parameter if you want to use a previous API version
+ of the client.
+
+ :param use_ssl: Whether or not to use SSL. By default, SSL is used.
+ Note that not all services support non-ssl connections.
+
+ :param verify: Whether or not to verify SSL certificates.
+ By default SSL certificates are verified. You can provide the
+ following values:
+
+ * False - do not validate SSL certificates. SSL will still be
+ used (unless use_ssl is False), but SSL certificates
+ will not be verified.
+ * path/to/cert/bundle.pem - A filename of the CA cert bundle to
+ uses. You can specify this argument if you want to use a
+ different CA cert bundle than the one used by botocore.
+
+ :param endpoint_url: The complete URL to use for the constructed
+ client. Normally, botocore will automatically construct the
+ appropriate URL to use when communicating with a service. You can
+ specify a complete URL (including the "http/https" scheme) to
+ override this behavior. If this value is provided, then
+ ``use_ssl`` is ignored.
+
+ :param aws_access_key_id: The access key to use when creating
+ the client. This is entirely optional, and if not provided,
+ the credentials configured for the session will automatically
+ be used. You only need to provide this argument if you want
+ to override the credentials used for this specific client.
+
+ :param aws_secret_access_key: The secret key to use when creating
+ the client. Same semantics as aws_access_key_id above.
+
+ :param aws_session_token: The session token to use when creating
+ the client. Same semantics as aws_access_key_id above.
+
+ :type boto_config: botocore.client.Config
+ :param boto_config: Advanced boto3 client configuration options. If a value
+ is specified in the client config, its value will take precedence
+ over environment variables and configuration values, but not over
+ a value passed explicitly to the method. If a default config
+ object is set on the session, the config object used when creating
+ the client will be the result of calling ``merge()`` on the
+ default config with the config provided to this call.
+ """
+ super().__init__()
+ self.bucket = bucket
+ self.key = key
+ self.region_name = region_name
+ self.api_version = api_version
+ self.use_ssl = use_ssl
+ self.verify = verify
+ self.endpoint_url = endpoint_url
+ self.aws_access_key_id = aws_access_key_id
+ self.aws_secret_access_key = aws_secret_access_key
+ self.aws_session_token = aws_session_token
+ self.boto_config = boto_config
+
+ def _get_elements(self) -> List:
+ """Get elements."""
+ from unstructured.partition.auto import partition
+
+ try:
+ import boto3
+ except ImportError:
+ raise ImportError(
+ "Could not import `boto3` python package. "
+ "Please install it with `pip install boto3`."
+ )
+ s3 = boto3.client(
+ "s3",
+ region_name=self.region_name,
+ api_version=self.api_version,
+ use_ssl=self.use_ssl,
+ verify=self.verify,
+ endpoint_url=self.endpoint_url,
+ aws_access_key_id=self.aws_access_key_id,
+ aws_secret_access_key=self.aws_secret_access_key,
+ aws_session_token=self.aws_session_token,
+ config=self.boto_config,
+ )
+ with tempfile.TemporaryDirectory() as temp_dir:
+ file_path = f"{temp_dir}/{self.key}"
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ s3.download_file(self.bucket, self.key, file_path)
+ return partition(filename=file_path)
+
+ def _get_metadata(self) -> dict:
+ return {"source": f"s3://{self.bucket}/{self.key}"}
diff --git a/libs/community/langchain_community/document_loaders/sharepoint.py b/libs/community/langchain_community/document_loaders/sharepoint.py
new file mode 100644
index 00000000000..211760fe3f8
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/sharepoint.py
@@ -0,0 +1,60 @@
+"""Loader that loads data from Sharepoint Document Library"""
+from __future__ import annotations
+
+from typing import Iterator, List, Optional, Sequence
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.document_loaders.base_o365 import (
+ O365BaseLoader,
+ _FileType,
+)
+from langchain_community.document_loaders.parsers.registry import get_parser
+
+
+class SharePointLoader(O365BaseLoader):
+ """Load from `SharePoint`."""
+
+ document_library_id: str = Field(...)
+ """ The ID of the SharePoint document library to load data from."""
+ folder_path: Optional[str] = None
+ """ The path to the folder to load data from."""
+ object_ids: Optional[List[str]] = None
+ """ The IDs of the objects to load data from."""
+
+ @property
+ def _file_types(self) -> Sequence[_FileType]:
+ """Return supported file types."""
+ return _FileType.DOC, _FileType.DOCX, _FileType.PDF
+
+ @property
+ def _scopes(self) -> List[str]:
+ """Return required scopes."""
+ return ["sharepoint", "basic"]
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Load documents lazily. Use this when working at a large scale."""
+ try:
+ from O365.drive import Drive, Folder
+ except ImportError:
+ raise ImportError(
+ "O365 package not found, please install it with `pip install o365`"
+ )
+ drive = self._auth().storage().get_drive(self.document_library_id)
+ if not isinstance(drive, Drive):
+ raise ValueError(f"There isn't a Drive with id {self.document_library_id}.")
+ blob_parser = get_parser("default")
+ if self.folder_path:
+ target_folder = drive.get_item_by_path(self.folder_path)
+ if not isinstance(target_folder, Folder):
+ raise ValueError(f"There isn't a folder with path {self.folder_path}.")
+ for blob in self._load_from_folder(target_folder):
+ yield from blob_parser.lazy_parse(blob)
+ if self.object_ids:
+ for blob in self._load_from_object_ids(drive, self.object_ids):
+ yield from blob_parser.lazy_parse(blob)
+
+ def load(self) -> List[Document]:
+ """Load all documents."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/sitemap.py b/libs/community/langchain_community/document_loaders/sitemap.py
new file mode 100644
index 00000000000..fe08acb9bb1
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/sitemap.py
@@ -0,0 +1,220 @@
+import itertools
+import re
+from typing import Any, Callable, Generator, Iterable, List, Optional, Tuple
+from urllib.parse import urlparse
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+
+def _default_parsing_function(content: Any) -> str:
+ return str(content.get_text())
+
+
+def _default_meta_function(meta: dict, _content: Any) -> dict:
+ return {"source": meta["loc"], **meta}
+
+
+def _batch_block(iterable: Iterable, size: int) -> Generator[List[dict], None, None]:
+ it = iter(iterable)
+ while item := list(itertools.islice(it, size)):
+ yield item
+
+
+def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
+ """Extract the scheme + domain from a given URL.
+
+ Args:
+ url (str): The input URL.
+
+ Returns:
+ return a 2-tuple of scheme and domain
+ """
+ parsed_uri = urlparse(url)
+ return parsed_uri.scheme, parsed_uri.netloc
+
+
+class SitemapLoader(WebBaseLoader):
+ """Load a sitemap and its URLs.
+
+ **Security Note**: This loader can be used to load all URLs specified in a sitemap.
+ If a malicious actor gets access to the sitemap, they could force
+ the server to load URLs from other domains by modifying the sitemap.
+ This could lead to server-side request forgery (SSRF) attacks; e.g.,
+ with the attacker forcing the server to load URLs from internal
+ service endpoints that are not publicly accessible. While the attacker
+ may not immediately gain access to this data, this data could leak
+ into downstream systems (e.g., data loader is used to load data for indexing).
+
+ This loader is a crawler and web crawlers should generally NOT be deployed
+ with network access to any internal servers.
+
+ Control access to who can submit crawling requests and what network access
+ the crawler has.
+
+ By default, the loader will only load URLs from the same domain as the sitemap
+ if the site map is not a local file. This can be disabled by setting
+ restrict_to_same_domain to False (not recommended).
+
+ If the site map is a local file, no such risk mitigation is applied by default.
+
+ Use the filter URLs argument to limit which URLs can be loaded.
+
+ See https://python.langchain.com/docs/security
+ """
+
+ def __init__(
+ self,
+ web_path: str,
+ filter_urls: Optional[List[str]] = None,
+ parsing_function: Optional[Callable] = None,
+ blocksize: Optional[int] = None,
+ blocknum: int = 0,
+ meta_function: Optional[Callable] = None,
+ is_local: bool = False,
+ continue_on_failure: bool = False,
+ restrict_to_same_domain: bool = True,
+ **kwargs: Any,
+ ):
+ """Initialize with webpage path and optional filter URLs.
+
+ Args:
+ web_path: url of the sitemap. can also be a local path
+ filter_urls: a list of regexes. If specified, only
+ URLS that match one of the filter URLs will be loaded.
+ *WARNING* The filter URLs are interpreted as regular expressions.
+ Remember to escape special characters if you do not want them to be
+ interpreted as regular expression syntax. For example, `.` appears
+ frequently in URLs and should be escaped if you want to match a literal
+ `.` rather than any character.
+ restrict_to_same_domain takes precedence over filter_urls when
+ restrict_to_same_domain is True and the sitemap is not a local file.
+ parsing_function: Function to parse bs4.Soup output
+ blocksize: number of sitemap locations per block
+ blocknum: the number of the block that should be loaded - zero indexed.
+ Default: 0
+ meta_function: Function to parse bs4.Soup output for metadata
+ remember when setting this method to also copy metadata["loc"]
+ to metadata["source"] if you are using this field
+ is_local: whether the sitemap is a local file. Default: False
+ continue_on_failure: whether to continue loading the sitemap if an error
+ occurs loading a url, emitting a warning instead of raising an
+ exception. Setting this to True makes the loader more robust, but also
+ may result in missing data. Default: False
+ restrict_to_same_domain: whether to restrict loading to URLs to the same
+ domain as the sitemap. Attention: This is only applied if the sitemap
+ is not a local file!
+ """
+
+ if blocksize is not None and blocksize < 1:
+ raise ValueError("Sitemap blocksize should be at least 1")
+
+ if blocknum < 0:
+ raise ValueError("Sitemap blocknum can not be lower then 0")
+
+ try:
+ import lxml # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "lxml package not found, please install it with `pip install lxml`"
+ )
+
+ super().__init__(web_paths=[web_path], **kwargs)
+
+ # Define a list of URL patterns (interpreted as regular expressions) that
+ # will be allowed to be loaded.
+ # restrict_to_same_domain takes precedence over filter_urls when
+ # restrict_to_same_domain is True and the sitemap is not a local file.
+ self.allow_url_patterns = filter_urls
+ self.restrict_to_same_domain = restrict_to_same_domain
+ self.parsing_function = parsing_function or _default_parsing_function
+ self.meta_function = meta_function or _default_meta_function
+ self.blocksize = blocksize
+ self.blocknum = blocknum
+ self.is_local = is_local
+ self.continue_on_failure = continue_on_failure
+
+ def parse_sitemap(self, soup: Any) -> List[dict]:
+ """Parse sitemap xml and load into a list of dicts.
+
+ Args:
+ soup: BeautifulSoup object.
+
+ Returns:
+ List of dicts.
+ """
+ els = []
+ for url in soup.find_all("url"):
+ loc = url.find("loc")
+ if not loc:
+ continue
+
+ # Strip leading and trailing whitespace and newlines
+ loc_text = loc.text.strip()
+
+ if self.restrict_to_same_domain and not self.is_local:
+ if _extract_scheme_and_domain(loc_text) != _extract_scheme_and_domain(
+ self.web_path
+ ):
+ continue
+
+ if self.allow_url_patterns and not any(
+ re.match(regexp_pattern, loc_text)
+ for regexp_pattern in self.allow_url_patterns
+ ):
+ continue
+
+ els.append(
+ {
+ tag: prop.text
+ for tag in ["loc", "lastmod", "changefreq", "priority"]
+ if (prop := url.find(tag))
+ }
+ )
+
+ for sitemap in soup.find_all("sitemap"):
+ loc = sitemap.find("loc")
+ if not loc:
+ continue
+ soup_child = self.scrape_all([loc.text], "xml")[0]
+
+ els.extend(self.parse_sitemap(soup_child))
+ return els
+
+ def load(self) -> List[Document]:
+ """Load sitemap."""
+ if self.is_local:
+ try:
+ import bs4
+ except ImportError:
+ raise ImportError(
+ "beautifulsoup4 package not found, please install it"
+ " with `pip install beautifulsoup4`"
+ )
+ fp = open(self.web_path)
+ soup = bs4.BeautifulSoup(fp, "xml")
+ else:
+ soup = self._scrape(self.web_path, parser="xml")
+
+ els = self.parse_sitemap(soup)
+
+ if self.blocksize is not None:
+ elblocks = list(_batch_block(els, self.blocksize))
+ blockcount = len(elblocks)
+ if blockcount - 1 < self.blocknum:
+ raise ValueError(
+ "Selected sitemap does not contain enough blocks for given blocknum"
+ )
+ else:
+ els = elblocks[self.blocknum]
+
+ results = self.scrape_all([el["loc"].strip() for el in els if "loc" in el])
+
+ return [
+ Document(
+ page_content=self.parsing_function(results[i]),
+ metadata=self.meta_function(els[i], results[i]),
+ )
+ for i in range(len(results))
+ ]
diff --git a/libs/community/langchain_community/document_loaders/slack_directory.py b/libs/community/langchain_community/document_loaders/slack_directory.py
new file mode 100644
index 00000000000..7c11535809f
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/slack_directory.py
@@ -0,0 +1,112 @@
+import json
+import zipfile
+from pathlib import Path
+from typing import Dict, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class SlackDirectoryLoader(BaseLoader):
+ """Load from a `Slack` directory dump."""
+
+ def __init__(self, zip_path: str, workspace_url: Optional[str] = None):
+ """Initialize the SlackDirectoryLoader.
+
+ Args:
+ zip_path (str): The path to the Slack directory dump zip file.
+ workspace_url (Optional[str]): The Slack workspace URL.
+ Including the URL will turn
+ sources into links. Defaults to None.
+ """
+ self.zip_path = Path(zip_path)
+ self.workspace_url = workspace_url
+ self.channel_id_map = self._get_channel_id_map(self.zip_path)
+
+ @staticmethod
+ def _get_channel_id_map(zip_path: Path) -> Dict[str, str]:
+ """Get a dictionary mapping channel names to their respective IDs."""
+ with zipfile.ZipFile(zip_path, "r") as zip_file:
+ try:
+ with zip_file.open("channels.json", "r") as f:
+ channels = json.load(f)
+ return {channel["name"]: channel["id"] for channel in channels}
+ except KeyError:
+ return {}
+
+ def load(self) -> List[Document]:
+ """Load and return documents from the Slack directory dump."""
+ docs = []
+ with zipfile.ZipFile(self.zip_path, "r") as zip_file:
+ for channel_path in zip_file.namelist():
+ channel_name = Path(channel_path).parent.name
+ if not channel_name:
+ continue
+ if channel_path.endswith(".json"):
+ messages = self._read_json(zip_file, channel_path)
+ for message in messages:
+ document = self._convert_message_to_document(
+ message, channel_name
+ )
+ docs.append(document)
+ return docs
+
+ def _read_json(self, zip_file: zipfile.ZipFile, file_path: str) -> List[dict]:
+ """Read JSON data from a zip subfile."""
+ with zip_file.open(file_path, "r") as f:
+ data = json.load(f)
+ return data
+
+ def _convert_message_to_document(
+ self, message: dict, channel_name: str
+ ) -> Document:
+ """
+ Convert a message to a Document object.
+
+ Args:
+ message (dict): A message in the form of a dictionary.
+ channel_name (str): The name of the channel the message belongs to.
+
+ Returns:
+ Document: A Document object representing the message.
+ """
+ text = message.get("text", "")
+ metadata = self._get_message_metadata(message, channel_name)
+ return Document(
+ page_content=text,
+ metadata=metadata,
+ )
+
+ def _get_message_metadata(self, message: dict, channel_name: str) -> dict:
+ """Create and return metadata for a given message and channel."""
+ timestamp = message.get("ts", "")
+ user = message.get("user", "")
+ source = self._get_message_source(channel_name, user, timestamp)
+ return {
+ "source": source,
+ "channel": channel_name,
+ "timestamp": timestamp,
+ "user": user,
+ }
+
+ def _get_message_source(self, channel_name: str, user: str, timestamp: str) -> str:
+ """
+ Get the message source as a string.
+
+ Args:
+ channel_name (str): The name of the channel the message belongs to.
+ user (str): The user ID who sent the message.
+ timestamp (str): The timestamp of the message.
+
+ Returns:
+ str: The message source.
+ """
+ if self.workspace_url:
+ channel_id = self.channel_id_map.get(channel_name, "")
+ return (
+ f"{self.workspace_url}/archives/{channel_id}"
+ + f"/p{timestamp.replace('.', '')}"
+ )
+ else:
+ return f"{channel_name} - {user} - {timestamp}"
diff --git a/libs/community/langchain_community/document_loaders/snowflake_loader.py b/libs/community/langchain_community/document_loaders/snowflake_loader.py
new file mode 100644
index 00000000000..c0e479f6105
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/snowflake_loader.py
@@ -0,0 +1,128 @@
+from __future__ import annotations
+
+from typing import Any, Dict, Iterator, List, Optional, Tuple
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class SnowflakeLoader(BaseLoader):
+ """Load from `Snowflake` API.
+
+ Each document represents one row of the result. The `page_content_columns`
+ are written into the `page_content` of the document. The `metadata_columns`
+ are written into the `metadata` of the document. By default, all columns
+ are written into the `page_content` and none into the `metadata`.
+
+ """
+
+ def __init__(
+ self,
+ query: str,
+ user: str,
+ password: str,
+ account: str,
+ warehouse: str,
+ role: str,
+ database: str,
+ schema: str,
+ parameters: Optional[Dict[str, Any]] = None,
+ page_content_columns: Optional[List[str]] = None,
+ metadata_columns: Optional[List[str]] = None,
+ ):
+ """Initialize Snowflake document loader.
+
+ Args:
+ query: The query to run in Snowflake.
+ user: Snowflake user.
+ password: Snowflake password.
+ account: Snowflake account.
+ warehouse: Snowflake warehouse.
+ role: Snowflake role.
+ database: Snowflake database
+ schema: Snowflake schema
+ parameters: Optional. Parameters to pass to the query.
+ page_content_columns: Optional. Columns written to Document `page_content`.
+ metadata_columns: Optional. Columns written to Document `metadata`.
+ """
+ self.query = query
+ self.user = user
+ self.password = password
+ self.account = account
+ self.warehouse = warehouse
+ self.role = role
+ self.database = database
+ self.schema = schema
+ self.parameters = parameters
+ self.page_content_columns = (
+ page_content_columns if page_content_columns is not None else ["*"]
+ )
+ self.metadata_columns = metadata_columns if metadata_columns is not None else []
+
+ def _execute_query(self) -> List[Dict[str, Any]]:
+ try:
+ import snowflake.connector
+ except ImportError as ex:
+ raise ImportError(
+ "Could not import snowflake-connector-python package. "
+ "Please install it with `pip install snowflake-connector-python`."
+ ) from ex
+
+ conn = snowflake.connector.connect(
+ user=self.user,
+ password=self.password,
+ account=self.account,
+ warehouse=self.warehouse,
+ role=self.role,
+ database=self.database,
+ schema=self.schema,
+ parameters=self.parameters,
+ )
+ try:
+ cur = conn.cursor()
+ cur.execute("USE DATABASE " + self.database)
+ cur.execute("USE SCHEMA " + self.schema)
+ cur.execute(self.query, self.parameters)
+ query_result = cur.fetchall()
+ column_names = [column[0] for column in cur.description]
+ query_result = [dict(zip(column_names, row)) for row in query_result]
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ query_result = []
+ finally:
+ cur.close()
+ return query_result
+
+ def _get_columns(
+ self, query_result: List[Dict[str, Any]]
+ ) -> Tuple[List[str], List[str]]:
+ page_content_columns = (
+ self.page_content_columns if self.page_content_columns else []
+ )
+ metadata_columns = self.metadata_columns if self.metadata_columns else []
+ if page_content_columns is None and query_result:
+ page_content_columns = list(query_result[0].keys())
+ if metadata_columns is None:
+ metadata_columns = []
+ return page_content_columns or [], metadata_columns
+
+ def lazy_load(self) -> Iterator[Document]:
+ query_result = self._execute_query()
+ if isinstance(query_result, Exception):
+ print(f"An error occurred during the query: {query_result}")
+ return []
+ page_content_columns, metadata_columns = self._get_columns(query_result)
+ if "*" in page_content_columns:
+ page_content_columns = list(query_result[0].keys())
+ for row in query_result:
+ page_content = "\n".join(
+ f"{k}: {v}" for k, v in row.items() if k in page_content_columns
+ )
+ metadata = {k: v for k, v in row.items() if k in metadata_columns}
+ doc = Document(page_content=page_content, metadata=metadata)
+ yield doc
+
+ def load(self) -> List[Document]:
+ """Load data into document objects."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/spreedly.py b/libs/community/langchain_community/document_loaders/spreedly.py
new file mode 100644
index 00000000000..a5af492255c
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/spreedly.py
@@ -0,0 +1,55 @@
+import json
+import urllib.request
+from typing import List
+
+from langchain_core.documents import Document
+from langchain_core.utils import stringify_dict
+
+from langchain_community.document_loaders.base import BaseLoader
+
+SPREEDLY_ENDPOINTS = {
+ "gateways_options": "https://core.spreedly.com/v1/gateways_options.json",
+ "gateways": "https://core.spreedly.com/v1/gateways.json",
+ "receivers_options": "https://core.spreedly.com/v1/receivers_options.json",
+ "receivers": "https://core.spreedly.com/v1/receivers.json",
+ "payment_methods": "https://core.spreedly.com/v1/payment_methods.json",
+ "certificates": "https://core.spreedly.com/v1/certificates.json",
+ "transactions": "https://core.spreedly.com/v1/transactions.json",
+ "environments": "https://core.spreedly.com/v1/environments.json",
+}
+
+
+class SpreedlyLoader(BaseLoader):
+ """Load from `Spreedly` API."""
+
+ def __init__(self, access_token: str, resource: str) -> None:
+ """Initialize with an access token and a resource.
+
+ Args:
+ access_token: The access token.
+ resource: The resource.
+ """
+ self.access_token = access_token
+ self.resource = resource
+ self.headers = {
+ "Authorization": f"Bearer {self.access_token}",
+ "Accept": "application/json",
+ }
+
+ def _make_request(self, url: str) -> List[Document]:
+ request = urllib.request.Request(url, headers=self.headers)
+
+ with urllib.request.urlopen(request) as response:
+ json_data = json.loads(response.read().decode())
+ text = stringify_dict(json_data)
+ metadata = {"source": url}
+ return [Document(page_content=text, metadata=metadata)]
+
+ def _get_resource(self) -> List[Document]:
+ endpoint = SPREEDLY_ENDPOINTS.get(self.resource)
+ if endpoint is None:
+ return []
+ return self._make_request(endpoint)
+
+ def load(self) -> List[Document]:
+ return self._get_resource()
diff --git a/libs/community/langchain_community/document_loaders/srt.py b/libs/community/langchain_community/document_loaders/srt.py
new file mode 100644
index 00000000000..32acecc41cd
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/srt.py
@@ -0,0 +1,28 @@
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class SRTLoader(BaseLoader):
+ """Load `.srt` (subtitle) files."""
+
+ def __init__(self, file_path: str):
+ """Initialize with a file path."""
+ try:
+ import pysrt # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "package `pysrt` not found, please install it with `pip install pysrt`"
+ )
+ self.file_path = file_path
+
+ def load(self) -> List[Document]:
+ """Load using pysrt file."""
+ import pysrt
+
+ parsed_info = pysrt.open(self.file_path)
+ text = " ".join([t.text for t in parsed_info])
+ metadata = {"source": self.file_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/stripe.py b/libs/community/langchain_community/document_loaders/stripe.py
new file mode 100644
index 00000000000..51bd04962ed
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/stripe.py
@@ -0,0 +1,52 @@
+import json
+import urllib.request
+from typing import List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.utils import get_from_env, stringify_dict
+
+from langchain_community.document_loaders.base import BaseLoader
+
+STRIPE_ENDPOINTS = {
+ "balance_transactions": "https://api.stripe.com/v1/balance_transactions",
+ "charges": "https://api.stripe.com/v1/charges",
+ "customers": "https://api.stripe.com/v1/customers",
+ "events": "https://api.stripe.com/v1/events",
+ "refunds": "https://api.stripe.com/v1/refunds",
+ "disputes": "https://api.stripe.com/v1/disputes",
+}
+
+
+class StripeLoader(BaseLoader):
+ """Load from `Stripe` API."""
+
+ def __init__(self, resource: str, access_token: Optional[str] = None) -> None:
+ """Initialize with a resource and an access token.
+
+ Args:
+ resource: The resource.
+ access_token: The access token.
+ """
+ self.resource = resource
+ access_token = access_token or get_from_env(
+ "access_token", "STRIPE_ACCESS_TOKEN"
+ )
+ self.headers = {"Authorization": f"Bearer {access_token}"}
+
+ def _make_request(self, url: str) -> List[Document]:
+ request = urllib.request.Request(url, headers=self.headers)
+
+ with urllib.request.urlopen(request) as response:
+ json_data = json.loads(response.read().decode())
+ text = stringify_dict(json_data)
+ metadata = {"source": url}
+ return [Document(page_content=text, metadata=metadata)]
+
+ def _get_resource(self) -> List[Document]:
+ endpoint = STRIPE_ENDPOINTS.get(self.resource)
+ if endpoint is None:
+ return []
+ return self._make_request(endpoint)
+
+ def load(self) -> List[Document]:
+ return self._get_resource()
diff --git a/libs/community/langchain_community/document_loaders/telegram.py b/libs/community/langchain_community/document_loaders/telegram.py
new file mode 100644
index 00000000000..5a9e3b06fd0
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/telegram.py
@@ -0,0 +1,263 @@
+from __future__ import annotations
+
+import asyncio
+import json
+from pathlib import Path
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ import pandas as pd
+ from telethon.hints import EntityLike
+
+
+def concatenate_rows(row: dict) -> str:
+ """Combine message information in a readable format ready to be used."""
+ date = row["date"]
+ sender = row["from"]
+ text = row["text"]
+ return f"{sender} on {date}: {text}\n\n"
+
+
+class TelegramChatFileLoader(BaseLoader):
+ """Load from `Telegram chat` dump."""
+
+ def __init__(self, path: str):
+ """Initialize with a path."""
+ self.file_path = path
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ p = Path(self.file_path)
+
+ with open(p, encoding="utf8") as f:
+ d = json.load(f)
+
+ text = "".join(
+ concatenate_rows(message)
+ for message in d["messages"]
+ if message["type"] == "message" and isinstance(message["text"], str)
+ )
+ metadata = {"source": str(p)}
+
+ return [Document(page_content=text, metadata=metadata)]
+
+
+def text_to_docs(text: Union[str, List[str]]) -> List[Document]:
+ """Convert a string or list of strings to a list of Documents with metadata."""
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+ if isinstance(text, str):
+ # Take a single string as one page
+ text = [text]
+ page_docs = [Document(page_content=page) for page in text]
+
+ # Add page numbers as metadata
+ for i, doc in enumerate(page_docs):
+ doc.metadata["page"] = i + 1
+
+ # Split pages into chunks
+ doc_chunks = []
+
+ for doc in page_docs:
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=800,
+ separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
+ chunk_overlap=20,
+ )
+ chunks = text_splitter.split_text(doc.page_content)
+ for i, chunk in enumerate(chunks):
+ doc = Document(
+ page_content=chunk, metadata={"page": doc.metadata["page"], "chunk": i}
+ )
+ # Add sources a metadata
+ doc.metadata["source"] = f"{doc.metadata['page']}-{doc.metadata['chunk']}"
+ doc_chunks.append(doc)
+ return doc_chunks
+
+
+class TelegramChatApiLoader(BaseLoader):
+ """Load `Telegram` chat json directory dump."""
+
+ def __init__(
+ self,
+ chat_entity: Optional[EntityLike] = None,
+ api_id: Optional[int] = None,
+ api_hash: Optional[str] = None,
+ username: Optional[str] = None,
+ file_path: str = "telegram_data.json",
+ ):
+ """Initialize with API parameters.
+
+ Args:
+ chat_entity: The chat entity to fetch data from.
+ api_id: The API ID.
+ api_hash: The API hash.
+ username: The username.
+ file_path: The file path to save the data to. Defaults to
+ "telegram_data.json".
+ """
+ self.chat_entity = chat_entity
+ self.api_id = api_id
+ self.api_hash = api_hash
+ self.username = username
+ self.file_path = file_path
+
+ async def fetch_data_from_telegram(self) -> None:
+ """Fetch data from Telegram API and save it as a JSON file."""
+ from telethon.sync import TelegramClient
+
+ data = []
+ async with TelegramClient(self.username, self.api_id, self.api_hash) as client:
+ async for message in client.iter_messages(self.chat_entity):
+ is_reply = message.reply_to is not None
+ reply_to_id = message.reply_to.reply_to_msg_id if is_reply else None
+ data.append(
+ {
+ "sender_id": message.sender_id,
+ "text": message.text,
+ "date": message.date.isoformat(),
+ "message.id": message.id,
+ "is_reply": is_reply,
+ "reply_to_id": reply_to_id,
+ }
+ )
+
+ with open(self.file_path, "w", encoding="utf-8") as f:
+ json.dump(data, f, ensure_ascii=False, indent=4)
+
+ def _get_message_threads(self, data: pd.DataFrame) -> dict:
+ """Create a dictionary of message threads from the given data.
+
+ Args:
+ data (pd.DataFrame): A DataFrame containing the conversation \
+ data with columns:
+ - message.sender_id
+ - text
+ - date
+ - message.id
+ - is_reply
+ - reply_to_id
+
+ Returns:
+ dict: A dictionary where the key is the parent message ID and \
+ the value is a list of message IDs in ascending order.
+ """
+
+ def find_replies(parent_id: int, reply_data: pd.DataFrame) -> List[int]:
+ """
+ Recursively find all replies to a given parent message ID.
+
+ Args:
+ parent_id (int): The parent message ID.
+ reply_data (pd.DataFrame): A DataFrame containing reply messages.
+
+ Returns:
+ list: A list of message IDs that are replies to the parent message ID.
+ """
+ # Find direct replies to the parent message ID
+ direct_replies = reply_data[reply_data["reply_to_id"] == parent_id][
+ "message.id"
+ ].tolist()
+
+ # Recursively find replies to the direct replies
+ all_replies = []
+ for reply_id in direct_replies:
+ all_replies += [reply_id] + find_replies(reply_id, reply_data)
+
+ return all_replies
+
+ # Filter out parent messages
+ parent_messages = data[~data["is_reply"]]
+
+ # Filter out reply messages and drop rows with NaN in 'reply_to_id'
+ reply_messages = data[data["is_reply"]].dropna(subset=["reply_to_id"])
+
+ # Convert 'reply_to_id' to integer
+ reply_messages["reply_to_id"] = reply_messages["reply_to_id"].astype(int)
+
+ # Create a dictionary of message threads with parent message IDs as keys and \
+ # lists of reply message IDs as values
+ message_threads = {
+ parent_id: [parent_id] + find_replies(parent_id, reply_messages)
+ for parent_id in parent_messages["message.id"]
+ }
+
+ return message_threads
+
+ def _combine_message_texts(
+ self, message_threads: Dict[int, List[int]], data: pd.DataFrame
+ ) -> str:
+ """
+ Combine the message texts for each parent message ID based \
+ on the list of message threads.
+
+ Args:
+ message_threads (dict): A dictionary where the key is the parent message \
+ ID and the value is a list of message IDs in ascending order.
+ data (pd.DataFrame): A DataFrame containing the conversation data:
+ - message.sender_id
+ - text
+ - date
+ - message.id
+ - is_reply
+ - reply_to_id
+
+ Returns:
+ str: A combined string of message texts sorted by date.
+ """
+ combined_text = ""
+
+ # Iterate through sorted parent message IDs
+ for parent_id, message_ids in message_threads.items():
+ # Get the message texts for the message IDs and sort them by date
+ message_texts = (
+ data[data["message.id"].isin(message_ids)]
+ .sort_values(by="date")["text"]
+ .tolist()
+ )
+ message_texts = [str(elem) for elem in message_texts]
+
+ # Combine the message texts
+ combined_text += " ".join(message_texts) + ".\n"
+
+ return combined_text.strip()
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+
+ if self.chat_entity is not None:
+ try:
+ import nest_asyncio
+
+ nest_asyncio.apply()
+ asyncio.run(self.fetch_data_from_telegram())
+ except ImportError:
+ raise ImportError(
+ """`nest_asyncio` package not found.
+ please install with `pip install nest_asyncio`
+ """
+ )
+
+ p = Path(self.file_path)
+
+ with open(p, encoding="utf8") as f:
+ d = json.load(f)
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ """`pandas` package not found.
+ please install with `pip install pandas`
+ """
+ )
+ normalized_messages = pd.json_normalize(d)
+ df = pd.DataFrame(normalized_messages)
+
+ message_threads = self._get_message_threads(df)
+ combined_texts = self._combine_message_texts(message_threads, df)
+
+ return text_to_docs(combined_texts)
diff --git a/libs/community/langchain_community/document_loaders/tencent_cos_directory.py b/libs/community/langchain_community/document_loaders/tencent_cos_directory.py
new file mode 100644
index 00000000000..b58bfb674ff
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/tencent_cos_directory.py
@@ -0,0 +1,50 @@
+from typing import Any, Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.tencent_cos_file import TencentCOSFileLoader
+
+
+class TencentCOSDirectoryLoader(BaseLoader):
+ """Load from `Tencent Cloud COS` directory."""
+
+ def __init__(self, conf: Any, bucket: str, prefix: str = ""):
+ """Initialize with COS config, bucket and prefix.
+ :param conf(CosConfig): COS config.
+ :param bucket(str): COS bucket.
+ :param prefix(str): prefix.
+ """
+ self.conf = conf
+ self.bucket = bucket
+ self.prefix = prefix
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Load documents."""
+ try:
+ from qcloud_cos import CosS3Client
+ except ImportError:
+ raise ImportError(
+ "Could not import cos-python-sdk-v5 python package. "
+ "Please install it with `pip install cos-python-sdk-v5`."
+ )
+ client = CosS3Client(self.conf)
+ contents = []
+ marker = ""
+ while True:
+ response = client.list_objects(
+ Bucket=self.bucket, Prefix=self.prefix, Marker=marker, MaxKeys=1000
+ )
+ if "Contents" in response:
+ contents.extend(response["Contents"])
+ if response["IsTruncated"] == "false":
+ break
+ marker = response["NextMarker"]
+ for content in contents:
+ if content["Key"].endswith("/"):
+ continue
+ loader = TencentCOSFileLoader(self.conf, self.bucket, content["Key"])
+ yield loader.load()[0]
diff --git a/libs/community/langchain_community/document_loaders/tencent_cos_file.py b/libs/community/langchain_community/document_loaders/tencent_cos_file.py
new file mode 100644
index 00000000000..8a8c622dd74
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/tencent_cos_file.py
@@ -0,0 +1,48 @@
+import os
+import tempfile
+from typing import Any, Iterator, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class TencentCOSFileLoader(BaseLoader):
+ """Load from `Tencent Cloud COS` file."""
+
+ def __init__(self, conf: Any, bucket: str, key: str):
+ """Initialize with COS config, bucket and key name.
+ :param conf(CosConfig): COS config.
+ :param bucket(str): COS bucket.
+ :param key(str): COS file key.
+ """
+ self.conf = conf
+ self.bucket = bucket
+ self.key = key
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Load documents."""
+ try:
+ from qcloud_cos import CosS3Client
+ except ImportError:
+ raise ImportError(
+ "Could not import cos-python-sdk-v5 python package. "
+ "Please install it with `pip install cos-python-sdk-v5`."
+ )
+
+ # Initialise a client
+ client = CosS3Client(self.conf)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ file_path = f"{temp_dir}/{self.bucket}/{self.key}"
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ # Download the file to a destination
+ client.download_file(
+ Bucket=self.bucket, Key=self.key, DestFilePath=file_path
+ )
+ loader = UnstructuredFileLoader(file_path)
+ # UnstructuredFileLoader not implement lazy_load yet
+ return iter(loader.load())
diff --git a/libs/community/langchain_community/document_loaders/tensorflow_datasets.py b/libs/community/langchain_community/document_loaders/tensorflow_datasets.py
new file mode 100644
index 00000000000..87dad82e648
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/tensorflow_datasets.py
@@ -0,0 +1,80 @@
+from typing import Callable, Dict, Iterator, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.tensorflow_datasets import TensorflowDatasets
+
+
+class TensorflowDatasetLoader(BaseLoader):
+ """Load from `TensorFlow Dataset`.
+
+ Attributes:
+ dataset_name: the name of the dataset to load
+ split_name: the name of the split to load.
+ load_max_docs: a limit to the number of loaded documents. Defaults to 100.
+ sample_to_document_function: a function that converts a dataset sample
+ into a Document
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import TensorflowDatasetLoader
+
+ def mlqaen_example_to_document(example: dict) -> Document:
+ return Document(
+ page_content=decode_to_str(example["context"]),
+ metadata={
+ "id": decode_to_str(example["id"]),
+ "title": decode_to_str(example["title"]),
+ "question": decode_to_str(example["question"]),
+ "answer": decode_to_str(example["answers"]["text"][0]),
+ },
+ )
+
+ tsds_client = TensorflowDatasetLoader(
+ dataset_name="mlqa/en",
+ split_name="test",
+ load_max_docs=100,
+ sample_to_document_function=mlqaen_example_to_document,
+ )
+
+ """
+
+ def __init__(
+ self,
+ dataset_name: str,
+ split_name: str,
+ load_max_docs: Optional[int] = 100,
+ sample_to_document_function: Optional[Callable[[Dict], Document]] = None,
+ ):
+ """Initialize the TensorflowDatasetLoader.
+
+ Args:
+ dataset_name: the name of the dataset to load
+ split_name: the name of the split to load.
+ load_max_docs: a limit to the number of loaded documents. Defaults to 100.
+ sample_to_document_function: a function that converts a dataset sample
+ into a Document.
+ """
+ self.dataset_name: str = dataset_name
+ self.split_name: str = split_name
+ self.load_max_docs = load_max_docs
+ """The maximum number of documents to load."""
+ self.sample_to_document_function: Optional[
+ Callable[[Dict], Document]
+ ] = sample_to_document_function
+ """Custom function that transform a dataset sample into a Document."""
+
+ self._tfds_client = TensorflowDatasets(
+ dataset_name=self.dataset_name,
+ split_name=self.split_name,
+ load_max_docs=self.load_max_docs,
+ sample_to_document_function=self.sample_to_document_function,
+ )
+
+ def lazy_load(self) -> Iterator[Document]:
+ yield from self._tfds_client.lazy_load()
+
+ def load(self) -> List[Document]:
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/text.py b/libs/community/langchain_community/document_loaders/text.py
new file mode 100644
index 00000000000..f497b16295c
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/text.py
@@ -0,0 +1,60 @@
+import logging
+from typing import List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.helpers import detect_file_encodings
+
+logger = logging.getLogger(__name__)
+
+
+class TextLoader(BaseLoader):
+ """Load text file.
+
+
+ Args:
+ file_path: Path to the file to load.
+
+ encoding: File encoding to use. If `None`, the file will be loaded
+ with the default system encoding.
+
+ autodetect_encoding: Whether to try to autodetect the file encoding
+ if the specified encoding fails.
+ """
+
+ def __init__(
+ self,
+ file_path: str,
+ encoding: Optional[str] = None,
+ autodetect_encoding: bool = False,
+ ):
+ """Initialize with file path."""
+ self.file_path = file_path
+ self.encoding = encoding
+ self.autodetect_encoding = autodetect_encoding
+
+ def load(self) -> List[Document]:
+ """Load from file path."""
+ text = ""
+ try:
+ with open(self.file_path, encoding=self.encoding) as f:
+ text = f.read()
+ except UnicodeDecodeError as e:
+ if self.autodetect_encoding:
+ detected_encodings = detect_file_encodings(self.file_path)
+ for encoding in detected_encodings:
+ logger.debug(f"Trying encoding: {encoding.encoding}")
+ try:
+ with open(self.file_path, encoding=encoding.encoding) as f:
+ text = f.read()
+ break
+ except UnicodeDecodeError:
+ continue
+ else:
+ raise RuntimeError(f"Error loading {self.file_path}") from e
+ except Exception as e:
+ raise RuntimeError(f"Error loading {self.file_path}") from e
+
+ metadata = {"source": self.file_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/tomarkdown.py b/libs/community/langchain_community/document_loaders/tomarkdown.py
new file mode 100644
index 00000000000..ee384f81327
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/tomarkdown.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+from typing import Iterator, List
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class ToMarkdownLoader(BaseLoader):
+ """Load `HTML` using `2markdown API`."""
+
+ def __init__(self, url: str, api_key: str):
+ """Initialize with url and api key."""
+ self.url = url
+ self.api_key = api_key
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Lazily load the file."""
+ response = requests.post(
+ "https://2markdown.com/api/2md",
+ headers={"X-Api-Key": self.api_key},
+ json={"url": self.url},
+ )
+ text = response.json()["article"]
+ metadata = {"source": self.url}
+ yield Document(page_content=text, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """Load file."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/toml.py b/libs/community/langchain_community/document_loaders/toml.py
new file mode 100644
index 00000000000..dcb59d41f43
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/toml.py
@@ -0,0 +1,47 @@
+import json
+from pathlib import Path
+from typing import Iterator, List, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+class TomlLoader(BaseLoader):
+ """Load `TOML` files.
+
+ It can load a single source file or several files in a single
+ directory.
+ """
+
+ def __init__(self, source: Union[str, Path]):
+ """Initialize the TomlLoader with a source file or directory."""
+ self.source = Path(source)
+
+ def load(self) -> List[Document]:
+ """Load and return all documents."""
+ return list(self.lazy_load())
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazily load the TOML documents from the source file or directory."""
+ import tomli
+
+ if self.source.is_file() and self.source.suffix == ".toml":
+ files = [self.source]
+ elif self.source.is_dir():
+ files = list(self.source.glob("**/*.toml"))
+ else:
+ raise ValueError("Invalid source path or file type")
+
+ for file_path in files:
+ with file_path.open("r", encoding="utf-8") as file:
+ content = file.read()
+ try:
+ data = tomli.loads(content)
+ doc = Document(
+ page_content=json.dumps(data),
+ metadata={"source": str(file_path)},
+ )
+ yield doc
+ except tomli.TOMLDecodeError as e:
+ print(f"Error parsing TOML file {file_path}: {e}")
diff --git a/libs/community/langchain_community/document_loaders/trello.py b/libs/community/langchain_community/document_loaders/trello.py
new file mode 100644
index 00000000000..4fb7a47426a
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/trello.py
@@ -0,0 +1,168 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple
+
+from langchain_core.documents import Document
+from langchain_core.utils import get_from_env
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ from trello import Board, Card, TrelloClient
+
+
+class TrelloLoader(BaseLoader):
+ """Load cards from a `Trello` board."""
+
+ def __init__(
+ self,
+ client: TrelloClient,
+ board_name: str,
+ *,
+ include_card_name: bool = True,
+ include_comments: bool = True,
+ include_checklist: bool = True,
+ card_filter: Literal["closed", "open", "all"] = "all",
+ extra_metadata: Tuple[str, ...] = ("due_date", "labels", "list", "closed"),
+ ):
+ """Initialize Trello loader.
+
+ Args:
+ client: Trello API client.
+ board_name: The name of the Trello board.
+ include_card_name: Whether to include the name of the card in the document.
+ include_comments: Whether to include the comments on the card in the
+ document.
+ include_checklist: Whether to include the checklist on the card in the
+ document.
+ card_filter: Filter on card status. Valid values are "closed", "open",
+ "all".
+ extra_metadata: List of additional metadata fields to include as document
+ metadata.Valid values are "due_date", "labels", "list", "closed".
+
+ """
+ self.client = client
+ self.board_name = board_name
+ self.include_card_name = include_card_name
+ self.include_comments = include_comments
+ self.include_checklist = include_checklist
+ self.extra_metadata = extra_metadata
+ self.card_filter = card_filter
+
+ @classmethod
+ def from_credentials(
+ cls,
+ board_name: str,
+ *,
+ api_key: Optional[str] = None,
+ token: Optional[str] = None,
+ **kwargs: Any,
+ ) -> TrelloLoader:
+ """Convenience constructor that builds TrelloClient init param for you.
+
+ Args:
+ board_name: The name of the Trello board.
+ api_key: Trello API key. Can also be specified as environment variable
+ TRELLO_API_KEY.
+ token: Trello token. Can also be specified as environment variable
+ TRELLO_TOKEN.
+ include_card_name: Whether to include the name of the card in the document.
+ include_comments: Whether to include the comments on the card in the
+ document.
+ include_checklist: Whether to include the checklist on the card in the
+ document.
+ card_filter: Filter on card status. Valid values are "closed", "open",
+ "all".
+ extra_metadata: List of additional metadata fields to include as document
+ metadata.Valid values are "due_date", "labels", "list", "closed".
+ """
+
+ try:
+ from trello import TrelloClient # type: ignore
+ except ImportError as ex:
+ raise ImportError(
+ "Could not import trello python package. "
+ "Please install it with `pip install py-trello`."
+ ) from ex
+ api_key = api_key or get_from_env("api_key", "TRELLO_API_KEY")
+ token = token or get_from_env("token", "TRELLO_TOKEN")
+ client = TrelloClient(api_key=api_key, token=token)
+ return cls(client, board_name, **kwargs)
+
+ def load(self) -> List[Document]:
+ """Loads all cards from the specified Trello board.
+
+ You can filter the cards, metadata and text included by using the optional
+ parameters.
+
+ Returns:
+ A list of documents, one for each card in the board.
+ """
+ try:
+ from bs4 import BeautifulSoup # noqa: F401
+ except ImportError as ex:
+ raise ImportError(
+ "`beautifulsoup4` package not found, please run"
+ " `pip install beautifulsoup4`"
+ ) from ex
+
+ board = self._get_board()
+ # Create a dictionary with the list IDs as keys and the list names as values
+ list_dict = {list_item.id: list_item.name for list_item in board.list_lists()}
+ # Get Cards on the board
+ cards = board.get_cards(card_filter=self.card_filter)
+ return [self._card_to_doc(card, list_dict) for card in cards]
+
+ def _get_board(self) -> Board:
+ # Find the first board with a matching name
+ board = next(
+ (b for b in self.client.list_boards() if b.name == self.board_name), None
+ )
+ if not board:
+ raise ValueError(f"Board `{self.board_name}` not found.")
+ return board
+
+ def _card_to_doc(self, card: Card, list_dict: dict) -> Document:
+ from bs4 import BeautifulSoup # type: ignore
+
+ text_content = ""
+ if self.include_card_name:
+ text_content = card.name + "\n"
+ if card.description.strip():
+ text_content += BeautifulSoup(card.description, "lxml").get_text()
+ if self.include_checklist:
+ # Get all the checklist items on the card
+ for checklist in card.checklists:
+ if checklist.items:
+ items = [
+ f"{item['name']}:{item['state']}" for item in checklist.items
+ ]
+ text_content += f"\n{checklist.name}\n" + "\n".join(items)
+
+ if self.include_comments:
+ # Get all the comments on the card
+ comments = [
+ BeautifulSoup(comment["data"]["text"], "lxml").get_text()
+ for comment in card.comments
+ ]
+ text_content += "Comments:" + "\n".join(comments)
+
+ # Default metadata fields
+ metadata = {
+ "title": card.name,
+ "id": card.id,
+ "url": card.url,
+ }
+
+ # Extra metadata fields. Card object is not subscriptable.
+ if "labels" in self.extra_metadata:
+ metadata["labels"] = [label.name for label in card.labels]
+ if "list" in self.extra_metadata:
+ if card.list_id in list_dict:
+ metadata["list"] = list_dict[card.list_id]
+ if "closed" in self.extra_metadata:
+ metadata["closed"] = card.closed
+ if "due_date" in self.extra_metadata:
+ metadata["due_date"] = card.due_date
+
+ return Document(page_content=text_content, metadata=metadata)
diff --git a/libs/community/langchain_community/document_loaders/tsv.py b/libs/community/langchain_community/document_loaders/tsv.py
new file mode 100644
index 00000000000..9bd4b4c2ed5
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/tsv.py
@@ -0,0 +1,37 @@
+from typing import Any, List
+
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ validate_unstructured_version,
+)
+
+
+class UnstructuredTSVLoader(UnstructuredFileLoader):
+ """Load `TSV` files using `Unstructured`.
+
+ Like other
+ Unstructured loaders, UnstructuredTSVLoader can be used in both
+ "single" and "elements" mode. If you use the loader in "elements"
+ mode, the TSV file will be a single Unstructured Table element.
+ If you use the loader in "elements" mode, an HTML representation
+ of the table will be available in the "text_as_html" key in the
+ document metadata.
+
+ Examples
+ --------
+ from langchain_community.document_loaders.tsv import UnstructuredTSVLoader
+
+ loader = UnstructuredTSVLoader("stanley-cups.tsv", mode="elements")
+ docs = loader.load()
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ validate_unstructured_version(min_unstructured_version="0.7.6")
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.tsv import partition_tsv
+
+ return partition_tsv(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/twitter.py b/libs/community/langchain_community/document_loaders/twitter.py
new file mode 100644
index 00000000000..85ab4e7e396
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/twitter.py
@@ -0,0 +1,110 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ import tweepy
+ from tweepy import OAuth2BearerHandler, OAuthHandler
+
+
+def _dependable_tweepy_import() -> tweepy:
+ try:
+ import tweepy
+ except ImportError:
+ raise ImportError(
+ "tweepy package not found, please install it with `pip install tweepy`"
+ )
+ return tweepy
+
+
+class TwitterTweetLoader(BaseLoader):
+ """Load `Twitter` tweets.
+
+ Read tweets of the user's Twitter handle.
+
+ First you need to go to
+ `https://developer.twitter.com/en/docs/twitter-api
+ /getting-started/getting-access-to-the-twitter-api`
+ to get your token. And create a v2 version of the app.
+ """
+
+ def __init__(
+ self,
+ auth_handler: Union[OAuthHandler, OAuth2BearerHandler],
+ twitter_users: Sequence[str],
+ number_tweets: Optional[int] = 100,
+ ):
+ self.auth = auth_handler
+ self.twitter_users = twitter_users
+ self.number_tweets = number_tweets
+
+ def load(self) -> List[Document]:
+ """Load tweets."""
+ tweepy = _dependable_tweepy_import()
+ api = tweepy.API(self.auth, parser=tweepy.parsers.JSONParser())
+
+ results: List[Document] = []
+ for username in self.twitter_users:
+ tweets = api.user_timeline(screen_name=username, count=self.number_tweets)
+ user = api.get_user(screen_name=username)
+ docs = self._format_tweets(tweets, user)
+ results.extend(docs)
+ return results
+
+ def _format_tweets(
+ self, tweets: List[Dict[str, Any]], user_info: dict
+ ) -> Iterable[Document]:
+ """Format tweets into a string."""
+ for tweet in tweets:
+ metadata = {
+ "created_at": tweet["created_at"],
+ "user_info": user_info,
+ }
+ yield Document(
+ page_content=tweet["text"],
+ metadata=metadata,
+ )
+
+ @classmethod
+ def from_bearer_token(
+ cls,
+ oauth2_bearer_token: str,
+ twitter_users: Sequence[str],
+ number_tweets: Optional[int] = 100,
+ ) -> TwitterTweetLoader:
+ """Create a TwitterTweetLoader from OAuth2 bearer token."""
+ tweepy = _dependable_tweepy_import()
+ auth = tweepy.OAuth2BearerHandler(oauth2_bearer_token)
+ return cls(
+ auth_handler=auth,
+ twitter_users=twitter_users,
+ number_tweets=number_tweets,
+ )
+
+ @classmethod
+ def from_secrets(
+ cls,
+ access_token: str,
+ access_token_secret: str,
+ consumer_key: str,
+ consumer_secret: str,
+ twitter_users: Sequence[str],
+ number_tweets: Optional[int] = 100,
+ ) -> TwitterTweetLoader:
+ """Create a TwitterTweetLoader from access tokens and secrets."""
+ tweepy = _dependable_tweepy_import()
+ auth = tweepy.OAuthHandler(
+ access_token=access_token,
+ access_token_secret=access_token_secret,
+ consumer_key=consumer_key,
+ consumer_secret=consumer_secret,
+ )
+ return cls(
+ auth_handler=auth,
+ twitter_users=twitter_users,
+ number_tweets=number_tweets,
+ )
diff --git a/libs/community/langchain_community/document_loaders/unstructured.py b/libs/community/langchain_community/document_loaders/unstructured.py
new file mode 100644
index 00000000000..9d8223ff860
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/unstructured.py
@@ -0,0 +1,382 @@
+"""Loader that uses unstructured to load files."""
+import collections
+from abc import ABC, abstractmethod
+from typing import IO, Any, Callable, Dict, List, Optional, Sequence, Union
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+def satisfies_min_unstructured_version(min_version: str) -> bool:
+ """Check if the installed `Unstructured` version exceeds the minimum version
+ for the feature in question."""
+ from unstructured.__version__ import __version__ as __unstructured_version__
+
+ min_version_tuple = tuple([int(x) for x in min_version.split(".")])
+
+ # NOTE(MthwRobinson) - enables the loader to work when you're using pre-release
+ # versions of unstructured like 0.4.17-dev1
+ _unstructured_version = __unstructured_version__.split("-")[0]
+ unstructured_version_tuple = tuple(
+ [int(x) for x in _unstructured_version.split(".")]
+ )
+
+ return unstructured_version_tuple >= min_version_tuple
+
+
+def validate_unstructured_version(min_unstructured_version: str) -> None:
+ """Raise an error if the `Unstructured` version does not exceed the
+ specified minimum."""
+ if not satisfies_min_unstructured_version(min_unstructured_version):
+ raise ValueError(
+ f"unstructured>={min_unstructured_version} is required in this loader."
+ )
+
+
+class UnstructuredBaseLoader(BaseLoader, ABC):
+ """Base Loader that uses `Unstructured`."""
+
+ def __init__(
+ self,
+ mode: str = "single",
+ post_processors: Optional[List[Callable]] = None,
+ **unstructured_kwargs: Any,
+ ):
+ """Initialize with file path."""
+ try:
+ import unstructured # noqa:F401
+ except ImportError:
+ raise ValueError(
+ "unstructured package not found, please install it with "
+ "`pip install unstructured`"
+ )
+ _valid_modes = {"single", "elements", "paged"}
+ if mode not in _valid_modes:
+ raise ValueError(
+ f"Got {mode} for `mode`, but should be one of `{_valid_modes}`"
+ )
+ self.mode = mode
+
+ if not satisfies_min_unstructured_version("0.5.4"):
+ if "strategy" in unstructured_kwargs:
+ unstructured_kwargs.pop("strategy")
+
+ self.unstructured_kwargs = unstructured_kwargs
+ self.post_processors = post_processors or []
+
+ @abstractmethod
+ def _get_elements(self) -> List:
+ """Get elements."""
+
+ @abstractmethod
+ def _get_metadata(self) -> dict:
+ """Get metadata."""
+
+ def _post_process_elements(self, elements: list) -> list:
+ """Applies post processing functions to extracted unstructured elements.
+ Post processing functions are str -> str callables are passed
+ in using the post_processors kwarg when the loader is instantiated."""
+ for element in elements:
+ for post_processor in self.post_processors:
+ element.apply(post_processor)
+ return elements
+
+ def load(self) -> List[Document]:
+ """Load file."""
+ elements = self._get_elements()
+ self._post_process_elements(elements)
+ if self.mode == "elements":
+ docs: List[Document] = list()
+ for element in elements:
+ metadata = self._get_metadata()
+ # NOTE(MthwRobinson) - the attribute check is for backward compatibility
+ # with unstructured<0.4.9. The metadata attributed was added in 0.4.9.
+ if hasattr(element, "metadata"):
+ metadata.update(element.metadata.to_dict())
+ if hasattr(element, "category"):
+ metadata["category"] = element.category
+ docs.append(Document(page_content=str(element), metadata=metadata))
+ elif self.mode == "paged":
+ text_dict: Dict[int, str] = {}
+ meta_dict: Dict[int, Dict] = {}
+
+ for idx, element in enumerate(elements):
+ metadata = self._get_metadata()
+ if hasattr(element, "metadata"):
+ metadata.update(element.metadata.to_dict())
+ page_number = metadata.get("page_number", 1)
+
+ # Check if this page_number already exists in docs_dict
+ if page_number not in text_dict:
+ # If not, create new entry with initial text and metadata
+ text_dict[page_number] = str(element) + "\n\n"
+ meta_dict[page_number] = metadata
+ else:
+ # If exists, append to text and update the metadata
+ text_dict[page_number] += str(element) + "\n\n"
+ meta_dict[page_number].update(metadata)
+
+ # Convert the dict to a list of Document objects
+ docs = [
+ Document(page_content=text_dict[key], metadata=meta_dict[key])
+ for key in text_dict.keys()
+ ]
+ elif self.mode == "single":
+ metadata = self._get_metadata()
+ text = "\n\n".join([str(el) for el in elements])
+ docs = [Document(page_content=text, metadata=metadata)]
+ else:
+ raise ValueError(f"mode of {self.mode} not supported.")
+ return docs
+
+
+class UnstructuredFileLoader(UnstructuredBaseLoader):
+ """Load files using `Unstructured`.
+
+ The file loader uses the
+ unstructured partition function and will automatically detect the file
+ type. You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredFileLoader
+
+ loader = UnstructuredFileLoader(
+ "example.pdf", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition
+ """
+
+ def __init__(
+ self,
+ file_path: Union[str, List[str]],
+ mode: str = "single",
+ **unstructured_kwargs: Any,
+ ):
+ """Initialize with file path."""
+ self.file_path = file_path
+ super().__init__(mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.auto import partition
+
+ return partition(filename=self.file_path, **self.unstructured_kwargs)
+
+ def _get_metadata(self) -> dict:
+ return {"source": self.file_path}
+
+
+def get_elements_from_api(
+ file_path: Union[str, List[str], None] = None,
+ file: Union[IO, Sequence[IO], None] = None,
+ api_url: str = "https://api.unstructured.io/general/v0/general",
+ api_key: str = "",
+ **unstructured_kwargs: Any,
+) -> List:
+ """Retrieve a list of elements from the `Unstructured API`."""
+ if isinstance(file, collections.abc.Sequence) or isinstance(file_path, list):
+ from unstructured.partition.api import partition_multiple_via_api
+
+ _doc_elements = partition_multiple_via_api(
+ filenames=file_path,
+ files=file,
+ api_key=api_key,
+ api_url=api_url,
+ **unstructured_kwargs,
+ )
+
+ elements = []
+ for _elements in _doc_elements:
+ elements.extend(_elements)
+
+ return elements
+ else:
+ from unstructured.partition.api import partition_via_api
+
+ return partition_via_api(
+ filename=file_path,
+ file=file,
+ api_key=api_key,
+ api_url=api_url,
+ **unstructured_kwargs,
+ )
+
+
+class UnstructuredAPIFileLoader(UnstructuredFileLoader):
+ """Load files using `Unstructured` API.
+
+ By default, the loader makes a call to the hosted Unstructured API.
+ If you are running the unstructured API locally, you can change the
+ API rule by passing in the url parameter when you initialize the loader.
+ The hosted Unstructured API requires an API key. See
+ https://www.unstructured.io/api-key/ if you need to generate a key.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ ```python
+ from langchain_community.document_loaders import UnstructuredAPIFileLoader
+
+ loader = UnstructuredFileAPILoader(
+ "example.pdf", mode="elements", strategy="fast", api_key="MY_API_KEY",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition
+ https://www.unstructured.io/api-key/
+ https://github.com/Unstructured-IO/unstructured-api
+ """
+
+ def __init__(
+ self,
+ file_path: Union[str, List[str]] = "",
+ mode: str = "single",
+ url: str = "https://api.unstructured.io/general/v0/general",
+ api_key: str = "",
+ **unstructured_kwargs: Any,
+ ):
+ """Initialize with file path."""
+
+ validate_unstructured_version(min_unstructured_version="0.10.15")
+
+ self.url = url
+ self.api_key = api_key
+
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_metadata(self) -> dict:
+ return {"source": self.file_path}
+
+ def _get_elements(self) -> List:
+ return get_elements_from_api(
+ file_path=self.file_path,
+ api_key=self.api_key,
+ api_url=self.url,
+ **self.unstructured_kwargs,
+ )
+
+
+class UnstructuredFileIOLoader(UnstructuredBaseLoader):
+ """Load files using `Unstructured`.
+
+ The file loader
+ uses the unstructured partition function and will automatically detect the file
+ type. You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredFileIOLoader
+
+ with open("example.pdf", "rb") as f:
+ loader = UnstructuredFileIOLoader(
+ f, mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition
+ """
+
+ def __init__(
+ self,
+ file: Union[IO, Sequence[IO]],
+ mode: str = "single",
+ **unstructured_kwargs: Any,
+ ):
+ """Initialize with file path."""
+ self.file = file
+ super().__init__(mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.auto import partition
+
+ return partition(file=self.file, **self.unstructured_kwargs)
+
+ def _get_metadata(self) -> dict:
+ return {}
+
+
+class UnstructuredAPIFileIOLoader(UnstructuredFileIOLoader):
+ """Load files using `Unstructured` API.
+
+ By default, the loader makes a call to the hosted Unstructured API.
+ If you are running the unstructured API locally, you can change the
+ API rule by passing in the url parameter when you initialize the loader.
+ The hosted Unstructured API requires an API key. See
+ https://www.unstructured.io/api-key/ if you need to generate a key.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredAPIFileLoader
+
+ with open("example.pdf", "rb") as f:
+ loader = UnstructuredFileAPILoader(
+ f, mode="elements", strategy="fast", api_key="MY_API_KEY",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition
+ https://www.unstructured.io/api-key/
+ https://github.com/Unstructured-IO/unstructured-api
+ """
+
+ def __init__(
+ self,
+ file: Union[IO, Sequence[IO]],
+ mode: str = "single",
+ url: str = "https://api.unstructured.io/general/v0/general",
+ api_key: str = "",
+ **unstructured_kwargs: Any,
+ ):
+ """Initialize with file path."""
+
+ if isinstance(file, collections.abc.Sequence):
+ validate_unstructured_version(min_unstructured_version="0.6.3")
+ if file:
+ validate_unstructured_version(min_unstructured_version="0.6.2")
+
+ self.url = url
+ self.api_key = api_key
+
+ super().__init__(file=file, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ return get_elements_from_api(
+ file=self.file,
+ api_key=self.api_key,
+ api_url=self.url,
+ **self.unstructured_kwargs,
+ )
diff --git a/libs/community/langchain_community/document_loaders/url.py b/libs/community/langchain_community/document_loaders/url.py
new file mode 100644
index 00000000000..922f028ad3e
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/url.py
@@ -0,0 +1,160 @@
+"""Loader that uses unstructured to load HTML files."""
+import logging
+from typing import Any, List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class UnstructuredURLLoader(BaseLoader):
+ """Load files from remote URLs using `Unstructured`.
+
+ Use the unstructured partition function to detect the MIME type
+ and route the file to the appropriate partitioner.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredURLLoader
+
+ loader = UnstructuredURLLoader(
+ urls=["", ""], mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition
+ """
+
+ def __init__(
+ self,
+ urls: List[str],
+ continue_on_failure: bool = True,
+ mode: str = "single",
+ show_progress_bar: bool = False,
+ **unstructured_kwargs: Any,
+ ):
+ """Initialize with file path."""
+ try:
+ import unstructured # noqa:F401
+ from unstructured.__version__ import __version__ as __unstructured_version__
+
+ self.__version = __unstructured_version__
+ except ImportError:
+ raise ImportError(
+ "unstructured package not found, please install it with "
+ "`pip install unstructured`"
+ )
+
+ self._validate_mode(mode)
+ self.mode = mode
+
+ headers = unstructured_kwargs.pop("headers", {})
+ if len(headers.keys()) != 0:
+ warn_about_headers = False
+ if self.__is_non_html_available():
+ warn_about_headers = not self.__is_headers_available_for_non_html()
+ else:
+ warn_about_headers = not self.__is_headers_available_for_html()
+
+ if warn_about_headers:
+ logger.warning(
+ "You are using an old version of unstructured. "
+ "The headers parameter is ignored"
+ )
+
+ self.urls = urls
+ self.continue_on_failure = continue_on_failure
+ self.headers = headers
+ self.unstructured_kwargs = unstructured_kwargs
+ self.show_progress_bar = show_progress_bar
+
+ def _validate_mode(self, mode: str) -> None:
+ _valid_modes = {"single", "elements"}
+ if mode not in _valid_modes:
+ raise ValueError(
+ f"Got {mode} for `mode`, but should be one of `{_valid_modes}`"
+ )
+
+ def __is_headers_available_for_html(self) -> bool:
+ _unstructured_version = self.__version.split("-")[0]
+ unstructured_version = tuple([int(x) for x in _unstructured_version.split(".")])
+
+ return unstructured_version >= (0, 5, 7)
+
+ def __is_headers_available_for_non_html(self) -> bool:
+ _unstructured_version = self.__version.split("-")[0]
+ unstructured_version = tuple([int(x) for x in _unstructured_version.split(".")])
+
+ return unstructured_version >= (0, 5, 13)
+
+ def __is_non_html_available(self) -> bool:
+ _unstructured_version = self.__version.split("-")[0]
+ unstructured_version = tuple([int(x) for x in _unstructured_version.split(".")])
+
+ return unstructured_version >= (0, 5, 12)
+
+ def load(self) -> List[Document]:
+ """Load file."""
+ from unstructured.partition.auto import partition
+ from unstructured.partition.html import partition_html
+
+ docs: List[Document] = list()
+ if self.show_progress_bar:
+ try:
+ from tqdm import tqdm
+ except ImportError as e:
+ raise ImportError(
+ "Package tqdm must be installed if show_progress_bar=True. "
+ "Please install with 'pip install tqdm' or set "
+ "show_progress_bar=False."
+ ) from e
+
+ urls = tqdm(self.urls)
+ else:
+ urls = self.urls
+
+ for url in urls:
+ try:
+ if self.__is_non_html_available():
+ if self.__is_headers_available_for_non_html():
+ elements = partition(
+ url=url, headers=self.headers, **self.unstructured_kwargs
+ )
+ else:
+ elements = partition(url=url, **self.unstructured_kwargs)
+ else:
+ if self.__is_headers_available_for_html():
+ elements = partition_html(
+ url=url, headers=self.headers, **self.unstructured_kwargs
+ )
+ else:
+ elements = partition_html(url=url, **self.unstructured_kwargs)
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.error(f"Error fetching or processing {url}, exception: {e}")
+ continue
+ else:
+ raise e
+
+ if self.mode == "single":
+ text = "\n\n".join([str(el) for el in elements])
+ metadata = {"source": url}
+ docs.append(Document(page_content=text, metadata=metadata))
+ elif self.mode == "elements":
+ for element in elements:
+ metadata = element.metadata.to_dict()
+ metadata["category"] = element.category
+ docs.append(Document(page_content=str(element), metadata=metadata))
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/url_playwright.py b/libs/community/langchain_community/document_loaders/url_playwright.py
new file mode 100644
index 00000000000..8071d3717f7
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/url_playwright.py
@@ -0,0 +1,208 @@
+"""Loader that uses Playwright to load a page, then uses unstructured to load the html.
+"""
+import logging
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+if TYPE_CHECKING:
+ from playwright.async_api import Browser as AsyncBrowser
+ from playwright.async_api import Page as AsyncPage
+ from playwright.async_api import Response as AsyncResponse
+ from playwright.sync_api import Browser, Page, Response
+
+
+logger = logging.getLogger(__name__)
+
+
+class PlaywrightEvaluator(ABC):
+ """Abstract base class for all evaluators.
+
+ Each evaluator should take a page, a browser instance, and a response
+ object, process the page as necessary, and return the resulting text.
+ """
+
+ @abstractmethod
+ def evaluate(self, page: "Page", browser: "Browser", response: "Response") -> str:
+ """Synchronously process the page and return the resulting text.
+
+ Args:
+ page: The page to process.
+ browser: The browser instance.
+ response: The response from page.goto().
+
+ Returns:
+ text: The text content of the page.
+ """
+ pass
+
+ @abstractmethod
+ async def evaluate_async(
+ self, page: "AsyncPage", browser: "AsyncBrowser", response: "AsyncResponse"
+ ) -> str:
+ """Asynchronously process the page and return the resulting text.
+
+ Args:
+ page: The page to process.
+ browser: The browser instance.
+ response: The response from page.goto().
+
+ Returns:
+ text: The text content of the page.
+ """
+ pass
+
+
+class UnstructuredHtmlEvaluator(PlaywrightEvaluator):
+ """Evaluates the page HTML content using the `unstructured` library."""
+
+ def __init__(self, remove_selectors: Optional[List[str]] = None):
+ """Initialize UnstructuredHtmlEvaluator."""
+ try:
+ import unstructured # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "unstructured package not found, please install it with "
+ "`pip install unstructured`"
+ )
+
+ self.remove_selectors = remove_selectors
+
+ def evaluate(self, page: "Page", browser: "Browser", response: "Response") -> str:
+ """Synchronously process the HTML content of the page."""
+ from unstructured.partition.html import partition_html
+
+ for selector in self.remove_selectors or []:
+ elements = page.locator(selector).all()
+ for element in elements:
+ if element.is_visible():
+ element.evaluate("element => element.remove()")
+
+ page_source = page.content()
+ elements = partition_html(text=page_source)
+ return "\n\n".join([str(el) for el in elements])
+
+ async def evaluate_async(
+ self, page: "AsyncPage", browser: "AsyncBrowser", response: "AsyncResponse"
+ ) -> str:
+ """Asynchronously process the HTML content of the page."""
+ from unstructured.partition.html import partition_html
+
+ for selector in self.remove_selectors or []:
+ elements = await page.locator(selector).all()
+ for element in elements:
+ if await element.is_visible():
+ await element.evaluate("element => element.remove()")
+
+ page_source = await page.content()
+ elements = partition_html(text=page_source)
+ return "\n\n".join([str(el) for el in elements])
+
+
+class PlaywrightURLLoader(BaseLoader):
+ """Load `HTML` pages with `Playwright` and parse with `Unstructured`.
+
+ This is useful for loading pages that require javascript to render.
+
+ Attributes:
+ urls (List[str]): List of URLs to load.
+ continue_on_failure (bool): If True, continue loading other URLs on failure.
+ headless (bool): If True, the browser will run in headless mode.
+ """
+
+ def __init__(
+ self,
+ urls: List[str],
+ continue_on_failure: bool = True,
+ headless: bool = True,
+ remove_selectors: Optional[List[str]] = None,
+ evaluator: Optional[PlaywrightEvaluator] = None,
+ ):
+ """Load a list of URLs using Playwright."""
+ try:
+ import playwright # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "playwright package not found, please install it with "
+ "`pip install playwright`"
+ )
+
+ self.urls = urls
+ self.continue_on_failure = continue_on_failure
+ self.headless = headless
+
+ if remove_selectors and evaluator:
+ raise ValueError(
+ "`remove_selectors` and `evaluator` cannot be both not None"
+ )
+
+ # Use the provided evaluator, if any, otherwise, use the default.
+ self.evaluator = evaluator or UnstructuredHtmlEvaluator(remove_selectors)
+
+ def load(self) -> List[Document]:
+ """Load the specified URLs using Playwright and create Document instances.
+
+ Returns:
+ List[Document]: A list of Document instances with loaded content.
+ """
+ from playwright.sync_api import sync_playwright
+
+ docs: List[Document] = list()
+
+ with sync_playwright() as p:
+ browser = p.chromium.launch(headless=self.headless)
+ for url in self.urls:
+ try:
+ page = browser.new_page()
+ response = page.goto(url)
+ if response is None:
+ raise ValueError(f"page.goto() returned None for url {url}")
+
+ text = self.evaluator.evaluate(page, browser, response)
+ metadata = {"source": url}
+ docs.append(Document(page_content=text, metadata=metadata))
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.error(
+ f"Error fetching or processing {url}, exception: {e}"
+ )
+ else:
+ raise e
+ browser.close()
+ return docs
+
+ async def aload(self) -> List[Document]:
+ """Load the specified URLs with Playwright and create Documents asynchronously.
+ Use this function when in a jupyter notebook environment.
+
+ Returns:
+ List[Document]: A list of Document instances with loaded content.
+ """
+ from playwright.async_api import async_playwright
+
+ docs: List[Document] = list()
+
+ async with async_playwright() as p:
+ browser = await p.chromium.launch(headless=self.headless)
+ for url in self.urls:
+ try:
+ page = await browser.new_page()
+ response = await page.goto(url)
+ if response is None:
+ raise ValueError(f"page.goto() returned None for url {url}")
+
+ text = await self.evaluator.evaluate_async(page, browser, response)
+ metadata = {"source": url}
+ docs.append(Document(page_content=text, metadata=metadata))
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.error(
+ f"Error fetching or processing {url}, exception: {e}"
+ )
+ else:
+ raise e
+ await browser.close()
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/url_selenium.py b/libs/community/langchain_community/document_loaders/url_selenium.py
new file mode 100644
index 00000000000..bbe74f2b1e7
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/url_selenium.py
@@ -0,0 +1,176 @@
+"""Loader that uses Selenium to load a page, then uses unstructured to load the html.
+"""
+import logging
+from typing import TYPE_CHECKING, List, Literal, Optional, Union
+
+if TYPE_CHECKING:
+ from selenium.webdriver import Chrome, Firefox
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+
+class SeleniumURLLoader(BaseLoader):
+ """Load `HTML` pages with `Selenium` and parse with `Unstructured`.
+
+ This is useful for loading pages that require javascript to render.
+
+ Attributes:
+ urls (List[str]): List of URLs to load.
+ continue_on_failure (bool): If True, continue loading other URLs on failure.
+ browser (str): The browser to use, either 'chrome' or 'firefox'.
+ binary_location (Optional[str]): The location of the browser binary.
+ executable_path (Optional[str]): The path to the browser executable.
+ headless (bool): If True, the browser will run in headless mode.
+ arguments [List[str]]: List of arguments to pass to the browser.
+ """
+
+ def __init__(
+ self,
+ urls: List[str],
+ continue_on_failure: bool = True,
+ browser: Literal["chrome", "firefox"] = "chrome",
+ binary_location: Optional[str] = None,
+ executable_path: Optional[str] = None,
+ headless: bool = True,
+ arguments: List[str] = [],
+ ):
+ """Load a list of URLs using Selenium and unstructured."""
+ try:
+ import selenium # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "selenium package not found, please install it with "
+ "`pip install selenium`"
+ )
+
+ try:
+ import unstructured # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "unstructured package not found, please install it with "
+ "`pip install unstructured`"
+ )
+
+ self.urls = urls
+ self.continue_on_failure = continue_on_failure
+ self.browser = browser
+ self.binary_location = binary_location
+ self.executable_path = executable_path
+ self.headless = headless
+ self.arguments = arguments
+
+ def _get_driver(self) -> Union["Chrome", "Firefox"]:
+ """Create and return a WebDriver instance based on the specified browser.
+
+ Raises:
+ ValueError: If an invalid browser is specified.
+
+ Returns:
+ Union[Chrome, Firefox]: A WebDriver instance for the specified browser.
+ """
+ if self.browser.lower() == "chrome":
+ from selenium.webdriver import Chrome
+ from selenium.webdriver.chrome.options import Options as ChromeOptions
+ from selenium.webdriver.chrome.service import Service
+
+ chrome_options = ChromeOptions()
+
+ for arg in self.arguments:
+ chrome_options.add_argument(arg)
+
+ if self.headless:
+ chrome_options.add_argument("--headless")
+ chrome_options.add_argument("--no-sandbox")
+ if self.binary_location is not None:
+ chrome_options.binary_location = self.binary_location
+ if self.executable_path is None:
+ return Chrome(options=chrome_options)
+ return Chrome(
+ options=chrome_options,
+ service=Service(executable_path=self.executable_path),
+ )
+ elif self.browser.lower() == "firefox":
+ from selenium.webdriver import Firefox
+ from selenium.webdriver.firefox.options import Options as FirefoxOptions
+ from selenium.webdriver.firefox.service import Service
+
+ firefox_options = FirefoxOptions()
+
+ for arg in self.arguments:
+ firefox_options.add_argument(arg)
+
+ if self.headless:
+ firefox_options.add_argument("--headless")
+ if self.binary_location is not None:
+ firefox_options.binary_location = self.binary_location
+ if self.executable_path is None:
+ return Firefox(options=firefox_options)
+ return Firefox(
+ options=firefox_options,
+ service=Service(executable_path=self.executable_path),
+ )
+ else:
+ raise ValueError("Invalid browser specified. Use 'chrome' or 'firefox'.")
+
+ def _build_metadata(self, url: str, driver: Union["Chrome", "Firefox"]) -> dict:
+ from selenium.common.exceptions import NoSuchElementException
+ from selenium.webdriver.common.by import By
+
+ """Build metadata based on the contents of the webpage"""
+ metadata = {
+ "source": url,
+ "title": "No title found.",
+ "description": "No description found.",
+ "language": "No language found.",
+ }
+ if title := driver.title:
+ metadata["title"] = title
+ try:
+ if description := driver.find_element(
+ By.XPATH, '//meta[@name="description"]'
+ ):
+ metadata["description"] = (
+ description.get_attribute("content") or "No description found."
+ )
+ except NoSuchElementException:
+ pass
+ try:
+ if html_tag := driver.find_element(By.TAG_NAME, "html"):
+ metadata["language"] = (
+ html_tag.get_attribute("lang") or "No language found."
+ )
+ except NoSuchElementException:
+ pass
+ return metadata
+
+ def load(self) -> List[Document]:
+ """Load the specified URLs using Selenium and create Document instances.
+
+ Returns:
+ List[Document]: A list of Document instances with loaded content.
+ """
+ from unstructured.partition.html import partition_html
+
+ docs: List[Document] = list()
+ driver = self._get_driver()
+
+ for url in self.urls:
+ try:
+ driver.get(url)
+ page_content = driver.page_source
+ elements = partition_html(text=page_content)
+ text = "\n\n".join([str(el) for el in elements])
+ metadata = self._build_metadata(url, driver)
+ docs.append(Document(page_content=text, metadata=metadata))
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.error(f"Error fetching or processing {url}, exception: {e}")
+ else:
+ raise e
+
+ driver.quit()
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/weather.py b/libs/community/langchain_community/document_loaders/weather.py
new file mode 100644
index 00000000000..a2ca0a7d520
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/weather.py
@@ -0,0 +1,51 @@
+"""Simple reader that reads weather data from OpenWeatherMap API"""
+from __future__ import annotations
+
+from datetime import datetime
+from typing import Iterator, List, Optional, Sequence
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.openweathermap import OpenWeatherMapAPIWrapper
+
+
+class WeatherDataLoader(BaseLoader):
+ """Load weather data with `Open Weather Map` API.
+
+ Reads the forecast & current weather of any location using OpenWeatherMap's free
+ API. Checkout 'https://openweathermap.org/appid' for more on how to generate a free
+ OpenWeatherMap API.
+ """
+
+ def __init__(
+ self,
+ client: OpenWeatherMapAPIWrapper,
+ places: Sequence[str],
+ ) -> None:
+ """Initialize with parameters."""
+ super().__init__()
+ self.client = client
+ self.places = places
+
+ @classmethod
+ def from_params(
+ cls, places: Sequence[str], *, openweathermap_api_key: Optional[str] = None
+ ) -> WeatherDataLoader:
+ client = OpenWeatherMapAPIWrapper(openweathermap_api_key=openweathermap_api_key)
+ return cls(client, places)
+
+ def lazy_load(
+ self,
+ ) -> Iterator[Document]:
+ """Lazily load weather data for the given locations."""
+ for place in self.places:
+ metadata = {"queried_at": datetime.now()}
+ content = self.client.run(place)
+ yield Document(page_content=content, metadata=metadata)
+
+ def load(
+ self,
+ ) -> List[Document]:
+ """Load weather data for the given locations."""
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/document_loaders/web_base.py b/libs/community/langchain_community/document_loaders/web_base.py
new file mode 100644
index 00000000000..553f9c8e151
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/web_base.py
@@ -0,0 +1,267 @@
+"""Web base loader class."""
+import asyncio
+import logging
+import warnings
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
+
+import aiohttp
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+default_header_template = {
+ "User-Agent": "",
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*"
+ ";q=0.8",
+ "Accept-Language": "en-US,en;q=0.5",
+ "Referer": "https://www.google.com/",
+ "DNT": "1",
+ "Connection": "keep-alive",
+ "Upgrade-Insecure-Requests": "1",
+}
+
+
+def _build_metadata(soup: Any, url: str) -> dict:
+ """Build metadata from BeautifulSoup output."""
+ metadata = {"source": url}
+ if title := soup.find("title"):
+ metadata["title"] = title.get_text()
+ if description := soup.find("meta", attrs={"name": "description"}):
+ metadata["description"] = description.get("content", "No description found.")
+ if html := soup.find("html"):
+ metadata["language"] = html.get("lang", "No language found.")
+ return metadata
+
+
+class WebBaseLoader(BaseLoader):
+ """Load HTML pages using `urllib` and parse them with `BeautifulSoup'."""
+
+ def __init__(
+ self,
+ web_path: Union[str, Sequence[str]] = "",
+ header_template: Optional[dict] = None,
+ verify_ssl: bool = True,
+ proxies: Optional[dict] = None,
+ continue_on_failure: bool = False,
+ autoset_encoding: bool = True,
+ encoding: Optional[str] = None,
+ web_paths: Sequence[str] = (),
+ requests_per_second: int = 2,
+ default_parser: str = "html.parser",
+ requests_kwargs: Optional[Dict[str, Any]] = None,
+ raise_for_status: bool = False,
+ bs_get_text_kwargs: Optional[Dict[str, Any]] = None,
+ bs_kwargs: Optional[Dict[str, Any]] = None,
+ session: Any = None,
+ ) -> None:
+ """Initialize loader.
+
+ Args:
+ web_paths: Web paths to load from.
+ requests_per_second: Max number of concurrent requests to make.
+ default_parser: Default parser to use for BeautifulSoup.
+ requests_kwargs: kwargs for requests
+ raise_for_status: Raise an exception if http status code denotes an error.
+ bs_get_text_kwargs: kwargs for beatifulsoup4 get_text
+ bs_kwargs: kwargs for beatifulsoup4 web page parsing
+ """
+ # web_path kept for backwards-compatibility.
+ if web_path and web_paths:
+ raise ValueError(
+ "Received web_path and web_paths. Only one can be specified. "
+ "web_path is deprecated, web_paths should be used."
+ )
+ if web_paths:
+ self.web_paths = list(web_paths)
+ elif isinstance(web_path, str):
+ self.web_paths = [web_path]
+ elif isinstance(web_path, Sequence):
+ self.web_paths = list(web_path)
+ else:
+ raise TypeError(
+ f"web_path must be str or Sequence[str] got ({type(web_path)}) or"
+ f" web_paths must be Sequence[str] got ({type(web_paths)})"
+ )
+ self.requests_per_second = requests_per_second
+ self.default_parser = default_parser
+ self.requests_kwargs = requests_kwargs or {}
+ self.raise_for_status = raise_for_status
+ self.bs_get_text_kwargs = bs_get_text_kwargs or {}
+ self.bs_kwargs = bs_kwargs or {}
+ if session:
+ self.session = session
+ else:
+ session = requests.Session()
+ header_template = header_template or default_header_template.copy()
+ if not header_template.get("User-Agent"):
+ try:
+ from fake_useragent import UserAgent
+
+ header_template["User-Agent"] = UserAgent().random
+ except ImportError:
+ logger.info(
+ "fake_useragent not found, using default user agent."
+ "To get a realistic header for requests, "
+ "`pip install fake_useragent`."
+ )
+ session.headers = dict(header_template)
+ session.verify = verify_ssl
+ if proxies:
+ session.proxies.update(proxies)
+ self.session = session
+ self.continue_on_failure = continue_on_failure
+ self.autoset_encoding = autoset_encoding
+ self.encoding = encoding
+
+ @property
+ def web_path(self) -> str:
+ if len(self.web_paths) > 1:
+ raise ValueError("Multiple webpaths found.")
+ return self.web_paths[0]
+
+ async def _fetch(
+ self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
+ ) -> str:
+ async with aiohttp.ClientSession() as session:
+ for i in range(retries):
+ try:
+ async with session.get(
+ url,
+ headers=self.session.headers,
+ ssl=None if self.session.verify else False,
+ ) as response:
+ return await response.text()
+ except aiohttp.ClientConnectionError as e:
+ if i == retries - 1:
+ raise
+ else:
+ logger.warning(
+ f"Error fetching {url} with attempt "
+ f"{i + 1}/{retries}: {e}. Retrying..."
+ )
+ await asyncio.sleep(cooldown * backoff**i)
+ raise ValueError("retry count exceeded")
+
+ async def _fetch_with_rate_limit(
+ self, url: str, semaphore: asyncio.Semaphore
+ ) -> str:
+ async with semaphore:
+ try:
+ return await self._fetch(url)
+ except Exception as e:
+ if self.continue_on_failure:
+ logger.warning(
+ f"Error fetching {url}, skipping due to"
+ f" continue_on_failure=True"
+ )
+ return ""
+ logger.exception(
+ f"Error fetching {url} and aborting, use continue_on_failure=True "
+ "to continue loading urls after encountering an error."
+ )
+ raise e
+
+ async def fetch_all(self, urls: List[str]) -> Any:
+ """Fetch all urls concurrently with rate limiting."""
+ semaphore = asyncio.Semaphore(self.requests_per_second)
+ tasks = []
+ for url in urls:
+ task = asyncio.ensure_future(self._fetch_with_rate_limit(url, semaphore))
+ tasks.append(task)
+ try:
+ from tqdm.asyncio import tqdm_asyncio
+
+ return await tqdm_asyncio.gather(
+ *tasks, desc="Fetching pages", ascii=True, mininterval=1
+ )
+ except ImportError:
+ warnings.warn("For better logging of progress, `pip install tqdm`")
+ return await asyncio.gather(*tasks)
+
+ @staticmethod
+ def _check_parser(parser: str) -> None:
+ """Check that parser is valid for bs4."""
+ valid_parsers = ["html.parser", "lxml", "xml", "lxml-xml", "html5lib"]
+ if parser not in valid_parsers:
+ raise ValueError(
+ "`parser` must be one of " + ", ".join(valid_parsers) + "."
+ )
+
+ def scrape_all(self, urls: List[str], parser: Union[str, None] = None) -> List[Any]:
+ """Fetch all urls, then return soups for all results."""
+ from bs4 import BeautifulSoup
+
+ results = asyncio.run(self.fetch_all(urls))
+ final_results = []
+ for i, result in enumerate(results):
+ url = urls[i]
+ if parser is None:
+ if url.endswith(".xml"):
+ parser = "xml"
+ else:
+ parser = self.default_parser
+ self._check_parser(parser)
+ final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
+
+ return final_results
+
+ def _scrape(
+ self,
+ url: str,
+ parser: Union[str, None] = None,
+ bs_kwargs: Optional[dict] = None,
+ ) -> Any:
+ from bs4 import BeautifulSoup
+
+ if parser is None:
+ if url.endswith(".xml"):
+ parser = "xml"
+ else:
+ parser = self.default_parser
+
+ self._check_parser(parser)
+
+ html_doc = self.session.get(url, **self.requests_kwargs)
+ if self.raise_for_status:
+ html_doc.raise_for_status()
+
+ if self.encoding is not None:
+ html_doc.encoding = self.encoding
+ elif self.autoset_encoding:
+ html_doc.encoding = html_doc.apparent_encoding
+ return BeautifulSoup(html_doc.text, parser, **(bs_kwargs or {}))
+
+ def scrape(self, parser: Union[str, None] = None) -> Any:
+ """Scrape data from webpage and return it in BeautifulSoup format."""
+
+ if parser is None:
+ parser = self.default_parser
+
+ return self._scrape(self.web_path, parser=parser, bs_kwargs=self.bs_kwargs)
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load text from the url(s) in web_path."""
+ for path in self.web_paths:
+ soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
+ text = soup.get_text(**self.bs_get_text_kwargs)
+ metadata = _build_metadata(soup, path)
+ yield Document(page_content=text, metadata=metadata)
+
+ def load(self) -> List[Document]:
+ """Load text from the url(s) in web_path."""
+ return list(self.lazy_load())
+
+ def aload(self) -> List[Document]:
+ """Load text from the urls in web_path async into Documents."""
+
+ results = self.scrape_all(self.web_paths)
+ docs = []
+ for path, soup in zip(self.web_paths, results):
+ text = soup.get_text(**self.bs_get_text_kwargs)
+ metadata = _build_metadata(soup, path)
+ docs.append(Document(page_content=text, metadata=metadata))
+
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/whatsapp_chat.py b/libs/community/langchain_community/document_loaders/whatsapp_chat.py
new file mode 100644
index 00000000000..d8198dee1cf
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/whatsapp_chat.py
@@ -0,0 +1,65 @@
+import re
+from pathlib import Path
+from typing import List
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+
+
+def concatenate_rows(date: str, sender: str, text: str) -> str:
+ """Combine message information in a readable format ready to be used."""
+ return f"{sender} on {date}: {text}\n\n"
+
+
+class WhatsAppChatLoader(BaseLoader):
+ """Load `WhatsApp` messages text file."""
+
+ def __init__(self, path: str):
+ """Initialize with path."""
+ self.file_path = path
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ p = Path(self.file_path)
+ text_content = ""
+
+ with open(p, encoding="utf8") as f:
+ lines = f.readlines()
+
+ message_line_regex = r"""
+ \[?
+ (
+ \d{1,4}
+ [\/.]
+ \d{1,2}
+ [\/.]
+ \d{1,4}
+ ,\s
+ \d{1,2}
+ :\d{2}
+ (?:
+ :\d{2}
+ )?
+ (?:[\s_](?:AM|PM))?
+ )
+ \]?
+ [\s-]*
+ ([~\w\s]+)
+ [:]+
+ \s
+ (.+)
+ """
+ ignore_lines = ["This message was deleted", ""]
+ for line in lines:
+ result = re.match(
+ message_line_regex, line.strip(), flags=re.VERBOSE | re.IGNORECASE
+ )
+ if result:
+ date, sender, text = result.groups()
+ if text not in ignore_lines:
+ text_content += concatenate_rows(date, sender, text)
+
+ metadata = {"source": str(p)}
+
+ return [Document(page_content=text_content, metadata=metadata)]
diff --git a/libs/community/langchain_community/document_loaders/wikipedia.py b/libs/community/langchain_community/document_loaders/wikipedia.py
new file mode 100644
index 00000000000..ea7bd5887a3
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/wikipedia.py
@@ -0,0 +1,60 @@
+from typing import List, Optional
+
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
+
+
+class WikipediaLoader(BaseLoader):
+ """Load from `Wikipedia`.
+
+ The hard limit on the length of the query is 300 for now.
+
+ Each wiki page represents one Document.
+ """
+
+ def __init__(
+ self,
+ query: str,
+ lang: str = "en",
+ load_max_docs: Optional[int] = 25,
+ load_all_available_meta: Optional[bool] = False,
+ doc_content_chars_max: Optional[int] = 4000,
+ ):
+ """
+ Initializes a new instance of the WikipediaLoader class.
+
+ Args:
+ query (str): The query string to search on Wikipedia.
+ lang (str, optional): The language code for the Wikipedia language edition.
+ Defaults to "en".
+ load_max_docs (int, optional): The maximum number of documents to load.
+ Defaults to 100.
+ load_all_available_meta (bool, optional): Indicates whether to load all
+ available metadata for each document. Defaults to False.
+ doc_content_chars_max (int, optional): The maximum number of characters
+ for the document content. Defaults to 4000.
+ """
+ self.query = query
+ self.lang = lang
+ self.load_max_docs = load_max_docs
+ self.load_all_available_meta = load_all_available_meta
+ self.doc_content_chars_max = doc_content_chars_max
+
+ def load(self) -> List[Document]:
+ """
+ Loads the query result from Wikipedia into a list of Documents.
+
+ Returns:
+ List[Document]: A list of Document objects representing the loaded
+ Wikipedia pages.
+ """
+ client = WikipediaAPIWrapper(
+ lang=self.lang,
+ top_k_results=self.load_max_docs,
+ load_all_available_meta=self.load_all_available_meta,
+ doc_content_chars_max=self.doc_content_chars_max,
+ )
+ docs = client.load(self.query)
+ return docs
diff --git a/libs/community/langchain_community/document_loaders/word_document.py b/libs/community/langchain_community/document_loaders/word_document.py
new file mode 100644
index 00000000000..efbd12559e3
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/word_document.py
@@ -0,0 +1,124 @@
+"""Loads word documents."""
+import os
+import tempfile
+from abc import ABC
+from typing import List
+from urllib.parse import urlparse
+
+import requests
+from langchain_core.documents import Document
+
+from langchain_community.document_loaders.base import BaseLoader
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class Docx2txtLoader(BaseLoader, ABC):
+ """Load `DOCX` file using `docx2txt` and chunks at character level.
+
+ Defaults to check for local file, but if the file is a web path, it will download it
+ to a temporary file, and use that, then clean up the temporary file after completion
+ """
+
+ def __init__(self, file_path: str):
+ """Initialize with file path."""
+ self.file_path = file_path
+ if "~" in self.file_path:
+ self.file_path = os.path.expanduser(self.file_path)
+
+ # If the file is a web path, download it to a temporary file, and use that
+ if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
+ r = requests.get(self.file_path)
+
+ if r.status_code != 200:
+ raise ValueError(
+ "Check the url of your file; returned status code %s"
+ % r.status_code
+ )
+
+ self.web_path = self.file_path
+ self.temp_file = tempfile.NamedTemporaryFile()
+ self.temp_file.write(r.content)
+ self.file_path = self.temp_file.name
+ elif not os.path.isfile(self.file_path):
+ raise ValueError("File path %s is not a valid file or url" % self.file_path)
+
+ def __del__(self) -> None:
+ if hasattr(self, "temp_file"):
+ self.temp_file.close()
+
+ def load(self) -> List[Document]:
+ """Load given path as single page."""
+ import docx2txt
+
+ return [
+ Document(
+ page_content=docx2txt.process(self.file_path),
+ metadata={"source": self.file_path},
+ )
+ ]
+
+ @staticmethod
+ def _is_valid_url(url: str) -> bool:
+ """Check if the url is valid."""
+ parsed = urlparse(url)
+ return bool(parsed.netloc) and bool(parsed.scheme)
+
+
+class UnstructuredWordDocumentLoader(UnstructuredFileLoader):
+ """Load `Microsoft Word` file using `Unstructured`.
+
+ Works with both .docx and .doc files.
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredWordDocumentLoader
+
+ loader = UnstructuredWordDocumentLoader(
+ "example.docx", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-docx
+ """
+
+ def _get_elements(self) -> List:
+ from unstructured.__version__ import __version__ as __unstructured_version__
+ from unstructured.file_utils.filetype import FileType, detect_filetype
+
+ unstructured_version = tuple(
+ [int(x) for x in __unstructured_version__.split(".")]
+ )
+ # NOTE(MthwRobinson) - magic will raise an import error if the libmagic
+ # system dependency isn't installed. If it's not installed, we'll just
+ # check the file extension
+ try:
+ import magic # noqa: F401
+
+ is_doc = detect_filetype(self.file_path) == FileType.DOC
+ except ImportError:
+ _, extension = os.path.splitext(str(self.file_path))
+ is_doc = extension == ".doc"
+
+ if is_doc and unstructured_version < (0, 4, 11):
+ raise ValueError(
+ f"You are on unstructured version {__unstructured_version__}. "
+ "Partitioning .doc files is only supported in unstructured>=0.4.11. "
+ "Please upgrade the unstructured package and try again."
+ )
+
+ if is_doc:
+ from unstructured.partition.doc import partition_doc
+
+ return partition_doc(filename=self.file_path, **self.unstructured_kwargs)
+ else:
+ from unstructured.partition.docx import partition_docx
+
+ return partition_docx(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/xml.py b/libs/community/langchain_community/document_loaders/xml.py
new file mode 100644
index 00000000000..1e1262de03d
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/xml.py
@@ -0,0 +1,43 @@
+"""Loads Microsoft Excel files."""
+from typing import Any, List
+
+from langchain_community.document_loaders.unstructured import (
+ UnstructuredFileLoader,
+ validate_unstructured_version,
+)
+
+
+class UnstructuredXMLLoader(UnstructuredFileLoader):
+ """Load `XML` file using `Unstructured`.
+
+ You can run the loader in one of two modes: "single" and "elements".
+ If you use "single" mode, the document will be returned as a single
+ langchain Document object. If you use "elements" mode, the unstructured
+ library will split the document into elements such as Title and NarrativeText.
+ You can pass in additional unstructured kwargs after mode to apply
+ different unstructured settings.
+
+ Examples
+ --------
+ from langchain_community.document_loaders import UnstructuredXMLLoader
+
+ loader = UnstructuredXMLLoader(
+ "example.xml", mode="elements", strategy="fast",
+ )
+ docs = loader.load()
+
+ References
+ ----------
+ https://unstructured-io.github.io/unstructured/bricks.html#partition-xml
+ """
+
+ def __init__(
+ self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
+ ):
+ validate_unstructured_version(min_unstructured_version="0.6.7")
+ super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
+
+ def _get_elements(self) -> List:
+ from unstructured.partition.xml import partition_xml
+
+ return partition_xml(filename=self.file_path, **self.unstructured_kwargs)
diff --git a/libs/community/langchain_community/document_loaders/xorbits.py b/libs/community/langchain_community/document_loaders/xorbits.py
new file mode 100644
index 00000000000..67c87e80bff
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/xorbits.py
@@ -0,0 +1,32 @@
+from typing import Any
+
+from langchain_community.document_loaders.dataframe import BaseDataFrameLoader
+
+
+class XorbitsLoader(BaseDataFrameLoader):
+ """Load `Xorbits` DataFrame."""
+
+ def __init__(self, data_frame: Any, page_content_column: str = "text"):
+ """Initialize with dataframe object.
+
+ Requirements:
+ Must have xorbits installed. You can install with `pip install xorbits`.
+
+ Args:
+ data_frame: Xorbits DataFrame object.
+ page_content_column: Name of the column containing the page content.
+ Defaults to "text".
+ """
+ try:
+ import xorbits.pandas as pd
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import xorbits, please install with 'pip install xorbits'."
+ ) from e
+
+ if not isinstance(data_frame, pd.DataFrame):
+ raise ValueError(
+ f"Expected data_frame to be a xorbits.pandas.DataFrame, \
+ got {type(data_frame)}"
+ )
+ super().__init__(data_frame, page_content_column=page_content_column)
diff --git a/libs/community/langchain_community/document_loaders/youtube.py b/libs/community/langchain_community/document_loaders/youtube.py
new file mode 100644
index 00000000000..ea3f08349de
--- /dev/null
+++ b/libs/community/langchain_community/document_loaders/youtube.py
@@ -0,0 +1,429 @@
+"""Loads YouTube transcript."""
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Sequence, Union
+from urllib.parse import parse_qs, urlparse
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.pydantic_v1.dataclasses import dataclass
+
+from langchain_community.document_loaders.base import BaseLoader
+
+logger = logging.getLogger(__name__)
+
+SCOPES = ["https://www.googleapis.com/auth/youtube.readonly"]
+
+
+@dataclass
+class GoogleApiClient:
+ """Generic Google API Client.
+
+ To use, you should have the ``google_auth_oauthlib,youtube_transcript_api,google``
+ python package installed.
+ As the google api expects credentials you need to set up a google account and
+ register your Service. "https://developers.google.com/docs/api/quickstart/python"
+
+
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import GoogleApiClient
+ google_api_client = GoogleApiClient(
+ service_account_path=Path("path_to_your_sec_file.json")
+ )
+
+ """
+
+ credentials_path: Path = Path.home() / ".credentials" / "credentials.json"
+ service_account_path: Path = Path.home() / ".credentials" / "credentials.json"
+ token_path: Path = Path.home() / ".credentials" / "token.json"
+
+ def __post_init__(self) -> None:
+ self.creds = self._load_credentials()
+
+ @root_validator
+ def validate_channel_or_videoIds_is_set(
+ cls, values: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Validate that either folder_id or document_ids is set, but not both."""
+
+ if not values.get("credentials_path") and not values.get(
+ "service_account_path"
+ ):
+ raise ValueError("Must specify either channel_name or video_ids")
+ return values
+
+ def _load_credentials(self) -> Any:
+ """Load credentials."""
+ # Adapted from https://developers.google.com/drive/api/v3/quickstart/python
+ try:
+ from google.auth.transport.requests import Request
+ from google.oauth2 import service_account
+ from google.oauth2.credentials import Credentials
+ from google_auth_oauthlib.flow import InstalledAppFlow
+ from youtube_transcript_api import YouTubeTranscriptApi # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "You must run"
+ "`pip install --upgrade "
+ "google-api-python-client google-auth-httplib2 "
+ "google-auth-oauthlib "
+ "youtube-transcript-api` "
+ "to use the Google Drive loader"
+ )
+
+ creds = None
+ if self.service_account_path.exists():
+ return service_account.Credentials.from_service_account_file(
+ str(self.service_account_path)
+ )
+ if self.token_path.exists():
+ creds = Credentials.from_authorized_user_file(str(self.token_path), SCOPES)
+
+ if not creds or not creds.valid:
+ if creds and creds.expired and creds.refresh_token:
+ creds.refresh(Request())
+ else:
+ flow = InstalledAppFlow.from_client_secrets_file(
+ str(self.credentials_path), SCOPES
+ )
+ creds = flow.run_local_server(port=0)
+ with open(self.token_path, "w") as token:
+ token.write(creds.to_json())
+
+ return creds
+
+
+ALLOWED_SCHEMAS = {"http", "https"}
+ALLOWED_NETLOCK = {
+ "youtu.be",
+ "m.youtube.com",
+ "youtube.com",
+ "www.youtube.com",
+ "www.youtube-nocookie.com",
+ "vid.plus",
+}
+
+
+def _parse_video_id(url: str) -> Optional[str]:
+ """Parse a youtube url and return the video id if valid, otherwise None."""
+ parsed_url = urlparse(url)
+
+ if parsed_url.scheme not in ALLOWED_SCHEMAS:
+ return None
+
+ if parsed_url.netloc not in ALLOWED_NETLOCK:
+ return None
+
+ path = parsed_url.path
+
+ if path.endswith("/watch"):
+ query = parsed_url.query
+ parsed_query = parse_qs(query)
+ if "v" in parsed_query:
+ ids = parsed_query["v"]
+ video_id = ids if isinstance(ids, str) else ids[0]
+ else:
+ return None
+ else:
+ path = parsed_url.path.lstrip("/")
+ video_id = path.split("/")[-1]
+
+ if len(video_id) != 11: # Video IDs are 11 characters long
+ return None
+
+ return video_id
+
+
+class YoutubeLoader(BaseLoader):
+ """Load `YouTube` transcripts."""
+
+ def __init__(
+ self,
+ video_id: str,
+ add_video_info: bool = False,
+ language: Union[str, Sequence[str]] = "en",
+ translation: Optional[str] = None,
+ continue_on_failure: bool = False,
+ ):
+ """Initialize with YouTube video ID."""
+ self.video_id = video_id
+ self.add_video_info = add_video_info
+ self.language = language
+ if isinstance(language, str):
+ self.language = [language]
+ else:
+ self.language = language
+ self.translation = translation
+ self.continue_on_failure = continue_on_failure
+
+ @staticmethod
+ def extract_video_id(youtube_url: str) -> str:
+ """Extract video id from common YT urls."""
+ video_id = _parse_video_id(youtube_url)
+ if not video_id:
+ raise ValueError(
+ f"Could not determine the video ID for the URL {youtube_url}"
+ )
+ return video_id
+
+ @classmethod
+ def from_youtube_url(cls, youtube_url: str, **kwargs: Any) -> YoutubeLoader:
+ """Given youtube URL, load video."""
+ video_id = cls.extract_video_id(youtube_url)
+ return cls(video_id, **kwargs)
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ try:
+ from youtube_transcript_api import (
+ NoTranscriptFound,
+ TranscriptsDisabled,
+ YouTubeTranscriptApi,
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import youtube_transcript_api python package. "
+ "Please install it with `pip install youtube-transcript-api`."
+ )
+
+ metadata = {"source": self.video_id}
+
+ if self.add_video_info:
+ # Get more video meta info
+ # Such as title, description, thumbnail url, publish_date
+ video_info = self._get_video_info()
+ metadata.update(video_info)
+
+ try:
+ transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id)
+ except TranscriptsDisabled:
+ return []
+
+ try:
+ transcript = transcript_list.find_transcript(self.language)
+ except NoTranscriptFound:
+ transcript = transcript_list.find_transcript(["en"])
+
+ if self.translation is not None:
+ transcript = transcript.translate(self.translation)
+
+ transcript_pieces = transcript.fetch()
+
+ transcript = " ".join([t["text"].strip(" ") for t in transcript_pieces])
+
+ return [Document(page_content=transcript, metadata=metadata)]
+
+ def _get_video_info(self) -> dict:
+ """Get important video information.
+
+ Components are:
+ - title
+ - description
+ - thumbnail url,
+ - publish_date
+ - channel_author
+ - and more.
+ """
+ try:
+ from pytube import YouTube
+
+ except ImportError:
+ raise ImportError(
+ "Could not import pytube python package. "
+ "Please install it with `pip install pytube`."
+ )
+ yt = YouTube(f"https://www.youtube.com/watch?v={self.video_id}")
+ video_info = {
+ "title": yt.title or "Unknown",
+ "description": yt.description or "Unknown",
+ "view_count": yt.views or 0,
+ "thumbnail_url": yt.thumbnail_url or "Unknown",
+ "publish_date": yt.publish_date.strftime("%Y-%m-%d %H:%M:%S")
+ if yt.publish_date
+ else "Unknown",
+ "length": yt.length or 0,
+ "author": yt.author or "Unknown",
+ }
+ return video_info
+
+
+@dataclass
+class GoogleApiYoutubeLoader(BaseLoader):
+ """Load all Videos from a `YouTube` Channel.
+
+ To use, you should have the ``googleapiclient,youtube_transcript_api``
+ python package installed.
+ As the service needs a google_api_client, you first have to initialize
+ the GoogleApiClient.
+
+ Additionally you have to either provide a channel name or a list of videoids
+ "https://developers.google.com/docs/api/quickstart/python"
+
+
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_loaders import GoogleApiClient
+ from langchain_community.document_loaders import GoogleApiYoutubeLoader
+ google_api_client = GoogleApiClient(
+ service_account_path=Path("path_to_your_sec_file.json")
+ )
+ loader = GoogleApiYoutubeLoader(
+ google_api_client=google_api_client,
+ channel_name = "CodeAesthetic"
+ )
+ load.load()
+
+ """
+
+ google_api_client: GoogleApiClient
+ channel_name: Optional[str] = None
+ video_ids: Optional[List[str]] = None
+ add_video_info: bool = True
+ captions_language: str = "en"
+ continue_on_failure: bool = False
+
+ def __post_init__(self) -> None:
+ self.youtube_client = self._build_youtube_client(self.google_api_client.creds)
+
+ def _build_youtube_client(self, creds: Any) -> Any:
+ try:
+ from googleapiclient.discovery import build
+ from youtube_transcript_api import YouTubeTranscriptApi # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "You must run"
+ "`pip install --upgrade "
+ "google-api-python-client google-auth-httplib2 "
+ "google-auth-oauthlib "
+ "youtube-transcript-api` "
+ "to use the Google Drive loader"
+ )
+
+ return build("youtube", "v3", credentials=creds)
+
+ @root_validator
+ def validate_channel_or_videoIds_is_set(
+ cls, values: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Validate that either folder_id or document_ids is set, but not both."""
+ if not values.get("channel_name") and not values.get("video_ids"):
+ raise ValueError("Must specify either channel_name or video_ids")
+ return values
+
+ def _get_transcripe_for_video_id(self, video_id: str) -> str:
+ from youtube_transcript_api import NoTranscriptFound, YouTubeTranscriptApi
+
+ transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
+ try:
+ transcript = transcript_list.find_transcript([self.captions_language])
+ except NoTranscriptFound:
+ for available_transcript in transcript_list:
+ transcript = available_transcript.translate(self.captions_language)
+ continue
+
+ transcript_pieces = transcript.fetch()
+ return " ".join([t["text"].strip(" ") for t in transcript_pieces])
+
+ def _get_document_for_video_id(self, video_id: str, **kwargs: Any) -> Document:
+ captions = self._get_transcripe_for_video_id(video_id)
+ video_response = (
+ self.youtube_client.videos()
+ .list(
+ part="id,snippet",
+ id=video_id,
+ )
+ .execute()
+ )
+ return Document(
+ page_content=captions,
+ metadata=video_response.get("items")[0],
+ )
+
+ def _get_channel_id(self, channel_name: str) -> str:
+ request = self.youtube_client.search().list(
+ part="id",
+ q=channel_name,
+ type="channel",
+ maxResults=1, # we only need one result since channel names are unique
+ )
+ response = request.execute()
+ channel_id = response["items"][0]["id"]["channelId"]
+ return channel_id
+
+ def _get_document_for_channel(self, channel: str, **kwargs: Any) -> List[Document]:
+ try:
+ from youtube_transcript_api import (
+ NoTranscriptFound,
+ TranscriptsDisabled,
+ )
+ except ImportError:
+ raise ImportError(
+ "You must run"
+ "`pip install --upgrade "
+ "youtube-transcript-api` "
+ "to use the youtube loader"
+ )
+
+ channel_id = self._get_channel_id(channel)
+ request = self.youtube_client.search().list(
+ part="id,snippet",
+ channelId=channel_id,
+ maxResults=50, # adjust this value to retrieve more or fewer videos
+ )
+ video_ids = []
+ while request is not None:
+ response = request.execute()
+
+ # Add each video ID to the list
+ for item in response["items"]:
+ if not item["id"].get("videoId"):
+ continue
+ meta_data = {"videoId": item["id"]["videoId"]}
+ if self.add_video_info:
+ item["snippet"].pop("thumbnails")
+ meta_data.update(item["snippet"])
+ try:
+ page_content = self._get_transcripe_for_video_id(
+ item["id"]["videoId"]
+ )
+ video_ids.append(
+ Document(
+ page_content=page_content,
+ metadata=meta_data,
+ )
+ )
+ except (TranscriptsDisabled, NoTranscriptFound) as e:
+ if self.continue_on_failure:
+ logger.error(
+ "Error fetching transscript "
+ + f" {item['id']['videoId']}, exception: {e}"
+ )
+ else:
+ raise e
+ pass
+ request = self.youtube_client.search().list_next(request, response)
+
+ return video_ids
+
+ def load(self) -> List[Document]:
+ """Load documents."""
+ document_list = []
+ if self.channel_name:
+ document_list.extend(self._get_document_for_channel(self.channel_name))
+ elif self.video_ids:
+ document_list.extend(
+ [
+ self._get_document_for_video_id(video_id)
+ for video_id in self.video_ids
+ ]
+ )
+ else:
+ raise ValueError("Must specify either channel_name or video_ids")
+ return document_list
diff --git a/libs/community/langchain_community/document_transformers/__init__.py b/libs/community/langchain_community/document_transformers/__init__.py
new file mode 100644
index 00000000000..fce2184af3e
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/__init__.py
@@ -0,0 +1,62 @@
+"""**Document Transformers** are classes to transform Documents.
+
+**Document Transformers** usually used to transform a lot of Documents in a single run.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseDocumentTransformer --> # Examples: DoctranQATransformer, DoctranTextTranslator
+
+**Main helpers:**
+
+.. code-block::
+
+ Document
+""" # noqa: E501
+
+from langchain_community.document_transformers.beautiful_soup_transformer import (
+ BeautifulSoupTransformer,
+)
+from langchain_community.document_transformers.doctran_text_extract import (
+ DoctranPropertyExtractor,
+)
+from langchain_community.document_transformers.doctran_text_qa import (
+ DoctranQATransformer,
+)
+from langchain_community.document_transformers.doctran_text_translate import (
+ DoctranTextTranslator,
+)
+from langchain_community.document_transformers.embeddings_redundant_filter import (
+ EmbeddingsClusteringFilter,
+ EmbeddingsRedundantFilter,
+ get_stateful_documents,
+)
+from langchain_community.document_transformers.google_translate import (
+ GoogleTranslateTransformer,
+)
+from langchain_community.document_transformers.html2text import Html2TextTransformer
+from langchain_community.document_transformers.long_context_reorder import (
+ LongContextReorder,
+)
+from langchain_community.document_transformers.nuclia_text_transform import (
+ NucliaTextTransformer,
+)
+from langchain_community.document_transformers.openai_functions import (
+ OpenAIMetadataTagger,
+)
+
+__all__ = [
+ "BeautifulSoupTransformer",
+ "DoctranQATransformer",
+ "DoctranTextTranslator",
+ "DoctranPropertyExtractor",
+ "EmbeddingsClusteringFilter",
+ "EmbeddingsRedundantFilter",
+ "GoogleTranslateTransformer",
+ "get_stateful_documents",
+ "LongContextReorder",
+ "NucliaTextTransformer",
+ "OpenAIMetadataTagger",
+ "Html2TextTransformer",
+]
diff --git a/libs/community/langchain_community/document_transformers/beautiful_soup_transformer.py b/libs/community/langchain_community/document_transformers/beautiful_soup_transformer.py
new file mode 100644
index 00000000000..0e2b5d394c2
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/beautiful_soup_transformer.py
@@ -0,0 +1,149 @@
+from typing import Any, Iterator, List, Sequence, cast
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+
+
+class BeautifulSoupTransformer(BaseDocumentTransformer):
+ """Transform HTML content by extracting specific tags and removing unwanted ones.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_transformers import BeautifulSoupTransformer
+
+ bs4_transformer = BeautifulSoupTransformer()
+ docs_transformed = bs4_transformer.transform_documents(docs)
+ """ # noqa: E501
+
+ def __init__(self) -> None:
+ """
+ Initialize the transformer.
+
+ This checks if the BeautifulSoup4 package is installed.
+ If not, it raises an ImportError.
+ """
+ try:
+ import bs4 # noqa:F401
+ except ImportError:
+ raise ImportError(
+ "BeautifulSoup4 is required for BeautifulSoupTransformer. "
+ "Please install it with `pip install beautifulsoup4`."
+ )
+
+ def transform_documents(
+ self,
+ documents: Sequence[Document],
+ unwanted_tags: List[str] = ["script", "style"],
+ tags_to_extract: List[str] = ["p", "li", "div", "a"],
+ remove_lines: bool = True,
+ **kwargs: Any,
+ ) -> Sequence[Document]:
+ """
+ Transform a list of Document objects by cleaning their HTML content.
+
+ Args:
+ documents: A sequence of Document objects containing HTML content.
+ unwanted_tags: A list of tags to be removed from the HTML.
+ tags_to_extract: A list of tags whose content will be extracted.
+ remove_lines: If set to True, unnecessary lines will be
+ removed from the HTML content.
+
+ Returns:
+ A sequence of Document objects with transformed content.
+ """
+ for doc in documents:
+ cleaned_content = doc.page_content
+
+ cleaned_content = self.remove_unwanted_tags(cleaned_content, unwanted_tags)
+
+ cleaned_content = self.extract_tags(cleaned_content, tags_to_extract)
+
+ if remove_lines:
+ cleaned_content = self.remove_unnecessary_lines(cleaned_content)
+
+ doc.page_content = cleaned_content
+
+ return documents
+
+ @staticmethod
+ def remove_unwanted_tags(html_content: str, unwanted_tags: List[str]) -> str:
+ """
+ Remove unwanted tags from a given HTML content.
+
+ Args:
+ html_content: The original HTML content string.
+ unwanted_tags: A list of tags to be removed from the HTML.
+
+ Returns:
+ A cleaned HTML string with unwanted tags removed.
+ """
+ from bs4 import BeautifulSoup
+
+ soup = BeautifulSoup(html_content, "html.parser")
+ for tag in unwanted_tags:
+ for element in soup.find_all(tag):
+ element.decompose()
+ return str(soup)
+
+ @staticmethod
+ def extract_tags(html_content: str, tags: List[str]) -> str:
+ """
+ Extract specific tags from a given HTML content.
+
+ Args:
+ html_content: The original HTML content string.
+ tags: A list of tags to be extracted from the HTML.
+
+ Returns:
+ A string combining the content of the extracted tags.
+ """
+ from bs4 import BeautifulSoup
+
+ soup = BeautifulSoup(html_content, "html.parser")
+ text_parts: List[str] = []
+ for element in soup.find_all():
+ if element.name in tags:
+ # Extract all navigable strings recursively from this element.
+ text_parts += get_navigable_strings(element)
+
+ # To avoid duplicate text, remove all descendants from the soup.
+ element.decompose()
+
+ return " ".join(text_parts)
+
+ @staticmethod
+ def remove_unnecessary_lines(content: str) -> str:
+ """
+ Clean up the content by removing unnecessary lines.
+
+ Args:
+ content: A string, which may contain unnecessary lines or spaces.
+
+ Returns:
+ A cleaned string with unnecessary lines removed.
+ """
+ lines = content.split("\n")
+ stripped_lines = [line.strip() for line in lines]
+ non_empty_lines = [line for line in stripped_lines if line]
+ cleaned_content = " ".join(non_empty_lines)
+ return cleaned_content
+
+ async def atransform_documents(
+ self,
+ documents: Sequence[Document],
+ **kwargs: Any,
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+
+
+def get_navigable_strings(element: Any) -> Iterator[str]:
+ from bs4 import NavigableString, Tag
+
+ for child in cast(Tag, element).children:
+ if isinstance(child, Tag):
+ yield from get_navigable_strings(child)
+ elif isinstance(child, NavigableString):
+ if (element.name == "a") and (href := element.get("href")):
+ yield f"{child.strip()} ({href})"
+ else:
+ yield child.strip()
diff --git a/libs/community/langchain_community/document_transformers/doctran_text_extract.py b/libs/community/langchain_community/document_transformers/doctran_text_extract.py
new file mode 100644
index 00000000000..e3028229253
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/doctran_text_extract.py
@@ -0,0 +1,94 @@
+from typing import Any, List, Optional, Sequence
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+from langchain_core.utils import get_from_env
+
+
+class DoctranPropertyExtractor(BaseDocumentTransformer):
+ """Extract properties from text documents using doctran.
+
+ Arguments:
+ properties: A list of the properties to extract.
+ openai_api_key: OpenAI API key. Can also be specified via environment variable
+ ``OPENAI_API_KEY``.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_transformers import DoctranPropertyExtractor
+
+ properties = [
+ {
+ "name": "category",
+ "description": "What type of email this is.",
+ "type": "string",
+ "enum": ["update", "action_item", "customer_feedback", "announcement", "other"],
+ "required": True,
+ },
+ {
+ "name": "mentions",
+ "description": "A list of all people mentioned in this email.",
+ "type": "array",
+ "items": {
+ "name": "full_name",
+ "description": "The full name of the person mentioned.",
+ "type": "string",
+ },
+ "required": True,
+ },
+ {
+ "name": "eli5",
+ "description": "Explain this email to me like I'm 5 years old.",
+ "type": "string",
+ "required": True,
+ },
+ ]
+
+ # Pass in openai_api_key or set env var OPENAI_API_KEY
+ property_extractor = DoctranPropertyExtractor(properties)
+ transformed_document = await qa_transformer.atransform_documents(documents)
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ properties: List[dict],
+ openai_api_key: Optional[str] = None,
+ openai_api_model: Optional[str] = None,
+ ) -> None:
+ self.properties = properties
+ self.openai_api_key = openai_api_key or get_from_env(
+ "openai_api_key", "OPENAI_API_KEY"
+ )
+ self.openai_api_model = openai_api_model or get_from_env(
+ "openai_api_model", "OPENAI_API_MODEL"
+ )
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Extracts properties from text documents using doctran."""
+ try:
+ from doctran import Doctran, ExtractProperty
+
+ doctran = Doctran(
+ openai_api_key=self.openai_api_key, openai_model=self.openai_api_model
+ )
+ except ImportError:
+ raise ImportError(
+ "Install doctran to use this parser. (pip install doctran)"
+ )
+ properties = [ExtractProperty(**property) for property in self.properties]
+ for d in documents:
+ doctran_doc = (
+ await doctran.parse(content=d.page_content)
+ .extract(properties=properties)
+ .execute()
+ )
+
+ d.metadata["extracted_properties"] = doctran_doc.extracted_properties
+ return documents
diff --git a/libs/community/langchain_community/document_transformers/doctran_text_qa.py b/libs/community/langchain_community/document_transformers/doctran_text_qa.py
new file mode 100644
index 00000000000..46940f6ddb8
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/doctran_text_qa.py
@@ -0,0 +1,63 @@
+from typing import Any, Optional, Sequence
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+from langchain_core.utils import get_from_env
+
+
+class DoctranQATransformer(BaseDocumentTransformer):
+ """Extract QA from text documents using doctran.
+
+ Arguments:
+ openai_api_key: OpenAI API key. Can also be specified via environment variable
+ ``OPENAI_API_KEY``.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_transformers import DoctranQATransformer
+
+ # Pass in openai_api_key or set env var OPENAI_API_KEY
+ qa_transformer = DoctranQATransformer()
+ transformed_document = await qa_transformer.atransform_documents(documents)
+ """
+
+ def __init__(
+ self,
+ openai_api_key: Optional[str] = None,
+ openai_api_model: Optional[str] = None,
+ ) -> None:
+ self.openai_api_key = openai_api_key or get_from_env(
+ "openai_api_key", "OPENAI_API_KEY"
+ )
+ self.openai_api_model = openai_api_model or get_from_env(
+ "openai_api_model", "OPENAI_API_MODEL"
+ )
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Extracts QA from text documents using doctran."""
+ try:
+ from doctran import Doctran
+
+ doctran = Doctran(
+ openai_api_key=self.openai_api_key, openai_model=self.openai_api_model
+ )
+ except ImportError:
+ raise ImportError(
+ "Install doctran to use this parser. (pip install doctran)"
+ )
+ for d in documents:
+ doctran_doc = (
+ await doctran.parse(content=d.page_content).interrogate().execute()
+ )
+ questions_and_answers = doctran_doc.extracted_properties.get(
+ "questions_and_answers"
+ )
+ d.metadata["questions_and_answers"] = questions_and_answers
+ return documents
diff --git a/libs/community/langchain_community/document_transformers/doctran_text_translate.py b/libs/community/langchain_community/document_transformers/doctran_text_translate.py
new file mode 100644
index 00000000000..d6acb5e68ba
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/doctran_text_translate.py
@@ -0,0 +1,67 @@
+from typing import Any, Optional, Sequence
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+from langchain_core.utils import get_from_env
+
+
+class DoctranTextTranslator(BaseDocumentTransformer):
+ """Translate text documents using doctran.
+
+ Arguments:
+ openai_api_key: OpenAI API key. Can also be specified via environment variable
+ ``OPENAI_API_KEY``.
+ language: The language to translate *to*.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.document_transformers import DoctranTextTranslator
+
+ # Pass in openai_api_key or set env var OPENAI_API_KEY
+ qa_translator = DoctranTextTranslator(language="spanish")
+ translated_document = await qa_translator.atransform_documents(documents)
+ """
+
+ def __init__(
+ self,
+ openai_api_key: Optional[str] = None,
+ language: str = "english",
+ openai_api_model: Optional[str] = None,
+ ) -> None:
+ self.openai_api_key = openai_api_key or get_from_env(
+ "openai_api_key", "OPENAI_API_KEY"
+ )
+ self.openai_api_model = openai_api_model or get_from_env(
+ "openai_api_model", "OPENAI_API_MODEL"
+ )
+ self.language = language
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Translates text documents using doctran."""
+ try:
+ from doctran import Doctran
+
+ doctran = Doctran(
+ openai_api_key=self.openai_api_key, openai_model=self.openai_api_model
+ )
+ except ImportError:
+ raise ImportError(
+ "Install doctran to use this parser. (pip install doctran)"
+ )
+ doctran_docs = [
+ doctran.parse(content=doc.page_content, metadata=doc.metadata)
+ for doc in documents
+ ]
+ for i, doc in enumerate(doctran_docs):
+ doctran_docs[i] = await doc.translate(language=self.language).execute()
+ return [
+ Document(page_content=doc.transformed_content, metadata=doc.metadata)
+ for doc in doctran_docs
+ ]
diff --git a/libs/community/langchain_community/document_transformers/embeddings_redundant_filter.py b/libs/community/langchain_community/document_transformers/embeddings_redundant_filter.py
new file mode 100644
index 00000000000..8e6cc89dc04
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/embeddings_redundant_filter.py
@@ -0,0 +1,212 @@
+"""Transform documents"""
+from typing import Any, Callable, List, Sequence
+
+import numpy as np
+from langchain_core.documents import BaseDocumentTransformer, Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.utils.math import cosine_similarity
+
+
+class _DocumentWithState(Document):
+ """Wrapper for a document that includes arbitrary state."""
+
+ state: dict = Field(default_factory=dict)
+ """State associated with the document."""
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ def to_document(self) -> Document:
+ """Convert the DocumentWithState to a Document."""
+ return Document(page_content=self.page_content, metadata=self.metadata)
+
+ @classmethod
+ def from_document(cls, doc: Document) -> "_DocumentWithState":
+ """Create a DocumentWithState from a Document."""
+ if isinstance(doc, cls):
+ return doc
+ return cls(page_content=doc.page_content, metadata=doc.metadata)
+
+
+def get_stateful_documents(
+ documents: Sequence[Document],
+) -> Sequence[_DocumentWithState]:
+ """Convert a list of documents to a list of documents with state.
+
+ Args:
+ documents: The documents to convert.
+
+ Returns:
+ A list of documents with state.
+ """
+ return [_DocumentWithState.from_document(doc) for doc in documents]
+
+
+def _filter_similar_embeddings(
+ embedded_documents: List[List[float]], similarity_fn: Callable, threshold: float
+) -> List[int]:
+ """Filter redundant documents based on the similarity of their embeddings."""
+ similarity = np.tril(similarity_fn(embedded_documents, embedded_documents), k=-1)
+ redundant = np.where(similarity > threshold)
+ redundant_stacked = np.column_stack(redundant)
+ redundant_sorted = np.argsort(similarity[redundant])[::-1]
+ included_idxs = set(range(len(embedded_documents)))
+ for first_idx, second_idx in redundant_stacked[redundant_sorted]:
+ if first_idx in included_idxs and second_idx in included_idxs:
+ # Default to dropping the second document of any highly similar pair.
+ included_idxs.remove(second_idx)
+ return list(sorted(included_idxs))
+
+
+def _get_embeddings_from_stateful_docs(
+ embeddings: Embeddings, documents: Sequence[_DocumentWithState]
+) -> List[List[float]]:
+ if len(documents) and "embedded_doc" in documents[0].state:
+ embedded_documents = [doc.state["embedded_doc"] for doc in documents]
+ else:
+ embedded_documents = embeddings.embed_documents(
+ [d.page_content for d in documents]
+ )
+ for doc, embedding in zip(documents, embedded_documents):
+ doc.state["embedded_doc"] = embedding
+ return embedded_documents
+
+
+def _filter_cluster_embeddings(
+ embedded_documents: List[List[float]],
+ num_clusters: int,
+ num_closest: int,
+ random_state: int,
+ remove_duplicates: bool,
+) -> List[int]:
+ """Filter documents based on proximity of their embeddings to clusters."""
+
+ try:
+ from sklearn.cluster import KMeans
+ except ImportError:
+ raise ImportError(
+ "sklearn package not found, please install it with "
+ "`pip install scikit-learn`"
+ )
+
+ kmeans = KMeans(n_clusters=num_clusters, random_state=random_state).fit(
+ embedded_documents
+ )
+ closest_indices = []
+
+ # Loop through the number of clusters you have
+ for i in range(num_clusters):
+ # Get the list of distances from that particular cluster center
+ distances = np.linalg.norm(
+ embedded_documents - kmeans.cluster_centers_[i], axis=1
+ )
+
+ # Find the indices of the two unique closest ones
+ # (using argsort to find the smallest 2 distances)
+ if remove_duplicates:
+ # Only add not duplicated vectors.
+ closest_indices_sorted = [
+ x
+ for x in np.argsort(distances)[:num_closest]
+ if x not in closest_indices
+ ]
+ else:
+ # Skip duplicates and add the next closest vector.
+ closest_indices_sorted = [
+ x for x in np.argsort(distances) if x not in closest_indices
+ ][:num_closest]
+
+ # Append that position closest indices list
+ closest_indices.extend(closest_indices_sorted)
+
+ return closest_indices
+
+
+class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
+ """Filter that drops redundant documents by comparing their embeddings."""
+
+ embeddings: Embeddings
+ """Embeddings to use for embedding document contents."""
+ similarity_fn: Callable = cosine_similarity
+ """Similarity function for comparing documents. Function expected to take as input
+ two matrices (List[List[float]]) and return a matrix of scores where higher values
+ indicate greater similarity."""
+ similarity_threshold: float = 0.95
+ """Threshold for determining when two documents are similar enough
+ to be considered redundant."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Filter down documents."""
+ stateful_documents = get_stateful_documents(documents)
+ embedded_documents = _get_embeddings_from_stateful_docs(
+ self.embeddings, stateful_documents
+ )
+ included_idxs = _filter_similar_embeddings(
+ embedded_documents, self.similarity_fn, self.similarity_threshold
+ )
+ return [stateful_documents[i] for i in sorted(included_idxs)]
+
+
+class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel):
+ """Perform K-means clustering on document vectors.
+ Returns an arbitrary number of documents closest to center."""
+
+ embeddings: Embeddings
+ """Embeddings to use for embedding document contents."""
+
+ num_clusters: int = 5
+ """Number of clusters. Groups of documents with similar meaning."""
+
+ num_closest: int = 1
+ """The number of closest vectors to return for each cluster center."""
+
+ random_state: int = 42
+ """Controls the random number generator used to initialize the cluster centroids.
+ If you set the random_state parameter to None, the KMeans algorithm will use a
+ random number generator that is seeded with the current time. This means
+ that the results of the KMeans algorithm will be different each time you
+ run it."""
+
+ sorted: bool = False
+ """By default results are re-ordered "grouping" them by cluster, if sorted is true
+ result will be ordered by the original position from the retriever"""
+
+ remove_duplicates: bool = False
+ """ By default duplicated results are skipped and replaced by the next closest
+ vector in the cluster. If remove_duplicates is true no replacement will be done:
+ This could dramatically reduce results when there is a lot of overlap between
+ clusters.
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Filter down documents."""
+ stateful_documents = get_stateful_documents(documents)
+ embedded_documents = _get_embeddings_from_stateful_docs(
+ self.embeddings, stateful_documents
+ )
+ included_idxs = _filter_cluster_embeddings(
+ embedded_documents,
+ self.num_clusters,
+ self.num_closest,
+ self.random_state,
+ self.remove_duplicates,
+ )
+ results = sorted(included_idxs) if self.sorted else included_idxs
+ return [stateful_documents[i] for i in results]
diff --git a/libs/community/langchain_community/document_transformers/google_translate.py b/libs/community/langchain_community/document_transformers/google_translate.py
new file mode 100644
index 00000000000..b2653c48919
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/google_translate.py
@@ -0,0 +1,107 @@
+from typing import Any, Optional, Sequence
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+
+from langchain_community.utilities.vertexai import get_client_info
+
+
+class GoogleTranslateTransformer(BaseDocumentTransformer):
+ """Translate text documents using Google Cloud Translation."""
+
+ def __init__(
+ self,
+ project_id: str,
+ *,
+ location: str = "global",
+ model_id: Optional[str] = None,
+ glossary_id: Optional[str] = None,
+ api_endpoint: Optional[str] = None,
+ ) -> None:
+ """
+ Arguments:
+ project_id: Google Cloud Project ID.
+ location: (Optional) Translate model location.
+ model_id: (Optional) Translate model ID to use.
+ glossary_id: (Optional) Translate glossary ID to use.
+ api_endpoint: (Optional) Regional endpoint to use.
+ """
+ try:
+ from google.api_core.client_options import ClientOptions
+ from google.cloud import translate
+ except ImportError as exc:
+ raise ImportError(
+ "Install Google Cloud Translate to use this parser."
+ "(pip install google-cloud-translate)"
+ ) from exc
+
+ self.project_id = project_id
+ self.location = location
+ self.model_id = model_id
+ self.glossary_id = glossary_id
+
+ self._client = translate.TranslationServiceClient(
+ client_info=get_client_info("translate"),
+ client_options=(
+ ClientOptions(api_endpoint=api_endpoint) if api_endpoint else None
+ ),
+ )
+ self._parent_path = self._client.common_location_path(project_id, location)
+ # For some reason, there's no `model_path()` method for the client.
+ self._model_path = (
+ f"{self._parent_path}/models/{model_id}" if model_id else None
+ )
+ self._glossary_path = (
+ self._client.glossary_path(project_id, location, glossary_id)
+ if glossary_id
+ else None
+ )
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Translate text documents using Google Translate.
+
+ Arguments:
+ source_language_code: ISO 639 language code of the input document.
+ target_language_code: ISO 639 language code of the output document.
+ For supported languages, refer to:
+ https://cloud.google.com/translate/docs/languages
+ mime_type: (Optional) Media Type of input text.
+ Options: `text/plain`, `text/html`
+ """
+ try:
+ from google.cloud import translate
+ except ImportError as exc:
+ raise ImportError(
+ "Install Google Cloud Translate to use this parser."
+ "(pip install google-cloud-translate)"
+ ) from exc
+
+ response = self._client.translate_text(
+ request=translate.TranslateTextRequest(
+ contents=[doc.page_content for doc in documents],
+ parent=self._parent_path,
+ model=self._model_path,
+ glossary_config=translate.TranslateTextGlossaryConfig(
+ glossary=self._glossary_path
+ ),
+ source_language_code=kwargs.get("source_language_code", None),
+ target_language_code=kwargs.get("target_language_code"),
+ mime_type=kwargs.get("mime_type", "text/plain"),
+ )
+ )
+
+ # If using a glossary, the translations will be in `glossary_translations`.
+ translations = response.glossary_translations or response.translations
+
+ return [
+ Document(
+ page_content=translation.translated_text,
+ metadata={
+ **doc.metadata,
+ "model": translation.model,
+ "detected_language_code": translation.detected_language_code,
+ },
+ )
+ for doc, translation in zip(documents, translations)
+ ]
diff --git a/libs/community/langchain_community/document_transformers/html2text.py b/libs/community/langchain_community/document_transformers/html2text.py
new file mode 100644
index 00000000000..cbf7cf366e4
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/html2text.py
@@ -0,0 +1,56 @@
+from typing import Any, Sequence
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+
+
+class Html2TextTransformer(BaseDocumentTransformer):
+ """Replace occurrences of a particular search pattern with a replacement string
+
+ Arguments:
+ ignore_links: Whether links should be ignored; defaults to True.
+ ignore_images: Whether images should be ignored; defaults to True.
+
+ Example:
+ .. code-block:: python
+ from langchain_community.document_transformers import Html2TextTransformer
+ html2text = Html2TextTransformer()
+ docs_transform = html2text.transform_documents(docs)
+ """
+
+ def __init__(self, ignore_links: bool = True, ignore_images: bool = True) -> None:
+ self.ignore_links = ignore_links
+ self.ignore_images = ignore_images
+
+ def transform_documents(
+ self,
+ documents: Sequence[Document],
+ **kwargs: Any,
+ ) -> Sequence[Document]:
+ try:
+ import html2text
+ except ImportError:
+ raise ImportError(
+ """html2text package not found, please
+ install it with `pip install html2text`"""
+ )
+
+ # Create a html2text.HTML2Text object and override some properties
+ h = html2text.HTML2Text()
+ h.ignore_links = self.ignore_links
+ h.ignore_images = self.ignore_images
+
+ new_documents = []
+
+ for d in documents:
+ new_document = Document(
+ page_content=h.handle(d.page_content), metadata={**d.metadata}
+ )
+ new_documents.append(new_document)
+ return new_documents
+
+ async def atransform_documents(
+ self,
+ documents: Sequence[Document],
+ **kwargs: Any,
+ ) -> Sequence[Document]:
+ raise NotImplementedError
diff --git a/libs/community/langchain_community/document_transformers/long_context_reorder.py b/libs/community/langchain_community/document_transformers/long_context_reorder.py
new file mode 100644
index 00000000000..32eda6b4826
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/long_context_reorder.py
@@ -0,0 +1,43 @@
+"""Reorder documents"""
+from typing import Any, List, Sequence
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+from langchain_core.pydantic_v1 import BaseModel
+
+
+def _litm_reordering(documents: List[Document]) -> List[Document]:
+ """Lost in the middle reorder: the less relevant documents will be at the
+ middle of the list and more relevant elements at beginning / end.
+ See: https://arxiv.org/abs//2307.03172"""
+
+ documents.reverse()
+ reordered_result = []
+ for i, value in enumerate(documents):
+ if i % 2 == 1:
+ reordered_result.append(value)
+ else:
+ reordered_result.insert(0, value)
+ return reordered_result
+
+
+class LongContextReorder(BaseDocumentTransformer, BaseModel):
+ """Lost in the middle:
+ Performance degrades when models must access relevant information
+ in the middle of long contexts.
+ See: https://arxiv.org/abs//2307.03172"""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Reorders documents."""
+ return _litm_reordering(list(documents))
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
diff --git a/libs/community/langchain_community/document_transformers/nuclia_text_transform.py b/libs/community/langchain_community/document_transformers/nuclia_text_transform.py
new file mode 100644
index 00000000000..ed62f527e1c
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/nuclia_text_transform.py
@@ -0,0 +1,48 @@
+import asyncio
+import json
+import uuid
+from typing import Any, Sequence
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+
+from langchain_community.tools.nuclia.tool import NucliaUnderstandingAPI
+
+
+class NucliaTextTransformer(BaseDocumentTransformer):
+ """
+ The Nuclia Understanding API splits into paragraphs and sentences,
+ identifies entities, provides a summary of the text and generates
+ embeddings for all sentences.
+ """
+
+ def __init__(self, nua: NucliaUnderstandingAPI):
+ self.nua = nua
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ tasks = [
+ self.nua.arun(
+ {
+ "action": "push",
+ "id": str(uuid.uuid4()),
+ "text": doc.page_content,
+ "path": None,
+ }
+ )
+ for doc in documents
+ ]
+ results = await asyncio.gather(*tasks)
+ for doc, result in zip(documents, results):
+ obj = json.loads(result)
+ metadata = {
+ "file": obj["file_extracted_data"][0],
+ "metadata": obj["field_metadata"][0],
+ }
+ doc.metadata["nuclia"] = metadata
+ return documents
diff --git a/libs/community/langchain_community/document_transformers/openai_functions.py b/libs/community/langchain_community/document_transformers/openai_functions.py
new file mode 100644
index 00000000000..5c9dec22e61
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/openai_functions.py
@@ -0,0 +1,141 @@
+"""Document transformers that use OpenAI Functions models"""
+from typing import Any, Dict, Optional, Sequence, Type, Union
+
+from langchain_core.documents import BaseDocumentTransformer, Document
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.pydantic_v1 import BaseModel
+
+
+class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
+ """Extract metadata tags from document contents using OpenAI functions.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatOpenAI
+ from langchain_community.document_transformers import OpenAIMetadataTagger
+ from langchain_core.documents import Document
+
+ schema = {
+ "properties": {
+ "movie_title": { "type": "string" },
+ "critic": { "type": "string" },
+ "tone": {
+ "type": "string",
+ "enum": ["positive", "negative"]
+ },
+ "rating": {
+ "type": "integer",
+ "description": "The number of stars the critic rated the movie"
+ }
+ },
+ "required": ["movie_title", "critic", "tone"]
+ }
+
+ # Must be an OpenAI model that supports functions
+ llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
+ tagging_chain = create_tagging_chain(schema, llm)
+ document_transformer = OpenAIMetadataTagger(tagging_chain=tagging_chain)
+ original_documents = [
+ Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\nThis is the greatest movie ever made. 4 out of 5 stars."),
+ Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
+ ]
+
+ enhanced_documents = document_transformer.transform_documents(original_documents)
+ """ # noqa: E501
+
+ tagging_chain: Any
+ """The chain used to extract metadata from each document."""
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Automatically extract and populate metadata
+ for each document according to the provided schema."""
+
+ new_documents = []
+
+ for document in documents:
+ extracted_metadata: Dict = self.tagging_chain.run(document.page_content) # type: ignore[assignment] # noqa: E501
+ new_document = Document(
+ page_content=document.page_content,
+ metadata={**extracted_metadata, **document.metadata},
+ )
+ new_documents.append(new_document)
+ return new_documents
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+
+
+def create_metadata_tagger(
+ metadata_schema: Union[Dict[str, Any], Type[BaseModel]],
+ llm: BaseLanguageModel,
+ prompt: Optional[ChatPromptTemplate] = None,
+ *,
+ tagging_chain_kwargs: Optional[Dict] = None,
+) -> OpenAIMetadataTagger:
+ """Create a DocumentTransformer that uses an OpenAI function chain to automatically
+ tag documents with metadata based on their content and an input schema.
+
+ Args:
+ metadata_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
+ is passed in, it's assumed to already be a valid JsonSchema.
+ For best results, pydantic.BaseModels should have docstrings describing what
+ the schema represents and descriptions for the parameters.
+ llm: Language model to use, assumed to support the OpenAI function-calling API.
+ Defaults to use "gpt-3.5-turbo-0613"
+ prompt: BasePromptTemplate to pass to the model.
+
+ Returns:
+ An LLMChain that will pass the given function to the model.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.chat_models import ChatOpenAI
+ from langchain_community.document_transformers import create_metadata_tagger
+ from langchain_core.documents import Document
+
+ schema = {
+ "properties": {
+ "movie_title": { "type": "string" },
+ "critic": { "type": "string" },
+ "tone": {
+ "type": "string",
+ "enum": ["positive", "negative"]
+ },
+ "rating": {
+ "type": "integer",
+ "description": "The number of stars the critic rated the movie"
+ }
+ },
+ "required": ["movie_title", "critic", "tone"]
+ }
+
+ # Must be an OpenAI model that supports functions
+ llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
+
+ document_transformer = create_metadata_tagger(schema, llm)
+ original_documents = [
+ Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\nThis is the greatest movie ever made. 4 out of 5 stars."),
+ Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
+ ]
+
+ enhanced_documents = document_transformer.transform_documents(original_documents)
+ """ # noqa: E501
+ from langchain.chains.openai_functions import create_tagging_chain
+
+ metadata_schema = (
+ metadata_schema
+ if isinstance(metadata_schema, dict)
+ else metadata_schema.schema()
+ )
+ _tagging_chain_kwargs = tagging_chain_kwargs or {}
+ tagging_chain = create_tagging_chain(
+ metadata_schema, llm, prompt=prompt, **_tagging_chain_kwargs
+ )
+ return OpenAIMetadataTagger(tagging_chain=tagging_chain)
diff --git a/libs/community/langchain_community/document_transformers/xsl/html_chunks_with_headers.xslt b/libs/community/langchain_community/document_transformers/xsl/html_chunks_with_headers.xslt
new file mode 100644
index 00000000000..285edfe892d
--- /dev/null
+++ b/libs/community/langchain_community/document_transformers/xsl/html_chunks_with_headers.xslt
@@ -0,0 +1,199 @@
+
+
+
+
+ div|p|blockquote|ol|ul
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ [
+
+ ]/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py
new file mode 100644
index 00000000000..ce9cfc7aa0b
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/__init__.py
@@ -0,0 +1,161 @@
+"""**Embedding models** are wrappers around embedding models
+from different APIs and services.
+
+**Embedding models** can be LLMs or not.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ Embeddings --> Embeddings # Examples: OpenAIEmbeddings, HuggingFaceEmbeddings
+"""
+
+
+import logging
+from typing import Any
+
+from langchain_community.embeddings.aleph_alpha import (
+ AlephAlphaAsymmetricSemanticEmbedding,
+ AlephAlphaSymmetricSemanticEmbedding,
+)
+from langchain_community.embeddings.awa import AwaEmbeddings
+from langchain_community.embeddings.azure_openai import AzureOpenAIEmbeddings
+from langchain_community.embeddings.baidu_qianfan_endpoint import (
+ QianfanEmbeddingsEndpoint,
+)
+from langchain_community.embeddings.bedrock import BedrockEmbeddings
+from langchain_community.embeddings.bookend import BookendEmbeddings
+from langchain_community.embeddings.clarifai import ClarifaiEmbeddings
+from langchain_community.embeddings.cohere import CohereEmbeddings
+from langchain_community.embeddings.dashscope import DashScopeEmbeddings
+from langchain_community.embeddings.databricks import DatabricksEmbeddings
+from langchain_community.embeddings.deepinfra import DeepInfraEmbeddings
+from langchain_community.embeddings.edenai import EdenAiEmbeddings
+from langchain_community.embeddings.elasticsearch import ElasticsearchEmbeddings
+from langchain_community.embeddings.embaas import EmbaasEmbeddings
+from langchain_community.embeddings.ernie import ErnieEmbeddings
+from langchain_community.embeddings.fake import (
+ DeterministicFakeEmbedding,
+ FakeEmbeddings,
+)
+from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
+from langchain_community.embeddings.google_palm import GooglePalmEmbeddings
+from langchain_community.embeddings.gpt4all import GPT4AllEmbeddings
+from langchain_community.embeddings.gradient_ai import GradientEmbeddings
+from langchain_community.embeddings.huggingface import (
+ HuggingFaceBgeEmbeddings,
+ HuggingFaceEmbeddings,
+ HuggingFaceInferenceAPIEmbeddings,
+ HuggingFaceInstructEmbeddings,
+)
+from langchain_community.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
+from langchain_community.embeddings.infinity import InfinityEmbeddings
+from langchain_community.embeddings.javelin_ai_gateway import JavelinAIGatewayEmbeddings
+from langchain_community.embeddings.jina import JinaEmbeddings
+from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
+from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
+from langchain_community.embeddings.localai import LocalAIEmbeddings
+from langchain_community.embeddings.minimax import MiniMaxEmbeddings
+from langchain_community.embeddings.mlflow import MlflowEmbeddings
+from langchain_community.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
+from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings
+from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings
+from langchain_community.embeddings.nlpcloud import NLPCloudEmbeddings
+from langchain_community.embeddings.octoai_embeddings import OctoAIEmbeddings
+from langchain_community.embeddings.ollama import OllamaEmbeddings
+from langchain_community.embeddings.openai import OpenAIEmbeddings
+from langchain_community.embeddings.sagemaker_endpoint import (
+ SagemakerEndpointEmbeddings,
+)
+from langchain_community.embeddings.self_hosted import SelfHostedEmbeddings
+from langchain_community.embeddings.self_hosted_hugging_face import (
+ SelfHostedHuggingFaceEmbeddings,
+ SelfHostedHuggingFaceInstructEmbeddings,
+)
+from langchain_community.embeddings.sentence_transformer import (
+ SentenceTransformerEmbeddings,
+)
+from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings
+from langchain_community.embeddings.tensorflow_hub import TensorflowHubEmbeddings
+from langchain_community.embeddings.vertexai import VertexAIEmbeddings
+from langchain_community.embeddings.voyageai import VoyageEmbeddings
+from langchain_community.embeddings.xinference import XinferenceEmbeddings
+
+logger = logging.getLogger(__name__)
+
+__all__ = [
+ "OpenAIEmbeddings",
+ "AzureOpenAIEmbeddings",
+ "ClarifaiEmbeddings",
+ "CohereEmbeddings",
+ "DatabricksEmbeddings",
+ "ElasticsearchEmbeddings",
+ "FastEmbedEmbeddings",
+ "HuggingFaceEmbeddings",
+ "HuggingFaceInferenceAPIEmbeddings",
+ "InfinityEmbeddings",
+ "GradientEmbeddings",
+ "JinaEmbeddings",
+ "LlamaCppEmbeddings",
+ "HuggingFaceHubEmbeddings",
+ "MlflowEmbeddings",
+ "MlflowAIGatewayEmbeddings",
+ "ModelScopeEmbeddings",
+ "TensorflowHubEmbeddings",
+ "SagemakerEndpointEmbeddings",
+ "HuggingFaceInstructEmbeddings",
+ "MosaicMLInstructorEmbeddings",
+ "SelfHostedEmbeddings",
+ "SelfHostedHuggingFaceEmbeddings",
+ "SelfHostedHuggingFaceInstructEmbeddings",
+ "FakeEmbeddings",
+ "DeterministicFakeEmbedding",
+ "AlephAlphaAsymmetricSemanticEmbedding",
+ "AlephAlphaSymmetricSemanticEmbedding",
+ "SentenceTransformerEmbeddings",
+ "GooglePalmEmbeddings",
+ "MiniMaxEmbeddings",
+ "VertexAIEmbeddings",
+ "BedrockEmbeddings",
+ "DeepInfraEmbeddings",
+ "EdenAiEmbeddings",
+ "DashScopeEmbeddings",
+ "EmbaasEmbeddings",
+ "OctoAIEmbeddings",
+ "SpacyEmbeddings",
+ "NLPCloudEmbeddings",
+ "GPT4AllEmbeddings",
+ "XinferenceEmbeddings",
+ "LocalAIEmbeddings",
+ "AwaEmbeddings",
+ "HuggingFaceBgeEmbeddings",
+ "ErnieEmbeddings",
+ "JavelinAIGatewayEmbeddings",
+ "OllamaEmbeddings",
+ "QianfanEmbeddingsEndpoint",
+ "JohnSnowLabsEmbeddings",
+ "VoyageEmbeddings",
+ "BookendEmbeddings",
+]
+
+
+# TODO: this is in here to maintain backwards compatibility
+class HypotheticalDocumentEmbedder:
+ def __init__(self, *args: Any, **kwargs: Any):
+ logger.warning(
+ "Using a deprecated class. Please use "
+ "`from langchain.chains import HypotheticalDocumentEmbedder` instead"
+ )
+ from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
+
+ return H(*args, **kwargs) # type: ignore
+
+ @classmethod
+ def from_llm(cls, *args: Any, **kwargs: Any) -> Any:
+ logger.warning(
+ "Using a deprecated class. Please use "
+ "`from langchain.chains import HypotheticalDocumentEmbedder` instead"
+ )
+ from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
+
+ return H.from_llm(*args, **kwargs)
diff --git a/libs/community/langchain_community/embeddings/aleph_alpha.py b/libs/community/langchain_community/embeddings/aleph_alpha.py
new file mode 100644
index 00000000000..41f970673fd
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/aleph_alpha.py
@@ -0,0 +1,255 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
+ """Aleph Alpha's asymmetric semantic embedding.
+
+ AA provides you with an endpoint to embed a document and a query.
+ The models were optimized to make the embeddings of documents and
+ the query for a document as similar as possible.
+ To learn more, check out: https://docs.aleph-alpha.com/docs/tasks/semantic_embed/
+
+ Example:
+ .. code-block:: python
+ from aleph_alpha import AlephAlphaAsymmetricSemanticEmbedding
+
+ embeddings = AlephAlphaAsymmetricSemanticEmbedding(
+ normalize=True, compress_to_size=128
+ )
+
+ document = "This is a content of the document"
+ query = "What is the content of the document?"
+
+ doc_result = embeddings.embed_documents([document])
+ query_result = embeddings.embed_query(query)
+
+ """
+
+ client: Any #: :meta private:
+
+ # Embedding params
+ model: str = "luminous-base"
+ """Model name to use."""
+ compress_to_size: Optional[int] = None
+ """Should the returned embeddings come back as an original 5120-dim vector,
+ or should it be compressed to 128-dim."""
+ normalize: Optional[bool] = None
+ """Should returned embeddings be normalized"""
+ contextual_control_threshold: Optional[int] = None
+ """Attention control parameters only apply to those tokens that have
+ explicitly been set in the request."""
+ control_log_additive: bool = True
+ """Apply controls on prompt items by adding the log(control_factor)
+ to attention scores."""
+
+ # Client params
+ aleph_alpha_api_key: Optional[str] = None
+ """API key for Aleph Alpha API."""
+ host: str = "https://api.aleph-alpha.com"
+ """The hostname of the API host.
+ The default one is "https://api.aleph-alpha.com")"""
+ hosting: Optional[str] = None
+ """Determines in which datacenters the request may be processed.
+ You can either set the parameter to "aleph-alpha" or omit it (defaulting to None).
+ Not setting this value, or setting it to None, gives us maximal flexibility
+ in processing your request in our
+ own datacenters and on servers hosted with other providers.
+ Choose this option for maximal availability.
+ Setting it to "aleph-alpha" allows us to only process the request
+ in our own datacenters.
+ Choose this option for maximal data privacy."""
+ request_timeout_seconds: int = 305
+ """Client timeout that will be set for HTTP requests in the
+ `requests` library's API calls.
+ Server will close all requests after 300 seconds with an internal server error."""
+ total_retries: int = 8
+ """The number of retries made in case requests fail with certain retryable
+ status codes. If the last
+ retry fails a corresponding exception is raised. Note, that between retries
+ an exponential backoff
+ is applied, starting with 0.5 s after the first retry and doubling for each
+ retry made. So with the
+ default setting of 8 retries a total wait time of 63.5 s is added between
+ the retries."""
+ nice: bool = False
+ """Setting this to True, will signal to the API that you intend to be
+ nice to other users
+ by de-prioritizing your request below concurrent ones."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ aleph_alpha_api_key = get_from_dict_or_env(
+ values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY"
+ )
+ try:
+ from aleph_alpha_client import Client
+
+ values["client"] = Client(
+ token=aleph_alpha_api_key,
+ host=values["host"],
+ hosting=values["hosting"],
+ request_timeout_seconds=values["request_timeout_seconds"],
+ total_retries=values["total_retries"],
+ nice=values["nice"],
+ )
+ except ImportError:
+ raise ValueError(
+ "Could not import aleph_alpha_client python package. "
+ "Please install it with `pip install aleph_alpha_client`."
+ )
+
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to Aleph Alpha's asymmetric Document endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ try:
+ from aleph_alpha_client import (
+ Prompt,
+ SemanticEmbeddingRequest,
+ SemanticRepresentation,
+ )
+ except ImportError:
+ raise ValueError(
+ "Could not import aleph_alpha_client python package. "
+ "Please install it with `pip install aleph_alpha_client`."
+ )
+ document_embeddings = []
+
+ for text in texts:
+ document_params = {
+ "prompt": Prompt.from_text(text),
+ "representation": SemanticRepresentation.Document,
+ "compress_to_size": self.compress_to_size,
+ "normalize": self.normalize,
+ "contextual_control_threshold": self.contextual_control_threshold,
+ "control_log_additive": self.control_log_additive,
+ }
+
+ document_request = SemanticEmbeddingRequest(**document_params)
+ document_response = self.client.semantic_embed(
+ request=document_request, model=self.model
+ )
+
+ document_embeddings.append(document_response.embedding)
+
+ return document_embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Aleph Alpha's asymmetric, query embedding endpoint
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ try:
+ from aleph_alpha_client import (
+ Prompt,
+ SemanticEmbeddingRequest,
+ SemanticRepresentation,
+ )
+ except ImportError:
+ raise ValueError(
+ "Could not import aleph_alpha_client python package. "
+ "Please install it with `pip install aleph_alpha_client`."
+ )
+ symmetric_params = {
+ "prompt": Prompt.from_text(text),
+ "representation": SemanticRepresentation.Query,
+ "compress_to_size": self.compress_to_size,
+ "normalize": self.normalize,
+ "contextual_control_threshold": self.contextual_control_threshold,
+ "control_log_additive": self.control_log_additive,
+ }
+
+ symmetric_request = SemanticEmbeddingRequest(**symmetric_params)
+ symmetric_response = self.client.semantic_embed(
+ request=symmetric_request, model=self.model
+ )
+
+ return symmetric_response.embedding
+
+
+class AlephAlphaSymmetricSemanticEmbedding(AlephAlphaAsymmetricSemanticEmbedding):
+ """The symmetric version of the Aleph Alpha's semantic embeddings.
+
+ The main difference is that here, both the documents and
+ queries are embedded with a SemanticRepresentation.Symmetric
+ Example:
+ .. code-block:: python
+
+ from aleph_alpha import AlephAlphaSymmetricSemanticEmbedding
+
+ embeddings = AlephAlphaAsymmetricSemanticEmbedding(
+ normalize=True, compress_to_size=128
+ )
+ text = "This is a test text"
+
+ doc_result = embeddings.embed_documents([text])
+ query_result = embeddings.embed_query(text)
+ """
+
+ def _embed(self, text: str) -> List[float]:
+ try:
+ from aleph_alpha_client import (
+ Prompt,
+ SemanticEmbeddingRequest,
+ SemanticRepresentation,
+ )
+ except ImportError:
+ raise ValueError(
+ "Could not import aleph_alpha_client python package. "
+ "Please install it with `pip install aleph_alpha_client`."
+ )
+ query_params = {
+ "prompt": Prompt.from_text(text),
+ "representation": SemanticRepresentation.Symmetric,
+ "compress_to_size": self.compress_to_size,
+ "normalize": self.normalize,
+ "contextual_control_threshold": self.contextual_control_threshold,
+ "control_log_additive": self.control_log_additive,
+ }
+
+ query_request = SemanticEmbeddingRequest(**query_params)
+ query_response = self.client.semantic_embed(
+ request=query_request, model=self.model
+ )
+
+ return query_response.embedding
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to Aleph Alpha's Document endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ document_embeddings = []
+
+ for text in texts:
+ document_embeddings.append(self._embed(text))
+ return document_embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Aleph Alpha's asymmetric, query embedding endpoint
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self._embed(text)
diff --git a/libs/community/langchain_community/embeddings/awa.py b/libs/community/langchain_community/embeddings/awa.py
new file mode 100644
index 00000000000..9145d8006a7
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/awa.py
@@ -0,0 +1,63 @@
+from typing import Any, Dict, List
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+
+class AwaEmbeddings(BaseModel, Embeddings):
+ """Embedding documents and queries with Awa DB.
+
+ Attributes:
+ client: The AwaEmbedding client.
+ model: The name of the model used for embedding.
+ Default is "all-mpnet-base-v2".
+ """
+
+ client: Any #: :meta private:
+ model: str = "all-mpnet-base-v2"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that awadb library is installed."""
+
+ try:
+ from awadb import AwaEmbedding
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import awadb library. "
+ "Please install it with `pip install awadb`"
+ ) from exc
+ values["client"] = AwaEmbedding()
+ return values
+
+ def set_model(self, model_name: str) -> None:
+ """Set the model used for embedding.
+ The default model used is all-mpnet-base-v2
+
+ Args:
+ model_name: A string which represents the name of model.
+ """
+ self.model = model_name
+ self.client.model_name = model_name
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed a list of documents using AwaEmbedding.
+
+ Args:
+ texts: The list of texts need to be embedded
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ return self.client.EmbeddingBatch(texts)
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using AwaEmbedding.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.client.Embedding(text)
diff --git a/libs/community/langchain_community/embeddings/azure_openai.py b/libs/community/langchain_community/embeddings/azure_openai.py
new file mode 100644
index 00000000000..409a738de87
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/azure_openai.py
@@ -0,0 +1,156 @@
+"""Azure OpenAI embeddings wrapper."""
+from __future__ import annotations
+
+import os
+import warnings
+from typing import Dict, Optional, Union
+
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.embeddings.openai import OpenAIEmbeddings
+from langchain_community.utils.openai import is_openai_v1
+
+
+class AzureOpenAIEmbeddings(OpenAIEmbeddings):
+ """`Azure OpenAI` Embeddings API."""
+
+ azure_endpoint: Union[str, None] = None
+ """Your Azure endpoint, including the resource.
+
+ Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
+
+ Example: `https://example-resource.azure.openai.com/`
+ """
+ deployment: Optional[str] = Field(default=None, alias="azure_deployment")
+ """A model deployment.
+
+ If given sets the base client URL to include `/deployments/{azure_deployment}`.
+ Note: this means you won't be able to use non-deployment endpoints.
+ """
+ openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
+ """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
+ azure_ad_token: Union[str, None] = None
+ """Your Azure Active Directory token.
+
+ Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
+
+ For more:
+ https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
+ """ # noqa: E501
+ azure_ad_token_provider: Union[str, None] = None
+ """A function that returns an Azure Active Directory token.
+
+ Will be invoked on every request.
+ """
+ openai_api_version: Optional[str] = Field(default=None, alias="api_version")
+ """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
+ validate_base_url: bool = True
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ # Check OPENAI_KEY for backwards compatibility.
+ # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
+ # other forms of azure credentials.
+ values["openai_api_key"] = (
+ values["openai_api_key"]
+ or os.getenv("AZURE_OPENAI_API_KEY")
+ or os.getenv("OPENAI_API_KEY")
+ )
+ values["openai_api_base"] = values["openai_api_base"] or os.getenv(
+ "OPENAI_API_BASE"
+ )
+ values["openai_api_version"] = values["openai_api_version"] or os.getenv(
+ "OPENAI_API_VERSION", default="2023-05-15"
+ )
+ values["openai_api_type"] = get_from_dict_or_env(
+ values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
+ )
+ values["openai_organization"] = (
+ values["openai_organization"]
+ or os.getenv("OPENAI_ORG_ID")
+ or os.getenv("OPENAI_ORGANIZATION")
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values,
+ "openai_proxy",
+ "OPENAI_PROXY",
+ default="",
+ )
+ values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
+ "AZURE_OPENAI_ENDPOINT"
+ )
+ values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
+ "AZURE_OPENAI_AD_TOKEN"
+ )
+ # Azure OpenAI embedding models allow a maximum of 16 texts
+ # at a time in each batch
+ # See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
+ values["chunk_size"] = min(values["chunk_size"], 16)
+ try:
+ import openai
+
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ if is_openai_v1():
+ # For backwards compatibility. Before openai v1, no distinction was made
+ # between azure_endpoint and base_url (openai_api_base).
+ openai_api_base = values["openai_api_base"]
+ if openai_api_base and values["validate_base_url"]:
+ if "/openai" not in openai_api_base:
+ values["openai_api_base"] += "/openai"
+ warnings.warn(
+ "As of openai>=1.0.0, Azure endpoints should be specified via "
+ f"the `azure_endpoint` param not `openai_api_base` "
+ f"(or alias `base_url`). Updating `openai_api_base` from "
+ f"{openai_api_base} to {values['openai_api_base']}."
+ )
+ if values["deployment"]:
+ warnings.warn(
+ "As of openai>=1.0.0, if `deployment` (or alias "
+ "`azure_deployment`) is specified then "
+ "`openai_api_base` (or alias `base_url`) should not be. "
+ "Instead use `deployment` (or alias `azure_deployment`) "
+ "and `azure_endpoint`."
+ )
+ if values["deployment"] not in values["openai_api_base"]:
+ warnings.warn(
+ "As of openai>=1.0.0, if `openai_api_base` "
+ "(or alias `base_url`) is specified it is expected to be "
+ "of the form "
+ "https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
+ f"Updating {openai_api_base} to "
+ f"{values['openai_api_base']}."
+ )
+ values["openai_api_base"] += (
+ "/deployments/" + values["deployment"]
+ )
+ values["deployment"] = None
+ client_params = {
+ "api_version": values["openai_api_version"],
+ "azure_endpoint": values["azure_endpoint"],
+ "azure_deployment": values["deployment"],
+ "api_key": values["openai_api_key"],
+ "azure_ad_token": values["azure_ad_token"],
+ "azure_ad_token_provider": values["azure_ad_token_provider"],
+ "organization": values["openai_organization"],
+ "base_url": values["openai_api_base"],
+ "timeout": values["request_timeout"],
+ "max_retries": values["max_retries"],
+ "default_headers": values["default_headers"],
+ "default_query": values["default_query"],
+ "http_client": values["http_client"],
+ }
+ values["client"] = openai.AzureOpenAI(**client_params).embeddings
+ values["async_client"] = openai.AsyncAzureOpenAI(**client_params).embeddings
+ else:
+ values["client"] = openai.Embedding
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ return "azure-openai-chat"
diff --git a/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py b/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py
new file mode 100644
index 00000000000..01c440ab251
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py
@@ -0,0 +1,138 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
+ """`Baidu Qianfan Embeddings` embedding models."""
+
+ qianfan_ak: Optional[str] = None
+ """Qianfan application apikey"""
+
+ qianfan_sk: Optional[str] = None
+ """Qianfan application secretkey"""
+
+ chunk_size: int = 16
+ """Chunk size when multiple texts are input"""
+
+ model: str = "Embedding-V1"
+ """Model name
+ you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
+
+ for now, we support Embedding-V1 and
+ - Embedding-V1 οΌι»θ€ζ¨‘εοΌ
+ - bge-large-en
+ - bge-large-zh
+
+ preset models are mapping to an endpoint.
+ `model` will be ignored if `endpoint` is set
+ """
+
+ endpoint: str = ""
+ """Endpoint of the Qianfan Embedding, required if custom model used."""
+
+ client: Any
+ """Qianfan client"""
+
+ max_retries: int = 5
+ """Max reties times"""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """
+ Validate whether qianfan_ak and qianfan_sk in the environment variables or
+ configuration file are available or not.
+
+ init qianfan embedding client with `ak`, `sk`, `model`, `endpoint`
+
+ Args:
+
+ values: a dictionary containing configuration information, must include the
+ fields of qianfan_ak and qianfan_sk
+ Returns:
+
+ a dictionary containing configuration information. If qianfan_ak and
+ qianfan_sk are not provided in the environment variables or configuration
+ file,the original values will be returned; otherwise, values containing
+ qianfan_ak and qianfan_sk will be returned.
+ Raises:
+
+ ValueError: qianfan package not found, please install it with `pip install
+ qianfan`
+ """
+ values["qianfan_ak"] = get_from_dict_or_env(
+ values,
+ "qianfan_ak",
+ "QIANFAN_AK",
+ )
+ values["qianfan_sk"] = get_from_dict_or_env(
+ values,
+ "qianfan_sk",
+ "QIANFAN_SK",
+ )
+
+ try:
+ import qianfan
+
+ params = {
+ "ak": values["qianfan_ak"],
+ "sk": values["qianfan_sk"],
+ "model": values["model"],
+ }
+ if values["endpoint"] is not None and values["endpoint"] != "":
+ params["endpoint"] = values["endpoint"]
+ values["client"] = qianfan.Embedding(**params)
+ except ImportError:
+ raise ImportError(
+ "qianfan package not found, please install it with "
+ "`pip install qianfan`"
+ )
+ return values
+
+ def embed_query(self, text: str) -> List[float]:
+ resp = self.embed_documents([text])
+ return resp[0]
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """
+ Embeds a list of text documents using the AutoVOT algorithm.
+
+ Args:
+ texts (List[str]): A list of text documents to embed.
+
+ Returns:
+ List[List[float]]: A list of embeddings for each document in the input list.
+ Each embedding is represented as a list of float values.
+ """
+ text_in_chunks = [
+ texts[i : i + self.chunk_size]
+ for i in range(0, len(texts), self.chunk_size)
+ ]
+ lst = []
+ for chunk in text_in_chunks:
+ resp = self.client.do(texts=chunk)
+ lst.extend([res["embedding"] for res in resp["data"]])
+ return lst
+
+ async def aembed_query(self, text: str) -> List[float]:
+ embeddings = await self.aembed_documents([text])
+ return embeddings[0]
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ text_in_chunks = [
+ texts[i : i + self.chunk_size]
+ for i in range(0, len(texts), self.chunk_size)
+ ]
+ lst = []
+ for chunk in text_in_chunks:
+ resp = await self.client.ado(texts=chunk)
+ for res in resp["data"]:
+ lst.extend([res["embedding"]])
+ return lst
diff --git a/libs/community/langchain_community/embeddings/bedrock.py b/libs/community/langchain_community/embeddings/bedrock.py
new file mode 100644
index 00000000000..8e98bfe2858
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/bedrock.py
@@ -0,0 +1,200 @@
+import asyncio
+import json
+import os
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+
+class BedrockEmbeddings(BaseModel, Embeddings):
+ """Bedrock embedding models.
+
+ To authenticate, the AWS client uses the following methods to
+ automatically load credentials:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+
+ If a specific credential profile should be used, you must pass
+ the name of the profile from the ~/.aws/credentials file that is to be used.
+
+ Make sure the credentials / roles used have the required policies to
+ access the Bedrock service.
+ """
+
+ """
+ Example:
+ .. code-block:: python
+
+ from langchain_community.bedrock_embeddings import BedrockEmbeddings
+
+ region_name ="us-east-1"
+ credentials_profile_name = "default"
+ model_id = "amazon.titan-embed-text-v1"
+
+ be = BedrockEmbeddings(
+ credentials_profile_name=credentials_profile_name,
+ region_name=region_name,
+ model_id=model_id
+ )
+ """
+
+ client: Any #: :meta private:
+ """Bedrock client."""
+ region_name: Optional[str] = None
+ """The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
+ or region specified in ~/.aws/config in case it is not provided here.
+ """
+
+ credentials_profile_name: Optional[str] = None
+ """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
+ has either access keys or role information specified.
+ If not specified, the default credential profile or, if on an EC2 instance,
+ credentials from IMDS will be used.
+ See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+ """
+
+ model_id: str = "amazon.titan-embed-text-v1"
+ """Id of the model to call, e.g., amazon.titan-embed-text-v1, this is
+ equivalent to the modelId property in the list-foundation-models api"""
+
+ model_kwargs: Optional[Dict] = None
+ """Keyword arguments to pass to the model."""
+
+ endpoint_url: Optional[str] = None
+ """Needed if you don't want to default to us-east-1 endpoint"""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that AWS credentials to and python package exists in environment."""
+
+ if values["client"] is not None:
+ return values
+
+ try:
+ import boto3
+
+ if values["credentials_profile_name"] is not None:
+ session = boto3.Session(profile_name=values["credentials_profile_name"])
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ client_params = {}
+ if values["region_name"]:
+ client_params["region_name"] = values["region_name"]
+
+ if values["endpoint_url"]:
+ client_params["endpoint_url"] = values["endpoint_url"]
+
+ values["client"] = session.client("bedrock-runtime", **client_params)
+
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ except Exception as e:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+
+ return values
+
+ def _embedding_func(self, text: str) -> List[float]:
+ """Call out to Bedrock embedding endpoint."""
+ # replace newlines, which can negatively affect performance.
+ text = text.replace(os.linesep, " ")
+
+ # format input body for provider
+ provider = self.model_id.split(".")[0]
+ _model_kwargs = self.model_kwargs or {}
+ input_body = {**_model_kwargs}
+ if provider == "cohere":
+ if "input_type" not in input_body.keys():
+ input_body["input_type"] = "search_document"
+ input_body["texts"] = [text]
+ else:
+ # includes common provider == "amazon"
+ input_body["inputText"] = text
+ body = json.dumps(input_body)
+
+ try:
+ # invoke bedrock API
+ response = self.client.invoke_model(
+ body=body,
+ modelId=self.model_id,
+ accept="application/json",
+ contentType="application/json",
+ )
+
+ # format output based on provider
+ response_body = json.loads(response.get("body").read())
+ if provider == "cohere":
+ return response_body.get("embeddings")[0]
+ else:
+ # includes common provider == "amazon"
+ return response_body.get("embedding")
+ except Exception as e:
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a Bedrock model.
+
+ Args:
+ texts: The list of texts to embed
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ results = []
+ for text in texts:
+ response = self._embedding_func(text)
+ results.append(response)
+ return results
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a Bedrock model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self._embedding_func(text)
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """Asynchronous compute query embeddings using a Bedrock model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+
+ return await asyncio.get_running_loop().run_in_executor(
+ None, partial(self.embed_query, text)
+ )
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Asynchronous compute doc embeddings using a Bedrock model.
+
+ Args:
+ texts: The list of texts to embed
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+
+ result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
+
+ return list(result)
diff --git a/libs/community/langchain_community/embeddings/bookend.py b/libs/community/langchain_community/embeddings/bookend.py
new file mode 100644
index 00000000000..f8388712d7c
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/bookend.py
@@ -0,0 +1,90 @@
+"""Wrapper around Bookend AI embedding models."""
+
+import json
+from typing import Any, List
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+API_URL = "https://api.bookend.ai/"
+DEFAULT_TASK = "embeddings"
+PATH = "/models/predict"
+
+
+class BookendEmbeddings(BaseModel, Embeddings):
+ """Bookend AI sentence_transformers embedding models.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import BookendEmbeddings
+
+ bookend = BookendEmbeddings(
+ domain={domain}
+ api_token={api_token}
+ model_id={model_id}
+ )
+ bookend.embed_documents([
+ "Please put on these earmuffs because I can't you hear.",
+ "Baby wipes are made of chocolate stardust.",
+ ])
+ bookend.embed_query(
+ "She only paints with bold colors; she does not like pastels."
+ )
+ """
+
+ domain: str
+ """Request for a domain at https://bookend.ai/ to use this embeddings module."""
+ api_token: str
+ """Request for an API token at https://bookend.ai/ to use this embeddings module."""
+ model_id: str
+ """Embeddings model ID to use."""
+ auth_header: dict = Field(default_factory=dict)
+
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
+ self.auth_header = {"Authorization": "Basic {}".format(self.api_token)}
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed documents using a Bookend deployed embeddings model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ result = []
+ headers = self.auth_header
+ headers["Content-Type"] = "application/json; charset=utf-8"
+ params = {
+ "model_id": self.model_id,
+ "task": DEFAULT_TASK,
+ }
+
+ for text in texts:
+ data = json.dumps(
+ {"text": text, "question": None, "context": None, "instruction": None}
+ )
+ r = requests.request(
+ "POST",
+ API_URL + self.domain + PATH,
+ headers=headers,
+ params=params,
+ data=data,
+ )
+ result.append(r.json()[0]["data"])
+
+ return result
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using a Bookend deployed embeddings model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed_documents([text])[0]
diff --git a/libs/community/langchain_community/embeddings/clarifai.py b/libs/community/langchain_community/embeddings/clarifai.py
new file mode 100644
index 00000000000..66f8116321c
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/clarifai.py
@@ -0,0 +1,163 @@
+import logging
+from typing import Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class ClarifaiEmbeddings(BaseModel, Embeddings):
+ """Clarifai embedding models.
+
+ To use, you should have the ``clarifai`` python package installed, and the
+ environment variable ``CLARIFAI_PAT`` set with your personal access token or pass it
+ as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import ClarifaiEmbeddings
+ clarifai = ClarifaiEmbeddings(user_id=USER_ID,
+ app_id=APP_ID,
+ model_id=MODEL_ID)
+ (or)
+ clarifai_llm = Clarifai(model_url=EXAMPLE_URL)
+ """
+
+ model_url: Optional[str] = None
+ """Model url to use."""
+ model_id: Optional[str] = None
+ """Model id to use."""
+ model_version_id: Optional[str] = None
+ """Model version id to use."""
+ app_id: Optional[str] = None
+ """Clarifai application id to use."""
+ user_id: Optional[str] = None
+ """Clarifai user id to use."""
+ pat: Optional[str] = None
+ """Clarifai personal access token to use."""
+ api_base: str = "https://api.clarifai.com"
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that we have all required info to access Clarifai
+ platform and python package exists in environment."""
+
+ values["pat"] = get_from_dict_or_env(values, "pat", "CLARIFAI_PAT")
+ user_id = values.get("user_id")
+ app_id = values.get("app_id")
+ model_id = values.get("model_id")
+ model_url = values.get("model_url")
+
+ if model_url is not None and model_id is not None:
+ raise ValueError("Please provide either model_url or model_id, not both.")
+
+ if model_url is None and model_id is None:
+ raise ValueError("Please provide one of model_url or model_id.")
+
+ if model_url is None and model_id is not None:
+ if user_id is None or app_id is None:
+ raise ValueError("Please provide a user_id and app_id.")
+
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to Clarifai's embedding models.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ try:
+ from clarifai.client.input import Inputs
+ from clarifai.client.model import Model
+ except ImportError:
+ raise ImportError(
+ "Could not import clarifai python package. "
+ "Please install it with `pip install clarifai`."
+ )
+ if self.pat is not None:
+ pat = self.pat
+ if self.model_url is not None:
+ _model_init = Model(url=self.model_url, pat=pat)
+ else:
+ _model_init = Model(
+ model_id=self.model_id,
+ user_id=self.user_id,
+ app_id=self.app_id,
+ pat=pat,
+ )
+
+ input_obj = Inputs(pat=pat)
+ batch_size = 32
+ embeddings = []
+
+ try:
+ for i in range(0, len(texts), batch_size):
+ batch = texts[i : i + batch_size]
+ input_batch = [
+ input_obj.get_text_input(input_id=str(id), raw_text=inp)
+ for id, inp in enumerate(batch)
+ ]
+ predict_response = _model_init.predict(input_batch)
+ embeddings.extend(
+ [
+ list(output.data.embeddings[0].vector)
+ for output in predict_response.outputs
+ ]
+ )
+
+ except Exception as e:
+ logger.error(f"Predict failed, exception: {e}")
+
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Clarifai's embedding models.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ try:
+ from clarifai.client.model import Model
+ except ImportError:
+ raise ImportError(
+ "Could not import clarifai python package. "
+ "Please install it with `pip install clarifai`."
+ )
+ if self.pat is not None:
+ pat = self.pat
+ if self.model_url is not None:
+ _model_init = Model(url=self.model_url, pat=pat)
+ else:
+ _model_init = Model(
+ model_id=self.model_id,
+ user_id=self.user_id,
+ app_id=self.app_id,
+ pat=pat,
+ )
+
+ try:
+ predict_response = _model_init.predict_by_bytes(
+ bytes(text, "utf-8"), input_type="text"
+ )
+ embeddings = [
+ list(op.data.embeddings[0].vector) for op in predict_response.outputs
+ ]
+
+ except Exception as e:
+ logger.error(f"Predict failed, exception: {e}")
+
+ return embeddings[0]
diff --git a/libs/community/langchain_community/embeddings/cloudflare_workersai.py b/libs/community/langchain_community/embeddings/cloudflare_workersai.py
new file mode 100644
index 00000000000..81f6f83d4a5
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/cloudflare_workersai.py
@@ -0,0 +1,94 @@
+from typing import Any, Dict, List
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+DEFAULT_MODEL_NAME = "@cf/baai/bge-base-en-v1.5"
+
+
+class CloudflareWorkersAIEmbeddings(BaseModel, Embeddings):
+ """Cloudflare Workers AI embedding model.
+
+ To use, you need to provide an API token and
+ account ID to access Cloudflare Workers AI.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import CloudflareWorkersAIEmbeddings
+
+ account_id = "my_account_id"
+ api_token = "my_secret_api_token"
+ model_name = "@cf/baai/bge-small-en-v1.5"
+
+ cf = CloudflareWorkersAIEmbeddings(
+ account_id=account_id,
+ api_token=api_token,
+ model_name=model_name
+ )
+ """
+
+ api_base_url: str = "https://api.cloudflare.com/client/v4/accounts"
+ account_id: str
+ api_token: str
+ model_name: str = DEFAULT_MODEL_NAME
+ batch_size: int = 50
+ strip_new_lines: bool = True
+ headers: Dict[str, str] = {"Authorization": "Bearer "}
+
+ def __init__(self, **kwargs: Any):
+ """Initialize the Cloudflare Workers AI client."""
+ super().__init__(**kwargs)
+
+ self.headers = {"Authorization": f"Bearer {self.api_token}"}
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using Cloudflare Workers AI.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ if self.strip_new_lines:
+ texts = [text.replace("\n", " ") for text in texts]
+
+ batches = [
+ texts[i : i + self.batch_size]
+ for i in range(0, len(texts), self.batch_size)
+ ]
+ embeddings = []
+
+ for batch in batches:
+ response = requests.post(
+ f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}",
+ headers=self.headers,
+ json={"text": batch},
+ )
+ embeddings.extend(response.json()["result"]["data"])
+
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using Cloudflare Workers AI.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ text = text.replace("\n", " ") if self.strip_new_lines else text
+ response = requests.post(
+ f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}",
+ headers=self.headers,
+ json={"text": [text]},
+ )
+ return response.json()["result"]["data"][0]
diff --git a/libs/community/langchain_community/embeddings/cohere.py b/libs/community/langchain_community/embeddings/cohere.py
new file mode 100644
index 00000000000..fd95b58ea6c
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/cohere.py
@@ -0,0 +1,145 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class CohereEmbeddings(BaseModel, Embeddings):
+ """Cohere embedding models.
+
+ To use, you should have the ``cohere`` python package installed, and the
+ environment variable ``COHERE_API_KEY`` set with your API key or pass it
+ as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import CohereEmbeddings
+ cohere = CohereEmbeddings(
+ model="embed-english-light-v3.0",
+ cohere_api_key="my-api-key"
+ )
+ """
+
+ client: Any #: :meta private:
+ """Cohere client."""
+ async_client: Any #: :meta private:
+ """Cohere async client."""
+ model: str = "embed-english-v2.0"
+ """Model name to use."""
+
+ truncate: Optional[str] = None
+ """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
+
+ cohere_api_key: Optional[str] = None
+
+ max_retries: Optional[int] = None
+ """Maximum number of retries to make when generating."""
+ request_timeout: Optional[float] = None
+ """Timeout in seconds for the Cohere API request."""
+ user_agent: str = "langchain"
+ """Identifier for the application making the request."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ cohere_api_key = get_from_dict_or_env(
+ values, "cohere_api_key", "COHERE_API_KEY"
+ )
+ max_retries = values.get("max_retries")
+ request_timeout = values.get("request_timeout")
+
+ try:
+ import cohere
+
+ client_name = values["user_agent"]
+ values["client"] = cohere.Client(
+ cohere_api_key,
+ max_retries=max_retries,
+ timeout=request_timeout,
+ client_name=client_name,
+ )
+ values["async_client"] = cohere.AsyncClient(
+ cohere_api_key,
+ max_retries=max_retries,
+ timeout=request_timeout,
+ client_name=client_name,
+ )
+ except ImportError:
+ raise ValueError(
+ "Could not import cohere python package. "
+ "Please install it with `pip install cohere`."
+ )
+ return values
+
+ def embed(
+ self, texts: List[str], *, input_type: Optional[str] = None
+ ) -> List[List[float]]:
+ embeddings = self.client.embed(
+ model=self.model,
+ texts=texts,
+ input_type=input_type,
+ truncate=self.truncate,
+ ).embeddings
+ return [list(map(float, e)) for e in embeddings]
+
+ async def aembed(
+ self, texts: List[str], *, input_type: Optional[str] = None
+ ) -> List[List[float]]:
+ embeddings = await self.async_client.embed(
+ model=self.model,
+ texts=texts,
+ input_type=input_type,
+ truncate=self.truncate,
+ ).embeddings
+ return [list(map(float, e)) for e in embeddings]
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed a list of document texts.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ return self.embed(texts, input_type="search_document")
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Async call out to Cohere's embedding endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ return await self.aembed(texts, input_type="search_document")
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Cohere's embedding endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed([text], input_type="search_query")[0]
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """Async call out to Cohere's embedding endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return (await self.aembed([text], input_type="search_query"))[0]
diff --git a/libs/community/langchain_community/embeddings/dashscope.py b/libs/community/langchain_community/embeddings/dashscope.py
new file mode 100644
index 00000000000..9b1b7db8874
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/dashscope.py
@@ -0,0 +1,155 @@
+from __future__ import annotations
+
+import logging
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+)
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from requests.exceptions import HTTPError
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator(embeddings: DashScopeEmbeddings) -> Callable[[Any], Any]:
+ multiplier = 1
+ min_seconds = 1
+ max_seconds = 4
+ # Wait 2^x * 1 second between each retry starting with
+ # 1 seconds, then up to 4 seconds, then 4 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(embeddings.max_retries),
+ wait=wait_exponential(multiplier, min=min_seconds, max=max_seconds),
+ retry=(retry_if_exception_type(HTTPError)),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
+ """Use tenacity to retry the embedding call."""
+ retry_decorator = _create_retry_decorator(embeddings)
+
+ @retry_decorator
+ def _embed_with_retry(**kwargs: Any) -> Any:
+ resp = embeddings.client.call(**kwargs)
+ if resp.status_code == 200:
+ return resp.output["embeddings"]
+ elif resp.status_code in [400, 401]:
+ raise ValueError(
+ f"status_code: {resp.status_code} \n "
+ f"code: {resp.code} \n message: {resp.message}"
+ )
+ else:
+ raise HTTPError(
+ f"HTTP error occurred: status_code: {resp.status_code} \n "
+ f"code: {resp.code} \n message: {resp.message}",
+ response=resp,
+ )
+
+ return _embed_with_retry(**kwargs)
+
+
+class DashScopeEmbeddings(BaseModel, Embeddings):
+ """DashScope embedding models.
+
+ To use, you should have the ``dashscope`` python package installed, and the
+ environment variable ``DASHSCOPE_API_KEY`` set with your API key or pass it
+ as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import DashScopeEmbeddings
+ embeddings = DashScopeEmbeddings(dashscope_api_key="my-api-key")
+
+ Example:
+ .. code-block:: python
+
+ import os
+ os.environ["DASHSCOPE_API_KEY"] = "your DashScope API KEY"
+
+ from langchain_community.embeddings.dashscope import DashScopeEmbeddings
+ embeddings = DashScopeEmbeddings(
+ model="text-embedding-v1",
+ )
+ text = "This is a test query."
+ query_result = embeddings.embed_query(text)
+
+ """
+
+ client: Any #: :meta private:
+ """The DashScope client."""
+ model: str = "text-embedding-v1"
+ dashscope_api_key: Optional[str] = None
+ max_retries: int = 5
+ """Maximum number of retries to make when generating."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ import dashscope
+
+ """Validate that api key and python package exists in environment."""
+ values["dashscope_api_key"] = get_from_dict_or_env(
+ values, "dashscope_api_key", "DASHSCOPE_API_KEY"
+ )
+ dashscope.api_key = values["dashscope_api_key"]
+ try:
+ import dashscope
+
+ values["client"] = dashscope.TextEmbedding
+ except ImportError:
+ raise ImportError(
+ "Could not import dashscope python package. "
+ "Please install it with `pip install dashscope`."
+ )
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to DashScope's embedding endpoint for embedding search docs.
+
+ Args:
+ texts: The list of texts to embed.
+ chunk_size: The chunk size of embeddings. If None, will use the chunk size
+ specified by the class.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = embed_with_retry(
+ self, input=texts, text_type="document", model=self.model
+ )
+ embedding_list = [item["embedding"] for item in embeddings]
+ return embedding_list
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to DashScope's embedding endpoint for embedding query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embedding for the text.
+ """
+ embedding = embed_with_retry(
+ self, input=text, text_type="query", model=self.model
+ )[0]["embedding"]
+ return embedding
diff --git a/libs/community/langchain_community/embeddings/databricks.py b/libs/community/langchain_community/embeddings/databricks.py
new file mode 100644
index 00000000000..91948dc96fc
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/databricks.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+from typing import Iterator, List
+from urllib.parse import urlparse
+
+from langchain_community.embeddings.mlflow import MlflowEmbeddings
+
+
+def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
+ for i in range(0, len(texts), size):
+ yield texts[i : i + size]
+
+
+class DatabricksEmbeddings(MlflowEmbeddings):
+ """Wrapper around embeddings LLMs in Databricks.
+
+ To use, you should have the ``mlflow`` python package installed.
+ For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import DatabricksEmbeddings
+
+ embeddings = DatabricksEmbeddings(
+ target_uri="databricks",
+ endpoint="embeddings",
+ )
+ """
+
+ target_uri: str = "databricks"
+ """The target URI to use. Defaults to ``databricks``."""
+
+ @property
+ def _mlflow_extras(self) -> str:
+ return ""
+
+ def _validate_uri(self) -> None:
+ if self.target_uri == "databricks":
+ return
+
+ if urlparse(self.target_uri).scheme != "databricks":
+ raise ValueError(
+ "Invalid target URI. The target URI must be a valid databricks URI."
+ )
diff --git a/libs/community/langchain_community/embeddings/deepinfra.py b/libs/community/langchain_community/embeddings/deepinfra.py
new file mode 100644
index 00000000000..046e1e481f1
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/deepinfra.py
@@ -0,0 +1,128 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32"
+
+
+class DeepInfraEmbeddings(BaseModel, Embeddings):
+ """Deep Infra's embedding inference service.
+
+ To use, you should have the
+ environment variable ``DEEPINFRA_API_TOKEN`` set with your API token, or pass
+ it as a named parameter to the constructor.
+ There are multiple embeddings models available,
+ see https://deepinfra.com/models?type=embeddings.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import DeepInfraEmbeddings
+ deepinfra_emb = DeepInfraEmbeddings(
+ model_id="sentence-transformers/clip-ViT-B-32",
+ deepinfra_api_token="my-api-key"
+ )
+ r1 = deepinfra_emb.embed_documents(
+ [
+ "Alpha is the first letter of Greek alphabet",
+ "Beta is the second letter of Greek alphabet",
+ ]
+ )
+ r2 = deepinfra_emb.embed_query(
+ "What is the second letter of Greek alphabet"
+ )
+
+ """
+
+ model_id: str = DEFAULT_MODEL_ID
+ """Embeddings model to use."""
+ normalize: bool = False
+ """whether to normalize the computed embeddings"""
+ embed_instruction: str = "passage: "
+ """Instruction used to embed documents."""
+ query_instruction: str = "query: "
+ """Instruction used to embed the query."""
+ model_kwargs: Optional[dict] = None
+ """Other model keyword args"""
+
+ deepinfra_api_token: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ deepinfra_api_token = get_from_dict_or_env(
+ values, "deepinfra_api_token", "DEEPINFRA_API_TOKEN"
+ )
+ values["deepinfra_api_token"] = deepinfra_api_token
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {"model_id": self.model_id}
+
+ def _embed(self, input: List[str]) -> List[List[float]]:
+ _model_kwargs = self.model_kwargs or {}
+ # HTTP headers for authorization
+ headers = {
+ "Authorization": f"bearer {self.deepinfra_api_token}",
+ "Content-Type": "application/json",
+ }
+ # send request
+ try:
+ res = requests.post(
+ f"https://api.deepinfra.com/v1/inference/{self.model_id}",
+ headers=headers,
+ json={"inputs": input, "normalize": self.normalize, **_model_kwargs},
+ )
+ except requests.exceptions.RequestException as e:
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+
+ if res.status_code != 200:
+ raise ValueError(
+ "Error raised by inference API HTTP code: %s, %s"
+ % (res.status_code, res.text)
+ )
+ try:
+ t = res.json()
+ embeddings = t["embeddings"]
+ except requests.exceptions.JSONDecodeError as e:
+ raise ValueError(
+ f"Error raised by inference API: {e}.\nResponse: {res.text}"
+ )
+
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed documents using a Deep Infra deployed embedding model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ instruction_pairs = [f"{self.embed_instruction}{text}" for text in texts]
+ embeddings = self._embed(instruction_pairs)
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using a Deep Infra deployed embedding model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ instruction_pair = f"{self.query_instruction}{text}"
+ embedding = self._embed([instruction_pair])[0]
+ return embedding
diff --git a/libs/community/langchain_community/embeddings/edenai.py b/libs/community/langchain_community/embeddings/edenai.py
new file mode 100644
index 00000000000..5cb92e63007
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/edenai.py
@@ -0,0 +1,110 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.utilities.requests import Requests
+
+
+class EdenAiEmbeddings(BaseModel, Embeddings):
+ """EdenAI embedding.
+ environment variable ``EDENAI_API_KEY`` set with your API key, or pass
+ it as a named parameter.
+ """
+
+ edenai_api_key: Optional[str] = Field(None, description="EdenAI API Token")
+
+ provider: str = "openai"
+ """embedding provider to use (eg: openai,google etc.)"""
+
+ model: Optional[str] = None
+ """
+ model name for above provider (eg: 'text-davinci-003' for openai)
+ available models are shown on https://docs.edenai.co/ under 'available providers'
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ values["edenai_api_key"] = get_from_dict_or_env(
+ values, "edenai_api_key", "EDENAI_API_KEY"
+ )
+ return values
+
+ @staticmethod
+ def get_user_agent() -> str:
+ from langchain_community import __version__
+
+ return f"langchain/{__version__}"
+
+ def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
+ """Compute embeddings using EdenAi api."""
+ url = "https://api.edenai.run/v2/text/embeddings"
+
+ headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ "authorization": f"Bearer {self.edenai_api_key}",
+ "User-Agent": self.get_user_agent(),
+ }
+
+ payload: Dict[str, Any] = {"texts": texts, "providers": self.provider}
+
+ if self.model is not None:
+ payload["settings"] = {self.provider: self.model}
+
+ request = Requests(headers=headers)
+ response = request.post(url=url, data=payload)
+ if response.status_code >= 500:
+ raise Exception(f"EdenAI Server: Error {response.status_code}")
+ elif response.status_code >= 400:
+ raise ValueError(f"EdenAI received an invalid payload: {response.text}")
+ elif response.status_code != 200:
+ raise Exception(
+ f"EdenAI returned an unexpected response with status "
+ f"{response.status_code}: {response.text}"
+ )
+
+ temp = response.json()
+
+ provider_response = temp[self.provider]
+ if provider_response.get("status") == "fail":
+ err_msg = provider_response.get("error", {}).get("message")
+ raise Exception(err_msg)
+
+ embeddings = []
+ for embed_item in temp[self.provider]["items"]:
+ embedding = embed_item["embedding"]
+
+ embeddings.append(embedding)
+
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed a list of documents using EdenAI.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+
+ return self._generate_embeddings(texts)
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using EdenAI.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self._generate_embeddings([text])[0]
diff --git a/libs/community/langchain_community/embeddings/elasticsearch.py b/libs/community/langchain_community/embeddings/elasticsearch.py
new file mode 100644
index 00000000000..9145a948ad7
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/elasticsearch.py
@@ -0,0 +1,222 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Optional
+
+from langchain_core.utils import get_from_env
+
+if TYPE_CHECKING:
+ from elasticsearch import Elasticsearch
+ from elasticsearch.client import MlClient
+
+from langchain_core.embeddings import Embeddings
+
+
+class ElasticsearchEmbeddings(Embeddings):
+ """Elasticsearch embedding models.
+
+ This class provides an interface to generate embeddings using a model deployed
+ in an Elasticsearch cluster. It requires an Elasticsearch connection object
+ and the model_id of the model deployed in the cluster.
+
+ In Elasticsearch you need to have an embedding model loaded and deployed.
+ - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
+ - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ client: MlClient,
+ model_id: str,
+ *,
+ input_field: str = "text_field",
+ ):
+ """
+ Initialize the ElasticsearchEmbeddings instance.
+
+ Args:
+ client (MlClient): An Elasticsearch ML client object.
+ model_id (str): The model_id of the model deployed in the Elasticsearch
+ cluster.
+ input_field (str): The name of the key for the input text field in the
+ document. Defaults to 'text_field'.
+ """
+ self.client = client
+ self.model_id = model_id
+ self.input_field = input_field
+
+ @classmethod
+ def from_credentials(
+ cls,
+ model_id: str,
+ *,
+ es_cloud_id: Optional[str] = None,
+ es_user: Optional[str] = None,
+ es_password: Optional[str] = None,
+ input_field: str = "text_field",
+ ) -> ElasticsearchEmbeddings:
+ """Instantiate embeddings from Elasticsearch credentials.
+
+ Args:
+ model_id (str): The model_id of the model deployed in the Elasticsearch
+ cluster.
+ input_field (str): The name of the key for the input text field in the
+ document. Defaults to 'text_field'.
+ es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to.
+ es_user: (str, optional): Elasticsearch username.
+ es_password: (str, optional): Elasticsearch password.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import ElasticsearchEmbeddings
+
+ # Define the model ID and input field name (if different from default)
+ model_id = "your_model_id"
+ # Optional, only if different from 'text_field'
+ input_field = "your_input_field"
+
+ # Credentials can be passed in two ways. Either set the env vars
+ # ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically
+ # pulled in, or pass them in directly as kwargs.
+ embeddings = ElasticsearchEmbeddings.from_credentials(
+ model_id,
+ input_field=input_field,
+ # es_cloud_id="foo",
+ # es_user="bar",
+ # es_password="baz",
+ )
+
+ documents = [
+ "This is an example document.",
+ "Another example document to generate embeddings for.",
+ ]
+ embeddings_generator.embed_documents(documents)
+ """
+ try:
+ from elasticsearch import Elasticsearch
+ from elasticsearch.client import MlClient
+ except ImportError:
+ raise ImportError(
+ "elasticsearch package not found, please install with 'pip install "
+ "elasticsearch'"
+ )
+
+ es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID")
+ es_user = es_user or get_from_env("es_user", "ES_USER")
+ es_password = es_password or get_from_env("es_password", "ES_PASSWORD")
+
+ # Connect to Elasticsearch
+ es_connection = Elasticsearch(
+ cloud_id=es_cloud_id, basic_auth=(es_user, es_password)
+ )
+ client = MlClient(es_connection)
+ return cls(client, model_id, input_field=input_field)
+
+ @classmethod
+ def from_es_connection(
+ cls,
+ model_id: str,
+ es_connection: Elasticsearch,
+ input_field: str = "text_field",
+ ) -> ElasticsearchEmbeddings:
+ """
+ Instantiate embeddings from an existing Elasticsearch connection.
+
+ This method provides a way to create an instance of the ElasticsearchEmbeddings
+ class using an existing Elasticsearch connection. The connection object is used
+ to create an MlClient, which is then used to initialize the
+ ElasticsearchEmbeddings instance.
+
+ Args:
+ model_id (str): The model_id of the model deployed in the Elasticsearch cluster.
+ es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch
+ connection object. input_field (str, optional): The name of the key for the
+ input text field in the document. Defaults to 'text_field'.
+
+ Returns:
+ ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class.
+
+ Example:
+ .. code-block:: python
+
+ from elasticsearch import Elasticsearch
+
+ from langchain_community.embeddings import ElasticsearchEmbeddings
+
+ # Define the model ID and input field name (if different from default)
+ model_id = "your_model_id"
+ # Optional, only if different from 'text_field'
+ input_field = "your_input_field"
+
+ # Create Elasticsearch connection
+ es_connection = Elasticsearch(
+ hosts=["localhost:9200"], http_auth=("user", "password")
+ )
+
+ # Instantiate ElasticsearchEmbeddings using the existing connection
+ embeddings = ElasticsearchEmbeddings.from_es_connection(
+ model_id,
+ es_connection,
+ input_field=input_field,
+ )
+
+ documents = [
+ "This is an example document.",
+ "Another example document to generate embeddings for.",
+ ]
+ embeddings_generator.embed_documents(documents)
+ """
+ # Importing MlClient from elasticsearch.client within the method to
+ # avoid unnecessary import if the method is not used
+ from elasticsearch.client import MlClient
+
+ # Create an MlClient from the given Elasticsearch connection
+ client = MlClient(es_connection)
+
+ # Return a new instance of the ElasticsearchEmbeddings class with
+ # the MlClient, model_id, and input_field
+ return cls(client, model_id, input_field=input_field)
+
+ def _embedding_func(self, texts: List[str]) -> List[List[float]]:
+ """
+ Generate embeddings for the given texts using the Elasticsearch model.
+
+ Args:
+ texts (List[str]): A list of text strings to generate embeddings for.
+
+ Returns:
+ List[List[float]]: A list of embeddings, one for each text in the input
+ list.
+ """
+ response = self.client.infer_trained_model(
+ model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
+ )
+
+ embeddings = [doc["predicted_value"] for doc in response["inference_results"]]
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """
+ Generate embeddings for a list of documents.
+
+ Args:
+ texts (List[str]): A list of document text strings to generate embeddings
+ for.
+
+ Returns:
+ List[List[float]]: A list of embeddings, one for each document in the input
+ list.
+ """
+ return self._embedding_func(texts)
+
+ def embed_query(self, text: str) -> List[float]:
+ """
+ Generate an embedding for a single query text.
+
+ Args:
+ text (str): The query text to generate an embedding for.
+
+ Returns:
+ List[float]: The embedding for the input query text.
+ """
+ return self._embedding_func([text])[0]
diff --git a/libs/community/langchain_community/embeddings/embaas.py b/libs/community/langchain_community/embeddings/embaas.py
new file mode 100644
index 00000000000..799517bb621
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/embaas.py
@@ -0,0 +1,156 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from requests.adapters import HTTPAdapter, Retry
+from typing_extensions import NotRequired, TypedDict
+
+# Currently supported maximum batch size for embedding requests
+MAX_BATCH_SIZE = 256
+EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"
+
+
+class EmbaasEmbeddingsPayload(TypedDict):
+ """Payload for the Embaas embeddings API."""
+
+ model: str
+ texts: List[str]
+ instruction: NotRequired[str]
+
+
+class EmbaasEmbeddings(BaseModel, Embeddings):
+ """Embaas's embedding service.
+
+ To use, you should have the
+ environment variable ``EMBAAS_API_KEY`` set with your API key, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ # Initialise with default model and instruction
+ from langchain_community.embeddings import EmbaasEmbeddings
+ emb = EmbaasEmbeddings()
+
+ # Initialise with custom model and instruction
+ from langchain_community.embeddings import EmbaasEmbeddings
+ emb_model = "instructor-large"
+ emb_inst = "Represent the Wikipedia document for retrieval"
+ emb = EmbaasEmbeddings(
+ model=emb_model,
+ instruction=emb_inst
+ )
+ """
+
+ model: str = "e5-large-v2"
+ """The model used for embeddings."""
+ instruction: Optional[str] = None
+ """Instruction used for domain-specific embeddings."""
+ api_url: str = EMBAAS_API_URL
+ """The URL for the embaas embeddings API."""
+ embaas_api_key: Optional[str] = None
+ """max number of retries for requests"""
+ max_retries: Optional[int] = 3
+ """request timeout in seconds"""
+ timeout: Optional[int] = 30
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ embaas_api_key = get_from_dict_or_env(
+ values, "embaas_api_key", "EMBAAS_API_KEY"
+ )
+ values["embaas_api_key"] = embaas_api_key
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying params."""
+ return {"model": self.model, "instruction": self.instruction}
+
+ def _generate_payload(self, texts: List[str]) -> EmbaasEmbeddingsPayload:
+ """Generates payload for the API request."""
+ payload = EmbaasEmbeddingsPayload(texts=texts, model=self.model)
+ if self.instruction:
+ payload["instruction"] = self.instruction
+ return payload
+
+ def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
+ """Sends a request to the Embaas API and handles the response."""
+ headers = {
+ "Authorization": f"Bearer {self.embaas_api_key}",
+ "Content-Type": "application/json",
+ }
+
+ session = requests.Session()
+ retries = Retry(
+ total=self.max_retries,
+ backoff_factor=0.5,
+ allowed_methods=["POST"],
+ raise_on_status=True,
+ )
+
+ session.mount("http://", HTTPAdapter(max_retries=retries))
+ session.mount("https://", HTTPAdapter(max_retries=retries))
+ response = session.post(
+ self.api_url,
+ headers=headers,
+ json=payload,
+ timeout=self.timeout,
+ )
+
+ parsed_response = response.json()
+ embeddings = [item["embedding"] for item in parsed_response["data"]]
+
+ return embeddings
+
+ def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
+ """Generate embeddings using the Embaas API."""
+ payload = self._generate_payload(texts)
+ try:
+ return self._handle_request(payload)
+ except requests.exceptions.RequestException as e:
+ if e.response is None or not e.response.text:
+ raise ValueError(f"Error raised by embaas embeddings API: {e}")
+
+ parsed_response = e.response.json()
+ if "message" in parsed_response:
+ raise ValueError(
+ "Validation Error raised by embaas embeddings API:"
+ f"{parsed_response['message']}"
+ )
+ raise
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Get embeddings for a list of texts.
+
+ Args:
+ texts: The list of texts to get embeddings for.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ batches = [
+ texts[i : i + MAX_BATCH_SIZE] for i in range(0, len(texts), MAX_BATCH_SIZE)
+ ]
+ embeddings = [self._generate_embeddings(batch) for batch in batches]
+ # flatten the list of lists into a single list
+ return [embedding for batch in embeddings for embedding in batch]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Get embeddings for a single text.
+
+ Args:
+ text: The text to get embeddings for.
+
+ Returns:
+ List of embeddings.
+ """
+ return self.embed_documents([text])[0]
diff --git a/libs/community/langchain_community/embeddings/ernie.py b/libs/community/langchain_community/embeddings/ernie.py
new file mode 100644
index 00000000000..0e2d19f6b5d
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/ernie.py
@@ -0,0 +1,153 @@
+import asyncio
+import logging
+import threading
+from functools import partial
+from typing import Dict, List, Optional
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class ErnieEmbeddings(BaseModel, Embeddings):
+ """`Ernie Embeddings V1` embedding models."""
+
+ ernie_api_base: Optional[str] = None
+ ernie_client_id: Optional[str] = None
+ ernie_client_secret: Optional[str] = None
+ access_token: Optional[str] = None
+
+ chunk_size: int = 16
+
+ model_name = "ErnieBot-Embedding-V1"
+
+ _lock = threading.Lock()
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ values["ernie_api_base"] = get_from_dict_or_env(
+ values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
+ )
+ values["ernie_client_id"] = get_from_dict_or_env(
+ values,
+ "ernie_client_id",
+ "ERNIE_CLIENT_ID",
+ )
+ values["ernie_client_secret"] = get_from_dict_or_env(
+ values,
+ "ernie_client_secret",
+ "ERNIE_CLIENT_SECRET",
+ )
+ return values
+
+ def _embedding(self, json: object) -> dict:
+ base_url = (
+ f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
+ )
+ resp = requests.post(
+ f"{base_url}/embedding-v1",
+ headers={
+ "Content-Type": "application/json",
+ },
+ params={"access_token": self.access_token},
+ json=json,
+ )
+ return resp.json()
+
+ def _refresh_access_token_with_lock(self) -> None:
+ with self._lock:
+ logger.debug("Refreshing access token")
+ base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
+ resp = requests.post(
+ base_url,
+ headers={
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ },
+ params={
+ "grant_type": "client_credentials",
+ "client_id": self.ernie_client_id,
+ "client_secret": self.ernie_client_secret,
+ },
+ )
+ self.access_token = str(resp.json().get("access_token"))
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed search docs.
+
+ Args:
+ texts: The list of texts to embed
+
+ Returns:
+ List[List[float]]: List of embeddings, one for each text.
+ """
+
+ if not self.access_token:
+ self._refresh_access_token_with_lock()
+ text_in_chunks = [
+ texts[i : i + self.chunk_size]
+ for i in range(0, len(texts), self.chunk_size)
+ ]
+ lst = []
+ for chunk in text_in_chunks:
+ resp = self._embedding({"input": [text for text in chunk]})
+ if resp.get("error_code"):
+ if resp.get("error_code") == 111:
+ self._refresh_access_token_with_lock()
+ resp = self._embedding({"input": [text for text in chunk]})
+ else:
+ raise ValueError(f"Error from Ernie: {resp}")
+ lst.extend([i["embedding"] for i in resp["data"]])
+ return lst
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ List[float]: Embeddings for the text.
+ """
+
+ if not self.access_token:
+ self._refresh_access_token_with_lock()
+ resp = self._embedding({"input": [text]})
+ if resp.get("error_code"):
+ if resp.get("error_code") == 111:
+ self._refresh_access_token_with_lock()
+ resp = self._embedding({"input": [text]})
+ else:
+ raise ValueError(f"Error from Ernie: {resp}")
+ return resp["data"][0]["embedding"]
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """Asynchronous Embed query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ List[float]: Embeddings for the text.
+ """
+
+ return await asyncio.get_running_loop().run_in_executor(
+ None, partial(self.embed_query, text)
+ )
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Asynchronous Embed search docs.
+
+ Args:
+ texts: The list of texts to embed
+
+ Returns:
+ List[List[float]]: List of embeddings, one for each text.
+ """
+
+ result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
+
+ return list(result)
diff --git a/libs/community/langchain_community/embeddings/fake.py b/libs/community/langchain_community/embeddings/fake.py
new file mode 100644
index 00000000000..fcc33496aa8
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/fake.py
@@ -0,0 +1,49 @@
+import hashlib
+from typing import List
+
+import numpy as np
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel
+
+
+class FakeEmbeddings(Embeddings, BaseModel):
+ """Fake embedding model."""
+
+ size: int
+ """The size of the embedding vector."""
+
+ def _get_embedding(self) -> List[float]:
+ return list(np.random.normal(size=self.size))
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ return [self._get_embedding() for _ in texts]
+
+ def embed_query(self, text: str) -> List[float]:
+ return self._get_embedding()
+
+
+class DeterministicFakeEmbedding(Embeddings, BaseModel):
+ """
+ Fake embedding model that always returns
+ the same embedding vector for the same text.
+ """
+
+ size: int
+ """The size of the embedding vector."""
+
+ def _get_embedding(self, seed: int) -> List[float]:
+ # set the seed for the random generator
+ np.random.seed(seed)
+ return list(np.random.normal(size=self.size))
+
+ def _get_seed(self, text: str) -> int:
+ """
+ Get a seed for the random generator, using the hash of the text.
+ """
+ return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
+
+ def embed_query(self, text: str) -> List[float]:
+ return self._get_embedding(seed=self._get_seed(text))
diff --git a/libs/community/langchain_community/embeddings/fastembed.py b/libs/community/langchain_community/embeddings/fastembed.py
new file mode 100644
index 00000000000..0a29a784004
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/fastembed.py
@@ -0,0 +1,107 @@
+from typing import Any, Dict, List, Literal, Optional
+
+import numpy as np
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+
+class FastEmbedEmbeddings(BaseModel, Embeddings):
+ """Qdrant FastEmbedding models.
+ FastEmbed is a lightweight, fast, Python library built for embedding generation.
+ See more documentation at:
+ * https://github.com/qdrant/fastembed/
+ * https://qdrant.github.io/fastembed/
+
+ To use this class, you must install the `fastembed` Python package.
+
+ `pip install fastembed`
+ Example:
+ from langchain_community.embeddings import FastEmbedEmbeddings
+ fastembed = FastEmbedEmbeddings()
+ """
+
+ model_name: str = "BAAI/bge-small-en-v1.5"
+ """Name of the FastEmbedding model to use
+ Defaults to "BAAI/bge-small-en-v1.5"
+ Find the list of supported models at
+ https://qdrant.github.io/fastembed/examples/Supported_Models/
+ """
+
+ max_length: int = 512
+ """The maximum number of tokens. Defaults to 512.
+ Unknown behavior for values > 512.
+ """
+
+ cache_dir: Optional[str]
+ """The path to the cache directory.
+ Defaults to `local_cache` in the parent directory
+ """
+
+ threads: Optional[int]
+ """The number of threads single onnxruntime session can use.
+ Defaults to None
+ """
+
+ doc_embed_type: Literal["default", "passage"] = "default"
+ """Type of embedding to use for documents
+ "default": Uses FastEmbed's default embedding method
+ "passage": Prefixes the text with "passage" before embedding.
+ """
+
+ _model: Any # : :meta private:
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that FastEmbed has been installed."""
+ try:
+ from fastembed.embedding import FlagEmbedding
+
+ model_name = values.get("model_name")
+ max_length = values.get("max_length")
+ cache_dir = values.get("cache_dir")
+ threads = values.get("threads")
+ values["_model"] = FlagEmbedding(
+ model_name=model_name,
+ max_length=max_length,
+ cache_dir=cache_dir,
+ threads=threads,
+ )
+ except ImportError as ie:
+ raise ImportError(
+ "Could not import 'fastembed' Python package. "
+ "Please install it with `pip install fastembed`."
+ ) from ie
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Generate embeddings for documents using FastEmbed.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings: List[np.ndarray]
+ if self.doc_embed_type == "passage":
+ embeddings = self._model.passage_embed(texts)
+ else:
+ embeddings = self._model.embed(texts)
+ return [e.tolist() for e in embeddings]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Generate query embeddings using FastEmbed.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ query_embeddings: np.ndarray = next(self._model.query_embed(text))
+ return query_embeddings.tolist()
diff --git a/libs/community/langchain_community/embeddings/google_palm.py b/libs/community/langchain_community/embeddings/google_palm.py
new file mode 100644
index 00000000000..0ffb5cfb6fb
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/google_palm.py
@@ -0,0 +1,101 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Callable, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator() -> Callable[[Any], Any]:
+ """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
+ import google.api_core.exceptions
+
+ multiplier = 2
+ min_seconds = 1
+ max_seconds = 60
+ max_retries = 10
+
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(max_retries),
+ wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
+ retry=(
+ retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
+ | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
+ | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def embed_with_retry(
+ embeddings: GooglePalmEmbeddings, *args: Any, **kwargs: Any
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator()
+
+ @retry_decorator
+ def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
+ return embeddings.client.generate_embeddings(*args, **kwargs)
+
+ return _embed_with_retry(*args, **kwargs)
+
+
+class GooglePalmEmbeddings(BaseModel, Embeddings):
+ """Google's PaLM Embeddings APIs."""
+
+ client: Any
+ google_api_key: Optional[str]
+ model_name: str = "models/embedding-gecko-001"
+ """Model name to use."""
+ show_progress_bar: bool = False
+ """Whether to show a tqdm progress bar. Must have `tqdm` installed."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate api key, python package exists."""
+ google_api_key = get_from_dict_or_env(
+ values, "google_api_key", "GOOGLE_API_KEY"
+ )
+ try:
+ import google.generativeai as genai
+
+ genai.configure(api_key=google_api_key)
+ except ImportError:
+ raise ImportError("Could not import google.generativeai python package.")
+
+ values["client"] = genai
+
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ if self.show_progress_bar:
+ try:
+ from tqdm import tqdm
+
+ iter_ = tqdm(texts, desc="GooglePalmEmbeddings")
+ except ImportError:
+ logger.warning(
+ "Unable to show progress bar because tqdm could not be imported. "
+ "Please install with `pip install tqdm`."
+ )
+ iter_ = texts
+ else:
+ iter_ = texts
+ return [self.embed_query(text) for text in iter_]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed query text."""
+ embedding = embed_with_retry(self, self.model_name, text)
+ return embedding["embedding"]
diff --git a/libs/community/langchain_community/embeddings/gpt4all.py b/libs/community/langchain_community/embeddings/gpt4all.py
new file mode 100644
index 00000000000..f7983a5968b
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/gpt4all.py
@@ -0,0 +1,60 @@
+from typing import Any, Dict, List
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+
+class GPT4AllEmbeddings(BaseModel, Embeddings):
+ """GPT4All embedding models.
+
+ To use, you should have the gpt4all python package installed
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import GPT4AllEmbeddings
+
+ embeddings = GPT4AllEmbeddings()
+ """
+
+ client: Any #: :meta private:
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that GPT4All library is installed."""
+
+ try:
+ from gpt4all import Embed4All
+
+ values["client"] = Embed4All()
+ except ImportError:
+ raise ImportError(
+ "Could not import gpt4all library. "
+ "Please install the gpt4all library to "
+ "use this embedding model: pip install gpt4all"
+ )
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed a list of documents using GPT4All.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+
+ embeddings = [self.client.embed(text) for text in texts]
+ return [list(map(float, e)) for e in embeddings]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using GPT4All.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed_documents([text])[0]
diff --git a/libs/community/langchain_community/embeddings/gradient_ai.py b/libs/community/langchain_community/embeddings/gradient_ai.py
new file mode 100644
index 00000000000..d5c9bf9a49c
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/gradient_ai.py
@@ -0,0 +1,379 @@
+import asyncio
+import logging
+import os
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import aiohttp
+import numpy as np
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+__all__ = ["GradientEmbeddings"]
+
+
+class GradientEmbeddings(BaseModel, Embeddings):
+ """Gradient.ai Embedding models.
+
+ GradientLLM is a class to interact with Embedding Models on gradient.ai
+
+ To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
+ API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
+ or alternatively provide them as keywords to the constructor of this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import GradientEmbeddings
+ GradientEmbeddings(
+ model="bge-large",
+ gradient_workspace_id="12345614fc0_workspace",
+ gradient_access_token="gradientai-access_token",
+ )
+ """
+
+ model: str
+ "Underlying gradient.ai model id."
+
+ gradient_workspace_id: Optional[str] = None
+ "Underlying gradient.ai workspace_id."
+
+ gradient_access_token: Optional[str] = None
+ """gradient.ai API Token, which can be generated by going to
+ https://auth.gradient.ai/select-workspace
+ and selecting "Access tokens" under the profile drop-down.
+ """
+
+ gradient_api_url: str = "https://api.gradient.ai/api"
+ """Endpoint URL to use."""
+
+ client: Any = None #: :meta private:
+ """Gradient client."""
+
+ # LLM call kwargs
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(allow_reuse=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+
+ values["gradient_access_token"] = get_from_dict_or_env(
+ values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN"
+ )
+ values["gradient_workspace_id"] = get_from_dict_or_env(
+ values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
+ )
+
+ values["gradient_api_url"] = get_from_dict_or_env(
+ values, "gradient_api_url", "GRADIENT_API_URL"
+ )
+
+ values["client"] = TinyAsyncGradientEmbeddingClient(
+ access_token=values["gradient_access_token"],
+ workspace_id=values["gradient_workspace_id"],
+ host=values["gradient_api_url"],
+ )
+ try:
+ import gradientai # noqa
+ except ImportError:
+ logging.warning(
+ "DeprecationWarning: `GradientEmbeddings` will use "
+ "`pip install gradientai` in future releases of langchain."
+ )
+ except Exception:
+ pass
+
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to Gradient's embedding endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = self.client.embed(
+ model=self.model,
+ texts=texts,
+ )
+ return embeddings
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Async call out to Gradient's embedding endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = await self.client.aembed(
+ model=self.model,
+ texts=texts,
+ )
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Gradient's embedding endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed_documents([text])[0]
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """Async call out to Gradient's embedding endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ embeddings = await self.aembed_documents([text])
+ return embeddings[0]
+
+
+class TinyAsyncGradientEmbeddingClient: #: :meta private:
+ """A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API,
+ direct use discouraged.
+
+ To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
+ API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
+ or alternatively provide them as keywords to the constructor of this class.
+
+ Example:
+ .. code-block:: python
+
+
+ mini_client = TinyAsyncGradientEmbeddingClient(
+ workspace_id="12345614fc0_workspace",
+ access_token="gradientai-access_token",
+ )
+ embeds = mini_client.embed(
+ model="bge-large",
+ text=["doc1", "doc2"]
+ )
+ # or
+ embeds = await mini_client.aembed(
+ model="bge-large",
+ text=["doc1", "doc2"]
+ )
+
+ """
+
+ def __init__(
+ self,
+ access_token: Optional[str] = None,
+ workspace_id: Optional[str] = None,
+ host: str = "https://api.gradient.ai/api",
+ aiosession: Optional[aiohttp.ClientSession] = None,
+ ) -> None:
+ self.access_token = access_token or os.environ.get(
+ "GRADIENT_ACCESS_TOKEN", None
+ )
+ self.workspace_id = workspace_id or os.environ.get(
+ "GRADIENT_WORKSPACE_ID", None
+ )
+ self.host = host
+ self.aiosession = aiosession
+
+ if self.access_token is None or len(self.access_token) < 10:
+ raise ValueError(
+ "env variable `GRADIENT_ACCESS_TOKEN` or "
+ " param `access_token` must be set "
+ )
+
+ if self.workspace_id is None or len(self.workspace_id) < 3:
+ raise ValueError(
+ "env variable `GRADIENT_WORKSPACE_ID` or "
+ " param `workspace_id` must be set"
+ )
+
+ if self.host is None or len(self.host) < 3:
+ raise ValueError(" param `host` must be set to a valid url")
+ self._batch_size = 128
+
+ @staticmethod
+ def _permute(
+ texts: List[str], sorter: Callable = len
+ ) -> Tuple[List[str], Callable]:
+ """Sort texts in ascending order, and
+ delivers a lambda expr, which can sort a same length list
+ https://github.com/UKPLab/sentence-transformers/blob/
+ c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156
+
+ Args:
+ texts (List[str]): _description_
+ sorter (Callable, optional): _description_. Defaults to len.
+
+ Returns:
+ Tuple[List[str], Callable]: _description_
+
+ Example:
+ ```
+ texts = ["one","three","four"]
+ perm_texts, undo = self._permute(texts)
+ texts == undo(perm_texts)
+ ```
+ """
+
+ if len(texts) == 1:
+ # special case query
+ return texts, lambda t: t
+ length_sorted_idx = np.argsort([-sorter(sen) for sen in texts])
+ texts_sorted = [texts[idx] for idx in length_sorted_idx]
+
+ return texts_sorted, lambda unsorted_embeddings: [ # noqa E731
+ unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
+ ]
+
+ def _batch(self, texts: List[str]) -> List[List[str]]:
+ """
+ splits Lists of text parts into batches of size max `self._batch_size`
+ When encoding vector database,
+
+ Args:
+ texts (List[str]): List of sentences
+ self._batch_size (int, optional): max batch size of one request.
+
+ Returns:
+ List[List[str]]: Batches of List of sentences
+ """
+ if len(texts) == 1:
+ # special case query
+ return [texts]
+ batches = []
+ for start_index in range(0, len(texts), self._batch_size):
+ batches.append(texts[start_index : start_index + self._batch_size])
+ return batches
+
+ @staticmethod
+ def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
+ if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
+ # special case query
+ return batch_of_texts[0]
+ texts = []
+ for sublist in batch_of_texts:
+ texts.extend(sublist)
+ return texts
+
+ def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]:
+ """Build the kwargs for the Post request, used by sync
+
+ Args:
+ model (str): _description_
+ texts (List[str]): _description_
+
+ Returns:
+ Dict[str, Collection[str]]: _description_
+ """
+ return dict(
+ url=f"{self.host}/embeddings/{model}",
+ headers={
+ "authorization": f"Bearer {self.access_token}",
+ "x-gradient-workspace-id": f"{self.workspace_id}",
+ "accept": "application/json",
+ "content-type": "application/json",
+ },
+ json=dict(
+ inputs=[{"input": i} for i in texts],
+ ),
+ )
+
+ def _sync_request_embed(
+ self, model: str, batch_texts: List[str]
+ ) -> List[List[float]]:
+ response = requests.post(
+ **self._kwargs_post_request(model=model, texts=batch_texts)
+ )
+ if response.status_code != 200:
+ raise Exception(
+ f"Gradient returned an unexpected response with status "
+ f"{response.status_code}: {response.text}"
+ )
+ return [e["embedding"] for e in response.json()["embeddings"]]
+
+ def embed(self, model: str, texts: List[str]) -> List[List[float]]:
+ """call the embedding of model
+
+ Args:
+ model (str): to embedding model
+ texts (List[str]): List of sentences to embed.
+
+ Returns:
+ List[List[float]]: List of vectors for each sentence
+ """
+ perm_texts, unpermute_func = self._permute(texts)
+ perm_texts_batched = self._batch(perm_texts)
+
+ # Request
+ map_args = (
+ self._sync_request_embed,
+ [model] * len(perm_texts_batched),
+ perm_texts_batched,
+ )
+ if len(perm_texts_batched) == 1:
+ embeddings_batch_perm = list(map(*map_args))
+ else:
+ with ThreadPoolExecutor(32) as p:
+ embeddings_batch_perm = list(p.map(*map_args))
+
+ embeddings_perm = self._unbatch(embeddings_batch_perm)
+ embeddings = unpermute_func(embeddings_perm)
+ return embeddings
+
+ async def _async_request(
+ self, session: aiohttp.ClientSession, kwargs: Dict[str, Any]
+ ) -> List[List[float]]:
+ async with session.post(**kwargs) as response:
+ if response.status != 200:
+ raise Exception(
+ f"Gradient returned an unexpected response with status "
+ f"{response.status}: {response.text}"
+ )
+ embedding = (await response.json())["embeddings"]
+ return [e["embedding"] for e in embedding]
+
+ async def aembed(self, model: str, texts: List[str]) -> List[List[float]]:
+ """call the embedding of model, async method
+
+ Args:
+ model (str): to embedding model
+ texts (List[str]): List of sentences to embed.
+
+ Returns:
+ List[List[float]]: List of vectors for each sentence
+ """
+ perm_texts, unpermute_func = self._permute(texts)
+ perm_texts_batched = self._batch(perm_texts)
+
+ # Request
+ if self.aiosession is None:
+ self.aiosession = aiohttp.ClientSession(
+ trust_env=True, connector=aiohttp.TCPConnector(limit=32)
+ )
+ async with self.aiosession as session:
+ embeddings_batch_perm = await asyncio.gather(
+ *[
+ self._async_request(
+ session=session,
+ **self._kwargs_post_request(model=model, texts=t),
+ )
+ for t in perm_texts_batched
+ ]
+ )
+
+ embeddings_perm = self._unbatch(embeddings_batch_perm)
+ embeddings = unpermute_func(embeddings_perm)
+ return embeddings
diff --git a/libs/community/langchain_community/embeddings/huggingface.py b/libs/community/langchain_community/embeddings/huggingface.py
new file mode 100644
index 00000000000..84a568866f1
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/huggingface.py
@@ -0,0 +1,343 @@
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field
+
+DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
+DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
+DEFAULT_BGE_MODEL = "BAAI/bge-large-en"
+DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
+DEFAULT_QUERY_INSTRUCTION = (
+ "Represent the question for retrieving supporting documents: "
+)
+DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
+ "Represent this question for searching relevant passages: "
+)
+DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "δΈΊθΏδΈͺε₯εηζ葨瀺δ»₯η¨δΊζ£η΄’ηΈε ³ζη« οΌ"
+
+
+class HuggingFaceEmbeddings(BaseModel, Embeddings):
+ """HuggingFace sentence_transformers embedding models.
+
+ To use, you should have the ``sentence_transformers`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import HuggingFaceEmbeddings
+
+ model_name = "sentence-transformers/all-mpnet-base-v2"
+ model_kwargs = {'device': 'cpu'}
+ encode_kwargs = {'normalize_embeddings': False}
+ hf = HuggingFaceEmbeddings(
+ model_name=model_name,
+ model_kwargs=model_kwargs,
+ encode_kwargs=encode_kwargs
+ )
+ """
+
+ client: Any #: :meta private:
+ model_name: str = DEFAULT_MODEL_NAME
+ """Model name to use."""
+ cache_folder: Optional[str] = None
+ """Path to store models.
+ Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Keyword arguments to pass to the model."""
+ encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Keyword arguments to pass when calling the `encode` method of the model."""
+ multi_process: bool = False
+ """Run encode() on multiple GPUs."""
+
+ def __init__(self, **kwargs: Any):
+ """Initialize the sentence_transformer."""
+ super().__init__(**kwargs)
+ try:
+ import sentence_transformers
+
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import sentence_transformers python package. "
+ "Please install it with `pip install sentence-transformers`."
+ ) from exc
+
+ self.client = sentence_transformers.SentenceTransformer(
+ self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
+ )
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a HuggingFace transformer model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ import sentence_transformers
+
+ texts = list(map(lambda x: x.replace("\n", " "), texts))
+ if self.multi_process:
+ pool = self.client.start_multi_process_pool()
+ embeddings = self.client.encode_multi_process(texts, pool)
+ sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
+ else:
+ embeddings = self.client.encode(texts, **self.encode_kwargs)
+
+ return embeddings.tolist()
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a HuggingFace transformer model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed_documents([text])[0]
+
+
+class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
+ """Wrapper around sentence_transformers embedding models.
+
+ To use, you should have the ``sentence_transformers``
+ and ``InstructorEmbedding`` python packages installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
+
+ model_name = "hkunlp/instructor-large"
+ model_kwargs = {'device': 'cpu'}
+ encode_kwargs = {'normalize_embeddings': True}
+ hf = HuggingFaceInstructEmbeddings(
+ model_name=model_name,
+ model_kwargs=model_kwargs,
+ encode_kwargs=encode_kwargs
+ )
+ """
+
+ client: Any #: :meta private:
+ model_name: str = DEFAULT_INSTRUCT_MODEL
+ """Model name to use."""
+ cache_folder: Optional[str] = None
+ """Path to store models.
+ Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Keyword arguments to pass to the model."""
+ encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Keyword arguments to pass when calling the `encode` method of the model."""
+ embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
+ """Instruction to use for embedding documents."""
+ query_instruction: str = DEFAULT_QUERY_INSTRUCTION
+ """Instruction to use for embedding query."""
+
+ def __init__(self, **kwargs: Any):
+ """Initialize the sentence_transformer."""
+ super().__init__(**kwargs)
+ try:
+ from InstructorEmbedding import INSTRUCTOR
+
+ self.client = INSTRUCTOR(
+ self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
+ )
+ except ImportError as e:
+ raise ImportError("Dependencies for InstructorEmbedding not found.") from e
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a HuggingFace instruct model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ instruction_pairs = [[self.embed_instruction, text] for text in texts]
+ embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
+ return embeddings.tolist()
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a HuggingFace instruct model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ instruction_pair = [self.query_instruction, text]
+ embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
+ return embedding.tolist()
+
+
+class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
+ """HuggingFace BGE sentence_transformers embedding models.
+
+ To use, you should have the ``sentence_transformers`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
+
+ model_name = "BAAI/bge-large-en"
+ model_kwargs = {'device': 'cpu'}
+ encode_kwargs = {'normalize_embeddings': True}
+ hf = HuggingFaceBgeEmbeddings(
+ model_name=model_name,
+ model_kwargs=model_kwargs,
+ encode_kwargs=encode_kwargs
+ )
+ """
+
+ client: Any #: :meta private:
+ model_name: str = DEFAULT_BGE_MODEL
+ """Model name to use."""
+ cache_folder: Optional[str] = None
+ """Path to store models.
+ Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Keyword arguments to pass to the model."""
+ encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Keyword arguments to pass when calling the `encode` method of the model."""
+ query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN
+ """Instruction to use for embedding query."""
+
+ def __init__(self, **kwargs: Any):
+ """Initialize the sentence_transformer."""
+ super().__init__(**kwargs)
+ try:
+ import sentence_transformers
+
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import sentence_transformers python package. "
+ "Please install it with `pip install sentence_transformers`."
+ ) from exc
+
+ self.client = sentence_transformers.SentenceTransformer(
+ self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
+ )
+ if "-zh" in self.model_name:
+ self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a HuggingFace transformer model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ texts = [t.replace("\n", " ") for t in texts]
+ embeddings = self.client.encode(texts, **self.encode_kwargs)
+ return embeddings.tolist()
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a HuggingFace transformer model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ text = text.replace("\n", " ")
+ embedding = self.client.encode(
+ self.query_instruction + text, **self.encode_kwargs
+ )
+ return embedding.tolist()
+
+
+class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
+ """Embed texts using the HuggingFace API.
+
+ Requires a HuggingFace Inference API key and a model name.
+ """
+
+ api_key: str
+ """Your API key for the HuggingFace Inference API."""
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
+ """The name of the model to use for text embeddings."""
+ api_url: Optional[str] = None
+ """Custom inference endpoint url. None for using default public url."""
+
+ @property
+ def _api_url(self) -> str:
+ return self.api_url or self._default_api_url
+
+ @property
+ def _default_api_url(self) -> str:
+ return (
+ "https://api-inference.huggingface.co"
+ "/pipeline"
+ "/feature-extraction"
+ f"/{self.model_name}"
+ )
+
+ @property
+ def _headers(self) -> dict:
+ return {"Authorization": f"Bearer {self.api_key}"}
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Get the embeddings for a list of texts.
+
+ Args:
+ texts (Documents): A list of texts to get embeddings for.
+
+ Returns:
+ Embedded texts as List[List[float]], where each inner List[float]
+ corresponds to a single input text.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
+
+ hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
+ api_key="your_api_key",
+ model_name="sentence-transformers/all-MiniLM-l6-v2"
+ )
+ texts = ["Hello, world!", "How are you?"]
+ hf_embeddings.embed_documents(texts)
+ """ # noqa: E501
+ response = requests.post(
+ self._api_url,
+ headers=self._headers,
+ json={
+ "inputs": texts,
+ "options": {"wait_for_model": True, "use_cache": True},
+ },
+ )
+ return response.json()
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a HuggingFace transformer model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed_documents([text])[0]
diff --git a/libs/community/langchain_community/embeddings/huggingface_hub.py b/libs/community/langchain_community/embeddings/huggingface_hub.py
new file mode 100644
index 00000000000..773dddad7db
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/huggingface_hub.py
@@ -0,0 +1,109 @@
+import json
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
+VALID_TASKS = ("feature-extraction",)
+
+
+class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
+ """HuggingFaceHub embedding models.
+
+ To use, you should have the ``huggingface_hub`` python package installed, and the
+ environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import HuggingFaceHubEmbeddings
+ model = "sentence-transformers/all-mpnet-base-v2"
+ hf = HuggingFaceHubEmbeddings(
+ model=model,
+ task="feature-extraction",
+ huggingfacehub_api_token="my-api-key",
+ )
+ """
+
+ client: Any #: :meta private:
+ model: Optional[str] = None
+ """Model name to use."""
+ repo_id: Optional[str] = None
+ """Huggingfacehub repository id, for backward compatibility."""
+ task: Optional[str] = "feature-extraction"
+ """Task to call the model with."""
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+
+ huggingfacehub_api_token: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ huggingfacehub_api_token = get_from_dict_or_env(
+ values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
+ )
+ try:
+ from huggingface_hub import InferenceClient
+
+ if values["model"]:
+ values["repo_id"] = values["model"]
+ elif values["repo_id"]:
+ values["model"] = values["repo_id"]
+ else:
+ values["model"] = DEFAULT_MODEL
+ values["repo_id"] = DEFAULT_MODEL
+
+ client = InferenceClient(
+ model=values["model"],
+ token=huggingfacehub_api_token,
+ )
+ if values["task"] not in VALID_TASKS:
+ raise ValueError(
+ f"Got invalid task {values['task']}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ values["client"] = client
+ except ImportError:
+ raise ImportError(
+ "Could not import huggingface_hub python package. "
+ "Please install it with `pip install huggingface_hub`."
+ )
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ # replace newlines, which can negatively affect performance.
+ texts = [text.replace("\n", " ") for text in texts]
+ _model_kwargs = self.model_kwargs or {}
+ responses = self.client.post(
+ json={"inputs": texts, "parameters": _model_kwargs, "task": self.task}
+ )
+ return json.loads(responses.decode())
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to HuggingFaceHub's embedding endpoint for embedding query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ response = self.embed_documents([text])[0]
+ return response
diff --git a/libs/community/langchain_community/embeddings/infinity.py b/libs/community/langchain_community/embeddings/infinity.py
new file mode 100644
index 00000000000..4cc5688e0e1
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/infinity.py
@@ -0,0 +1,322 @@
+"""written under MIT Licence, Michael Feil 2023."""
+
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import aiohttp
+import numpy as np
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+__all__ = ["InfinityEmbeddings"]
+
+
+class InfinityEmbeddings(BaseModel, Embeddings):
+ """Embedding models for self-hosted https://github.com/michaelfeil/infinity
+ This should also work for text-embeddings-inference and other
+ self-hosted openai-compatible servers.
+
+ Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity
+
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import InfinityEmbeddings
+ InfinityEmbeddings(
+ model="BAAI/bge-small",
+ infinity_api_url="http://localhost:7797/v1",
+ )
+ """
+
+ model: str
+ "Underlying Infinity model id."
+
+ infinity_api_url: str = "http://localhost:7797/v1"
+ """Endpoint URL to use."""
+
+ client: Any = None #: :meta private:
+ """Infinity client."""
+
+ # LLM call kwargs
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(allow_reuse=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+
+ values["infinity_api_url"] = get_from_dict_or_env(
+ values, "infinity_api_url", "INFINITY_API_URL"
+ )
+
+ values["client"] = TinyAsyncOpenAIInfinityEmbeddingClient(
+ host=values["infinity_api_url"],
+ )
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to Infinity's embedding endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = self.client.embed(
+ model=self.model,
+ texts=texts,
+ )
+ return embeddings
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Async call out to Infinity's embedding endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = await self.client.aembed(
+ model=self.model,
+ texts=texts,
+ )
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Infinity's embedding endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed_documents([text])[0]
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """Async call out to Infinity's embedding endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ embeddings = await self.aembed_documents([text])
+ return embeddings[0]
+
+
+class TinyAsyncOpenAIInfinityEmbeddingClient: #: :meta private:
+ """A helper tool to embed Infinity. Not part of Langchain's stable API,
+ direct use discouraged.
+
+ Example:
+ .. code-block:: python
+
+
+ mini_client = TinyAsyncInfinityEmbeddingClient(
+ )
+ embeds = mini_client.embed(
+ model="BAAI/bge-small",
+ text=["doc1", "doc2"]
+ )
+ # or
+ embeds = await mini_client.aembed(
+ model="BAAI/bge-small",
+ text=["doc1", "doc2"]
+ )
+
+ """
+
+ def __init__(
+ self,
+ host: str = "http://localhost:7797/v1",
+ aiosession: Optional[aiohttp.ClientSession] = None,
+ ) -> None:
+ self.host = host
+ self.aiosession = aiosession
+
+ if self.host is None or len(self.host) < 3:
+ raise ValueError(" param `host` must be set to a valid url")
+ self._batch_size = 128
+
+ @staticmethod
+ def _permute(
+ texts: List[str], sorter: Callable = len
+ ) -> Tuple[List[str], Callable]:
+ """Sort texts in ascending order, and
+ delivers a lambda expr, which can sort a same length list
+ https://github.com/UKPLab/sentence-transformers/blob/
+ c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156
+
+ Args:
+ texts (List[str]): _description_
+ sorter (Callable, optional): _description_. Defaults to len.
+
+ Returns:
+ Tuple[List[str], Callable]: _description_
+
+ Example:
+ ```
+ texts = ["one","three","four"]
+ perm_texts, undo = self._permute(texts)
+ texts == undo(perm_texts)
+ ```
+ """
+
+ if len(texts) == 1:
+ # special case query
+ return texts, lambda t: t
+ length_sorted_idx = np.argsort([-sorter(sen) for sen in texts])
+ texts_sorted = [texts[idx] for idx in length_sorted_idx]
+
+ return texts_sorted, lambda unsorted_embeddings: [ # noqa E731
+ unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
+ ]
+
+ def _batch(self, texts: List[str]) -> List[List[str]]:
+ """
+ splits Lists of text parts into batches of size max `self._batch_size`
+ When encoding vector database,
+
+ Args:
+ texts (List[str]): List of sentences
+ self._batch_size (int, optional): max batch size of one request.
+
+ Returns:
+ List[List[str]]: Batches of List of sentences
+ """
+ if len(texts) == 1:
+ # special case query
+ return [texts]
+ batches = []
+ for start_index in range(0, len(texts), self._batch_size):
+ batches.append(texts[start_index : start_index + self._batch_size])
+ return batches
+
+ @staticmethod
+ def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
+ if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
+ # special case query
+ return batch_of_texts[0]
+ texts = []
+ for sublist in batch_of_texts:
+ texts.extend(sublist)
+ return texts
+
+ def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]:
+ """Build the kwargs for the Post request, used by sync
+
+ Args:
+ model (str): _description_
+ texts (List[str]): _description_
+
+ Returns:
+ Dict[str, Collection[str]]: _description_
+ """
+ return dict(
+ url=f"{self.host}/embeddings",
+ headers={
+ # "accept": "application/json",
+ "content-type": "application/json",
+ },
+ json=dict(
+ input=texts,
+ model=model,
+ ),
+ )
+
+ def _sync_request_embed(
+ self, model: str, batch_texts: List[str]
+ ) -> List[List[float]]:
+ response = requests.post(
+ **self._kwargs_post_request(model=model, texts=batch_texts)
+ )
+ if response.status_code != 200:
+ raise Exception(
+ f"Infinity returned an unexpected response with status "
+ f"{response.status_code}: {response.text}"
+ )
+ return [e["embedding"] for e in response.json()["data"]]
+
+ def embed(self, model: str, texts: List[str]) -> List[List[float]]:
+ """call the embedding of model
+
+ Args:
+ model (str): to embedding model
+ texts (List[str]): List of sentences to embed.
+
+ Returns:
+ List[List[float]]: List of vectors for each sentence
+ """
+ perm_texts, unpermute_func = self._permute(texts)
+ perm_texts_batched = self._batch(perm_texts)
+
+ # Request
+ map_args = (
+ self._sync_request_embed,
+ [model] * len(perm_texts_batched),
+ perm_texts_batched,
+ )
+ if len(perm_texts_batched) == 1:
+ embeddings_batch_perm = list(map(*map_args))
+ else:
+ with ThreadPoolExecutor(32) as p:
+ embeddings_batch_perm = list(p.map(*map_args))
+
+ embeddings_perm = self._unbatch(embeddings_batch_perm)
+ embeddings = unpermute_func(embeddings_perm)
+ return embeddings
+
+ async def _async_request(
+ self, session: aiohttp.ClientSession, kwargs: Dict[str, Any]
+ ) -> List[List[float]]:
+ async with session.post(**kwargs) as response:
+ if response.status != 200:
+ raise Exception(
+ f"Infinity returned an unexpected response with status "
+ f"{response.status}: {response.text}"
+ )
+ embedding = (await response.json())["embeddings"]
+ return [e["embedding"] for e in embedding]
+
+ async def aembed(self, model: str, texts: List[str]) -> List[List[float]]:
+ """call the embedding of model, async method
+
+ Args:
+ model (str): to embedding model
+ texts (List[str]): List of sentences to embed.
+
+ Returns:
+ List[List[float]]: List of vectors for each sentence
+ """
+ perm_texts, unpermute_func = self._permute(texts)
+ perm_texts_batched = self._batch(perm_texts)
+
+ # Request
+ if self.aiosession is None:
+ self.aiosession = aiohttp.ClientSession(
+ trust_env=True, connector=aiohttp.TCPConnector(limit=32)
+ )
+ async with self.aiosession as session:
+ embeddings_batch_perm = await asyncio.gather(
+ *[
+ self._async_request(
+ session=session,
+ **self._kwargs_post_request(model=model, texts=t),
+ )
+ for t in perm_texts_batched
+ ]
+ )
+
+ embeddings_perm = self._unbatch(embeddings_batch_perm)
+ embeddings = unpermute_func(embeddings_perm)
+ return embeddings
diff --git a/libs/community/langchain_community/embeddings/javelin_ai_gateway.py b/libs/community/langchain_community/embeddings/javelin_ai_gateway.py
new file mode 100644
index 00000000000..c91a003291d
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/javelin_ai_gateway.py
@@ -0,0 +1,110 @@
+from __future__ import annotations
+
+from typing import Any, Iterator, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel
+
+
+def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
+ for i in range(0, len(texts), size):
+ yield texts[i : i + size]
+
+
+class JavelinAIGatewayEmbeddings(Embeddings, BaseModel):
+ """
+ Wrapper around embeddings LLMs in the Javelin AI Gateway.
+
+ To use, you should have the ``javelin_sdk`` python package installed.
+ For more information, see https://docs.getjavelin.io
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import JavelinAIGatewayEmbeddings
+
+ embeddings = JavelinAIGatewayEmbeddings(
+ gateway_uri="",
+ route=""
+ )
+ """
+
+ client: Any
+ """javelin client."""
+
+ route: str
+ """The route to use for the Javelin AI Gateway API."""
+
+ gateway_uri: Optional[str] = None
+ """The URI for the Javelin AI Gateway API."""
+
+ javelin_api_key: Optional[str] = None
+ """The API key for the Javelin AI Gateway API."""
+
+ def __init__(self, **kwargs: Any):
+ try:
+ from javelin_sdk import (
+ JavelinClient,
+ UnauthorizedError,
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import javelin_sdk python package. "
+ "Please install it with `pip install javelin_sdk`."
+ )
+
+ super().__init__(**kwargs)
+ if self.gateway_uri:
+ try:
+ self.client = JavelinClient(
+ base_url=self.gateway_uri, api_key=self.javelin_api_key
+ )
+ except UnauthorizedError as e:
+ raise ValueError("Javelin: Incorrect API Key.") from e
+
+ def _query(self, texts: List[str]) -> List[List[float]]:
+ embeddings = []
+ for txt in _chunk(texts, 20):
+ try:
+ resp = self.client.query_route(self.route, query_body={"input": txt})
+ resp_dict = resp.dict()
+
+ embeddings_chunk = resp_dict.get("llm_response", {}).get("data", [])
+ for item in embeddings_chunk:
+ if "embedding" in item:
+ embeddings.append(item["embedding"])
+ except ValueError as e:
+ print("Failed to query route: " + str(e))
+
+ return embeddings
+
+ async def _aquery(self, texts: List[str]) -> List[List[float]]:
+ embeddings = []
+ for txt in _chunk(texts, 20):
+ try:
+ resp = await self.client.aquery_route(
+ self.route, query_body={"input": txt}
+ )
+ resp_dict = resp.dict()
+
+ embeddings_chunk = resp_dict.get("llm_response", {}).get("data", [])
+ for item in embeddings_chunk:
+ if "embedding" in item:
+ embeddings.append(item["embedding"])
+ except ValueError as e:
+ print("Failed to query route: " + str(e))
+
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ return self._query(texts)
+
+ def embed_query(self, text: str) -> List[float]:
+ return self._query([text])[0]
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ return await self._aquery(texts)
+
+ async def aembed_query(self, text: str) -> List[float]:
+ result = await self._aquery([text])
+ return result[0]
diff --git a/libs/community/langchain_community/embeddings/jina.py b/libs/community/langchain_community/embeddings/jina.py
new file mode 100644
index 00000000000..783615e5943
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/jina.py
@@ -0,0 +1,73 @@
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
+
+
+class JinaEmbeddings(BaseModel, Embeddings):
+ """Jina embedding models."""
+
+ session: Any #: :meta private:
+ model_name: str = "jina-embeddings-v2-base-en"
+ jina_api_key: Optional[str] = None
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that auth token exists in environment."""
+ try:
+ jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY")
+ except ValueError as original_exc:
+ try:
+ jina_api_key = get_from_dict_or_env(
+ values, "jina_auth_token", "JINA_AUTH_TOKEN"
+ )
+ except ValueError:
+ raise original_exc
+ session = requests.Session()
+ session.headers.update(
+ {
+ "Authorization": f"Bearer {jina_api_key}",
+ "Accept-Encoding": "identity",
+ "Content-type": "application/json",
+ }
+ )
+ values["session"] = session
+ return values
+
+ def _embed(self, texts: List[str]) -> List[List[float]]:
+ # Call Jina AI Embedding API
+ resp = self.session.post( # type: ignore
+ JINA_API_URL, json={"input": texts, "model": self.model_name}
+ ).json()
+ if "data" not in resp:
+ raise RuntimeError(resp["detail"])
+
+ embeddings = resp["data"]
+
+ # Sort resulting embeddings by index
+ sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
+
+ # Return just the embeddings
+ return [result["embedding"] for result in sorted_embeddings]
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to Jina's embedding endpoint.
+ Args:
+ texts: The list of texts to embed.
+ Returns:
+ List of embeddings, one for each text.
+ """
+ return self._embed(texts)
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Jina's embedding endpoint.
+ Args:
+ text: The text to embed.
+ Returns:
+ Embeddings for the text.
+ """
+ return self._embed([text])[0]
diff --git a/libs/community/langchain_community/embeddings/johnsnowlabs.py b/libs/community/langchain_community/embeddings/johnsnowlabs.py
new file mode 100644
index 00000000000..f183efe87b5
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/johnsnowlabs.py
@@ -0,0 +1,92 @@
+import os
+import sys
+from typing import Any, List
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+
+class JohnSnowLabsEmbeddings(BaseModel, Embeddings):
+ """JohnSnowLabs embedding models
+
+ To use, you should have the ``johnsnowlabs`` python package installed.
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
+
+ embedding = JohnSnowLabsEmbeddings(model='embed_sentence.bert')
+ output = embedding.embed_query("foo bar")
+ """ # noqa: E501
+
+ model: Any = "embed_sentence.bert"
+
+ def __init__(
+ self,
+ model: Any = "embed_sentence.bert",
+ hardware_target: str = "cpu",
+ **kwargs: Any,
+ ):
+ """Initialize the johnsnowlabs model."""
+ super().__init__(**kwargs)
+ # 1) Check imports
+ try:
+ from johnsnowlabs import nlp
+ from nlu.pipe.pipeline import NLUPipeline
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import johnsnowlabs python package. "
+ "Please install it with `pip install johnsnowlabs`."
+ ) from exc
+
+ # 2) Start a Spark Session
+ try:
+ os.environ["PYSPARK_PYTHON"] = sys.executable
+ os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
+ nlp.start(hardware_target=hardware_target)
+ except Exception as exc:
+ raise Exception("Failure starting Spark Session") from exc
+
+ # 3) Load the model
+ try:
+ if isinstance(model, str):
+ self.model = nlp.load(model)
+ elif isinstance(model, NLUPipeline):
+ self.model = model
+ else:
+ self.model = nlp.to_nlu_pipe(model)
+ except Exception as exc:
+ raise Exception("Failure loading model") from exc
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a JohnSnowLabs transformer model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+
+ df = self.model.predict(texts, output_level="document")
+ emb_col = None
+ for c in df.columns:
+ if "embedding" in c:
+ emb_col = c
+ return [vec.tolist() for vec in df[emb_col].tolist()]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a JohnSnowLabs transformer model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed_documents([text])[0]
diff --git a/libs/community/langchain_community/embeddings/llamacpp.py b/libs/community/langchain_community/embeddings/llamacpp.py
new file mode 100644
index 00000000000..1f75171dcb3
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/llamacpp.py
@@ -0,0 +1,126 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+
+
+class LlamaCppEmbeddings(BaseModel, Embeddings):
+ """llama.cpp embedding models.
+
+ To use, you should have the llama-cpp-python library installed, and provide the
+ path to the Llama model as a named parameter to the constructor.
+ Check out: https://github.com/abetlen/llama-cpp-python
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import LlamaCppEmbeddings
+ llama = LlamaCppEmbeddings(model_path="/path/to/model.bin")
+ """
+
+ client: Any #: :meta private:
+ model_path: str
+
+ n_ctx: int = Field(512, alias="n_ctx")
+ """Token context window."""
+
+ n_parts: int = Field(-1, alias="n_parts")
+ """Number of parts to split the model into.
+ If -1, the number of parts is automatically determined."""
+
+ seed: int = Field(-1, alias="seed")
+ """Seed. If -1, a random seed is used."""
+
+ f16_kv: bool = Field(False, alias="f16_kv")
+ """Use half-precision for key/value cache."""
+
+ logits_all: bool = Field(False, alias="logits_all")
+ """Return logits for all tokens, not just the last token."""
+
+ vocab_only: bool = Field(False, alias="vocab_only")
+ """Only load the vocabulary, no weights."""
+
+ use_mlock: bool = Field(False, alias="use_mlock")
+ """Force system to keep model in RAM."""
+
+ n_threads: Optional[int] = Field(None, alias="n_threads")
+ """Number of threads to use. If None, the number
+ of threads is automatically determined."""
+
+ n_batch: Optional[int] = Field(8, alias="n_batch")
+ """Number of tokens to process in parallel.
+ Should be a number between 1 and n_ctx."""
+
+ n_gpu_layers: Optional[int] = Field(None, alias="n_gpu_layers")
+ """Number of layers to be loaded into gpu memory. Default None."""
+
+ verbose: bool = Field(True, alias="verbose")
+ """Print verbose output to stderr."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that llama-cpp-python library is installed."""
+ model_path = values["model_path"]
+ model_param_names = [
+ "n_ctx",
+ "n_parts",
+ "seed",
+ "f16_kv",
+ "logits_all",
+ "vocab_only",
+ "use_mlock",
+ "n_threads",
+ "n_batch",
+ "verbose",
+ ]
+ model_params = {k: values[k] for k in model_param_names}
+ # For backwards compatibility, only include if non-null.
+ if values["n_gpu_layers"] is not None:
+ model_params["n_gpu_layers"] = values["n_gpu_layers"]
+
+ try:
+ from llama_cpp import Llama
+
+ values["client"] = Llama(model_path, embedding=True, **model_params)
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import llama-cpp-python library. "
+ "Please install the llama-cpp-python library to "
+ "use this embedding model: pip install llama-cpp-python"
+ )
+ except Exception as e:
+ raise ValueError(
+ f"Could not load Llama model from path: {model_path}. "
+ f"Received error {e}"
+ )
+
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed a list of documents using the Llama model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = [self.client.embed(text) for text in texts]
+ return [list(map(float, e)) for e in embeddings]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using the Llama model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ embedding = self.client.embed(text)
+ return list(map(float, embedding))
diff --git a/libs/community/langchain_community/embeddings/llm_rails.py b/libs/community/langchain_community/embeddings/llm_rails.py
new file mode 100644
index 00000000000..6f233d59d33
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/llm_rails.py
@@ -0,0 +1,71 @@
+""" This file is for LLMRails Embedding """
+import logging
+import os
+from typing import List, Optional
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+
+class LLMRailsEmbeddings(BaseModel, Embeddings):
+ """LLMRails embedding models.
+
+ To use, you should have the environment
+ variable ``LLM_RAILS_API_KEY`` set with your API key or pass it
+ as a named parameter to the constructor.
+
+ Model can be one of ["embedding-english-v1","embedding-multi-v1"]
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import LLMRailsEmbeddings
+ cohere = LLMRailsEmbeddings(
+ model="embedding-english-v1", api_key="my-api-key"
+ )
+ """
+
+ model: str = "embedding-english-v1"
+ """Model name to use."""
+
+ api_key: Optional[str] = None
+ """LLMRails API key."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to Cohere's embedding endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ api_key = self.api_key or os.environ.get("LLM_RAILS_API_KEY")
+ if api_key is None:
+ logging.warning("Can't find LLMRails credentials in environment.")
+ raise ValueError("LLM_RAILS_API_KEY is not set")
+
+ response = requests.post(
+ "https://api.llmrails.com/v1/embeddings",
+ headers={"X-API-KEY": api_key},
+ json={"input": texts, "model": self.model},
+ timeout=60,
+ )
+ return [item["embedding"] for item in response.json()["data"]]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Cohere's embedding endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.embed_documents([text])[0]
diff --git a/libs/community/langchain_community/embeddings/localai.py b/libs/community/langchain_community/embeddings/localai.py
new file mode 100644
index 00000000000..b5a926e8fe2
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/localai.py
@@ -0,0 +1,345 @@
+from __future__ import annotations
+
+import logging
+import warnings
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
+from tenacity import (
+ AsyncRetrying,
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]:
+ import openai
+
+ min_seconds = 4
+ max_seconds = 10
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(embeddings.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
+ import openai
+
+ min_seconds = 4
+ max_seconds = 10
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ async_retrying = AsyncRetrying(
+ reraise=True,
+ stop=stop_after_attempt(embeddings.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+ def wrap(func: Callable) -> Callable:
+ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
+ async for _ in async_retrying:
+ return await func(*args, **kwargs)
+ raise AssertionError("this is unreachable")
+
+ return wrapped_f
+
+ return wrap
+
+
+# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
+def _check_response(response: dict) -> dict:
+ if any(len(d["embedding"]) == 1 for d in response["data"]):
+ import openai
+
+ raise openai.error.APIError("LocalAI API returned an empty embedding")
+ return response
+
+
+def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
+ """Use tenacity to retry the embedding call."""
+ retry_decorator = _create_retry_decorator(embeddings)
+
+ @retry_decorator
+ def _embed_with_retry(**kwargs: Any) -> Any:
+ response = embeddings.client.create(**kwargs)
+ return _check_response(response)
+
+ return _embed_with_retry(**kwargs)
+
+
+async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
+ """Use tenacity to retry the embedding call."""
+
+ @_async_retry_decorator(embeddings)
+ async def _async_embed_with_retry(**kwargs: Any) -> Any:
+ response = await embeddings.client.acreate(**kwargs)
+ return _check_response(response)
+
+ return await _async_embed_with_retry(**kwargs)
+
+
+class LocalAIEmbeddings(BaseModel, Embeddings):
+ """LocalAI embedding models.
+
+ Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class
+ uses the ``openai`` Python package's ``openai.Embedding`` as its client.
+ Thus, you should have the ``openai`` python package installed, and defeat
+ the environment variable ``OPENAI_API_KEY`` by setting to a random string.
+ You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI
+ service endpoint.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import LocalAIEmbeddings
+ openai = LocalAIEmbeddings(
+ openai_api_key="random-string",
+ openai_api_base="http://localhost:8080"
+ )
+
+ """
+
+ client: Any #: :meta private:
+ model: str = "text-embedding-ada-002"
+ deployment: str = model
+ openai_api_version: Optional[str] = None
+ openai_api_base: Optional[str] = None
+ # to support explicit proxy for LocalAI
+ openai_proxy: Optional[str] = None
+ embedding_ctx_length: int = 8191
+ """The maximum number of tokens to embed at once."""
+ openai_api_key: Optional[str] = None
+ openai_organization: Optional[str] = None
+ allowed_special: Union[Literal["all"], Set[str]] = set()
+ disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
+ chunk_size: int = 1000
+ """Maximum number of texts to embed in each batch"""
+ max_retries: int = 6
+ """Maximum number of retries to make when generating."""
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
+ """Timeout in seconds for the LocalAI request."""
+ headers: Any = None
+ show_progress_bar: bool = False
+ """Whether to show a progress bar when embedding."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not explicitly specified."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ warnings.warn(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["openai_api_key"] = get_from_dict_or_env(
+ values, "openai_api_key", "OPENAI_API_KEY"
+ )
+ values["openai_api_base"] = get_from_dict_or_env(
+ values,
+ "openai_api_base",
+ "OPENAI_API_BASE",
+ default="",
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values,
+ "openai_proxy",
+ "OPENAI_PROXY",
+ default="",
+ )
+
+ default_api_version = ""
+ values["openai_api_version"] = get_from_dict_or_env(
+ values,
+ "openai_api_version",
+ "OPENAI_API_VERSION",
+ default=default_api_version,
+ )
+ values["openai_organization"] = get_from_dict_or_env(
+ values,
+ "openai_organization",
+ "OPENAI_ORGANIZATION",
+ default="",
+ )
+ try:
+ import openai
+
+ values["client"] = openai.Embedding
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ return values
+
+ @property
+ def _invocation_params(self) -> Dict:
+ openai_args = {
+ "model": self.model,
+ "request_timeout": self.request_timeout,
+ "headers": self.headers,
+ "api_key": self.openai_api_key,
+ "organization": self.openai_organization,
+ "api_base": self.openai_api_base,
+ "api_version": self.openai_api_version,
+ **self.model_kwargs,
+ }
+ if self.openai_proxy:
+ import openai
+
+ openai.proxy = {
+ "http": self.openai_proxy,
+ "https": self.openai_proxy,
+ } # type: ignore[assignment] # noqa: E501
+ return openai_args
+
+ def _embedding_func(self, text: str, *, engine: str) -> List[float]:
+ """Call out to LocalAI's embedding endpoint."""
+ # handle large input text
+ if self.model.endswith("001"):
+ # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
+ # replace newlines, which can negatively affect performance.
+ text = text.replace("\n", " ")
+ return embed_with_retry(
+ self,
+ input=[text],
+ **self._invocation_params,
+ )["data"][0]["embedding"]
+
+ async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
+ """Call out to LocalAI's embedding endpoint."""
+ # handle large input text
+ if self.model.endswith("001"):
+ # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
+ # replace newlines, which can negatively affect performance.
+ text = text.replace("\n", " ")
+ return (
+ await async_embed_with_retry(
+ self,
+ input=[text],
+ **self._invocation_params,
+ )
+ )["data"][0]["embedding"]
+
+ def embed_documents(
+ self, texts: List[str], chunk_size: Optional[int] = 0
+ ) -> List[List[float]]:
+ """Call out to LocalAI's embedding endpoint for embedding search docs.
+
+ Args:
+ texts: The list of texts to embed.
+ chunk_size: The chunk size of embeddings. If None, will use the chunk size
+ specified by the class.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ # call _embedding_func for each text
+ return [self._embedding_func(text, engine=self.deployment) for text in texts]
+
+ async def aembed_documents(
+ self, texts: List[str], chunk_size: Optional[int] = 0
+ ) -> List[List[float]]:
+ """Call out to LocalAI's embedding endpoint async for embedding search docs.
+
+ Args:
+ texts: The list of texts to embed.
+ chunk_size: The chunk size of embeddings. If None, will use the chunk size
+ specified by the class.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = []
+ for text in texts:
+ response = await self._aembedding_func(text, engine=self.deployment)
+ embeddings.append(response)
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to LocalAI's embedding endpoint for embedding query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embedding for the text.
+ """
+ embedding = self._embedding_func(text, engine=self.deployment)
+ return embedding
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """Call out to LocalAI's embedding endpoint async for embedding query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embedding for the text.
+ """
+ embedding = await self._aembedding_func(text, engine=self.deployment)
+ return embedding
diff --git a/libs/community/langchain_community/embeddings/minimax.py b/libs/community/langchain_community/embeddings/minimax.py
new file mode 100644
index 00000000000..06482ae29d9
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/minimax.py
@@ -0,0 +1,161 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Callable, Dict, List, Optional
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from tenacity import (
+ before_sleep_log,
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator() -> Callable[[Any], Any]:
+ """Returns a tenacity retry decorator."""
+
+ multiplier = 1
+ min_seconds = 1
+ max_seconds = 4
+ max_retries = 6
+
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(max_retries),
+ wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def embed_with_retry(embeddings: MiniMaxEmbeddings, *args: Any, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator()
+
+ @retry_decorator
+ def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
+ return embeddings.embed(*args, **kwargs)
+
+ return _embed_with_retry(*args, **kwargs)
+
+
+class MiniMaxEmbeddings(BaseModel, Embeddings):
+ """MiniMax's embedding service.
+
+ To use, you should have the environment variable ``MINIMAX_GROUP_ID`` and
+ ``MINIMAX_API_KEY`` set with your API token, or pass it as a named parameter to
+ the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import MiniMaxEmbeddings
+ embeddings = MiniMaxEmbeddings()
+
+ query_text = "This is a test query."
+ query_result = embeddings.embed_query(query_text)
+
+ document_text = "This is a test document."
+ document_result = embeddings.embed_documents([document_text])
+
+ """
+
+ endpoint_url: str = "https://api.minimax.chat/v1/embeddings"
+ """Endpoint URL to use."""
+ model: str = "embo-01"
+ """Embeddings model name to use."""
+ embed_type_db: str = "db"
+ """For embed_documents"""
+ embed_type_query: str = "query"
+ """For embed_query"""
+
+ minimax_group_id: Optional[str] = None
+ """Group ID for MiniMax API."""
+ minimax_api_key: Optional[str] = None
+ """API Key for MiniMax API."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that group id and api key exists in environment."""
+ minimax_group_id = get_from_dict_or_env(
+ values, "minimax_group_id", "MINIMAX_GROUP_ID"
+ )
+ minimax_api_key = get_from_dict_or_env(
+ values, "minimax_api_key", "MINIMAX_API_KEY"
+ )
+ values["minimax_group_id"] = minimax_group_id
+ values["minimax_api_key"] = minimax_api_key
+ return values
+
+ def embed(
+ self,
+ texts: List[str],
+ embed_type: str,
+ ) -> List[List[float]]:
+ payload = {
+ "model": self.model,
+ "type": embed_type,
+ "texts": texts,
+ }
+
+ # HTTP headers for authorization
+ headers = {
+ "Authorization": f"Bearer {self.minimax_api_key}",
+ "Content-Type": "application/json",
+ }
+
+ params = {
+ "GroupId": self.minimax_group_id,
+ }
+
+ # send request
+ response = requests.post(
+ self.endpoint_url, params=params, headers=headers, json=payload
+ )
+ parsed_response = response.json()
+
+ # check for errors
+ if parsed_response["base_resp"]["status_code"] != 0:
+ raise ValueError(
+ f"MiniMax API returned an error: {parsed_response['base_resp']}"
+ )
+
+ embeddings = parsed_response["vectors"]
+
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed documents using a MiniMax embedding endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using a MiniMax embedding endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ embeddings = embed_with_retry(
+ self, texts=[text], embed_type=self.embed_type_query
+ )
+ return embeddings[0]
diff --git a/libs/community/langchain_community/embeddings/mlflow.py b/libs/community/langchain_community/embeddings/mlflow.py
new file mode 100644
index 00000000000..0ae46bcffdb
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/mlflow.py
@@ -0,0 +1,74 @@
+from __future__ import annotations
+
+from typing import Any, Iterator, List
+from urllib.parse import urlparse
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
+
+
+def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
+ for i in range(0, len(texts), size):
+ yield texts[i : i + size]
+
+
+class MlflowEmbeddings(Embeddings, BaseModel):
+ """Wrapper around embeddings LLMs in MLflow.
+
+ To use, you should have the `mlflow[genai]` python package installed.
+ For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import MlflowEmbeddings
+
+ embeddings = MlflowEmbeddings(
+ target_uri="http://localhost:5000",
+ endpoint="embeddings",
+ )
+ """
+
+ endpoint: str
+ """The endpoint to use."""
+ target_uri: str
+ """The target URI to use."""
+ _client: Any = PrivateAttr()
+
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
+ self._validate_uri()
+ try:
+ from mlflow.deployments import get_deploy_client
+
+ self._client = get_deploy_client(self.target_uri)
+ except ImportError as e:
+ raise ImportError(
+ "Failed to create the client. "
+ f"Please run `pip install mlflow{self._mlflow_extras}` to install "
+ "required dependencies."
+ ) from e
+
+ @property
+ def _mlflow_extras(self) -> str:
+ return "[genai]"
+
+ def _validate_uri(self) -> None:
+ if self.target_uri == "databricks":
+ return
+ allowed = ["http", "https", "databricks"]
+ if urlparse(self.target_uri).scheme not in allowed:
+ raise ValueError(
+ f"Invalid target URI: {self.target_uri}. "
+ f"The scheme must be one of {allowed}."
+ )
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ embeddings: List[List[float]] = []
+ for txt in _chunk(texts, 20):
+ resp = self._client.predict(endpoint=self.endpoint, inputs={"input": txt})
+ embeddings.extend(r["embedding"] for r in resp["data"])
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ return self.embed_documents([text])[0]
diff --git a/libs/community/langchain_community/embeddings/mlflow_gateway.py b/libs/community/langchain_community/embeddings/mlflow_gateway.py
new file mode 100644
index 00000000000..ad54761cbe2
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/mlflow_gateway.py
@@ -0,0 +1,75 @@
+from __future__ import annotations
+
+import warnings
+from typing import Any, Iterator, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel
+
+
+def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
+ for i in range(0, len(texts), size):
+ yield texts[i : i + size]
+
+
+class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
+ """
+ Wrapper around embeddings LLMs in the MLflow AI Gateway.
+
+ To use, you should have the ``mlflow[gateway]`` python package installed.
+ For more information, see https://mlflow.org/docs/latest/gateway/index.html.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import MlflowAIGatewayEmbeddings
+
+ embeddings = MlflowAIGatewayEmbeddings(
+ gateway_uri="",
+ route=""
+ )
+ """
+
+ route: str
+ """The route to use for the MLflow AI Gateway API."""
+ gateway_uri: Optional[str] = None
+ """The URI for the MLflow AI Gateway API."""
+
+ def __init__(self, **kwargs: Any):
+ warnings.warn(
+ "`MlflowAIGatewayEmbeddings` is deprecated. Use `MlflowEmbeddings` or "
+ "`DatabricksEmbeddings` instead.",
+ DeprecationWarning,
+ )
+ try:
+ import mlflow.gateway
+ except ImportError as e:
+ raise ImportError(
+ "Could not import `mlflow.gateway` module. "
+ "Please install it with `pip install mlflow[gateway]`."
+ ) from e
+
+ super().__init__(**kwargs)
+ if self.gateway_uri:
+ mlflow.gateway.set_gateway_uri(self.gateway_uri)
+
+ def _query(self, texts: List[str]) -> List[List[float]]:
+ try:
+ import mlflow.gateway
+ except ImportError as e:
+ raise ImportError(
+ "Could not import `mlflow.gateway` module. "
+ "Please install it with `pip install mlflow[gateway]`."
+ ) from e
+
+ embeddings = []
+ for txt in _chunk(texts, 20):
+ resp = mlflow.gateway.query(self.route, data={"text": txt})
+ embeddings.append(resp["embeddings"])
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ return self._query(texts)
+
+ def embed_query(self, text: str) -> List[float]:
+ return self._query([text])[0]
diff --git a/libs/community/langchain_community/embeddings/modelscope_hub.py b/libs/community/langchain_community/embeddings/modelscope_hub.py
new file mode 100644
index 00000000000..dc2866dedaf
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/modelscope_hub.py
@@ -0,0 +1,73 @@
+from typing import Any, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+
+class ModelScopeEmbeddings(BaseModel, Embeddings):
+ """ModelScopeHub embedding models.
+
+ To use, you should have the ``modelscope`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import ModelScopeEmbeddings
+ model_id = "damo/nlp_corom_sentence-embedding_english-base"
+ embed = ModelScopeEmbeddings(model_id=model_id, model_revision="v1.0.0")
+ """
+
+ embed: Any
+ model_id: str = "damo/nlp_corom_sentence-embedding_english-base"
+ """Model name to use."""
+ model_revision: Optional[str] = None
+
+ def __init__(self, **kwargs: Any):
+ """Initialize the modelscope"""
+ super().__init__(**kwargs)
+ try:
+ from modelscope.pipelines import pipeline
+ from modelscope.utils.constant import Tasks
+ except ImportError as e:
+ raise ImportError(
+ "Could not import some python packages."
+ "Please install it with `pip install modelscope`."
+ ) from e
+ self.embed = pipeline(
+ Tasks.sentence_embedding,
+ model=self.model_id,
+ model_revision=self.model_revision,
+ )
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a modelscope embedding model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ texts = list(map(lambda x: x.replace("\n", " "), texts))
+ inputs = {"source_sentence": texts}
+ embeddings = self.embed(input=inputs)["text_embedding"]
+ return embeddings.tolist()
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a modelscope embedding model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ text = text.replace("\n", " ")
+ inputs = {"source_sentence": [text]}
+ embedding = self.embed(input=inputs)["text_embedding"][0]
+ return embedding.tolist()
diff --git a/libs/community/langchain_community/embeddings/mosaicml.py b/libs/community/langchain_community/embeddings/mosaicml.py
new file mode 100644
index 00000000000..bd8d97bb7e8
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/mosaicml.py
@@ -0,0 +1,147 @@
+from typing import Any, Dict, List, Mapping, Optional, Tuple
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
+ """MosaicML embedding service.
+
+ To use, you should have the
+ environment variable ``MOSAICML_API_TOKEN`` set with your API token, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import MosaicMLInstructorEmbeddings
+ endpoint_url = (
+ "https://models.hosted-on.mosaicml.hosting/instructor-large/v1/predict"
+ )
+ mosaic_llm = MosaicMLInstructorEmbeddings(
+ endpoint_url=endpoint_url,
+ mosaicml_api_token="my-api-key"
+ )
+ """
+
+ endpoint_url: str = (
+ "https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
+ )
+ """Endpoint URL to use."""
+ embed_instruction: str = "Represent the document for retrieval: "
+ """Instruction used to embed documents."""
+ query_instruction: str = (
+ "Represent the question for retrieving supporting documents: "
+ )
+ """Instruction used to embed the query."""
+ retry_sleep: float = 1.0
+ """How long to try sleeping for if a rate limit is encountered"""
+
+ mosaicml_api_token: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ mosaicml_api_token = get_from_dict_or_env(
+ values, "mosaicml_api_token", "MOSAICML_API_TOKEN"
+ )
+ values["mosaicml_api_token"] = mosaicml_api_token
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {"endpoint_url": self.endpoint_url}
+
+ def _embed(
+ self, input: List[Tuple[str, str]], is_retry: bool = False
+ ) -> List[List[float]]:
+ payload = {"inputs": input}
+
+ # HTTP headers for authorization
+ headers = {
+ "Authorization": f"{self.mosaicml_api_token}",
+ "Content-Type": "application/json",
+ }
+
+ # send request
+ try:
+ response = requests.post(self.endpoint_url, headers=headers, json=payload)
+ except requests.exceptions.RequestException as e:
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+
+ try:
+ if response.status_code == 429:
+ if not is_retry:
+ import time
+
+ time.sleep(self.retry_sleep)
+
+ return self._embed(input, is_retry=True)
+
+ raise ValueError(
+ f"Error raised by inference API: rate limit exceeded.\nResponse: "
+ f"{response.text}"
+ )
+
+ parsed_response = response.json()
+
+ # The inference API has changed a couple of times, so we add some handling
+ # to be robust to multiple response formats.
+ if isinstance(parsed_response, dict):
+ output_keys = ["data", "output", "outputs"]
+ for key in output_keys:
+ if key in parsed_response:
+ output_item = parsed_response[key]
+ break
+ else:
+ raise ValueError(
+ f"No key data or output in response: {parsed_response}"
+ )
+
+ if isinstance(output_item, list) and isinstance(output_item[0], list):
+ embeddings = output_item
+ else:
+ embeddings = [output_item]
+ else:
+ raise ValueError(f"Unexpected response type: {parsed_response}")
+
+ except requests.exceptions.JSONDecodeError as e:
+ raise ValueError(
+ f"Error raised by inference API: {e}.\nResponse: {response.text}"
+ )
+
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed documents using a MosaicML deployed instructor embedding model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ instruction_pairs = [(self.embed_instruction, text) for text in texts]
+ embeddings = self._embed(instruction_pairs)
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using a MosaicML deployed instructor embedding model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ instruction_pair = (self.query_instruction, text)
+ embedding = self._embed([instruction_pair])[0]
+ return embedding
diff --git a/libs/community/langchain_community/embeddings/nlpcloud.py b/libs/community/langchain_community/embeddings/nlpcloud.py
new file mode 100644
index 00000000000..748d63b9005
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/nlpcloud.py
@@ -0,0 +1,73 @@
+from typing import Any, Dict, List
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class NLPCloudEmbeddings(BaseModel, Embeddings):
+ """NLP Cloud embedding models.
+
+ To use, you should have the nlpcloud python package installed
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import NLPCloudEmbeddings
+
+ embeddings = NLPCloudEmbeddings()
+ """
+
+ model_name: str # Define model_name as a class attribute
+ gpu: bool # Define gpu as a class attribute
+ client: Any #: :meta private:
+
+ def __init__(
+ self,
+ model_name: str = "paraphrase-multilingual-mpnet-base-v2",
+ gpu: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(model_name=model_name, gpu=gpu, **kwargs)
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ nlpcloud_api_key = get_from_dict_or_env(
+ values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
+ )
+ try:
+ import nlpcloud
+
+ values["client"] = nlpcloud.Client(
+ values["model_name"], nlpcloud_api_key, gpu=values["gpu"], lang="en"
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import nlpcloud python package. "
+ "Please install it with `pip install nlpcloud`."
+ )
+ return values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed a list of documents using NLP Cloud.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+
+ return self.client.embeddings(texts)["embeddings"]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using NLP Cloud.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self.client.embeddings([text])["embeddings"][0]
diff --git a/libs/community/langchain_community/embeddings/octoai_embeddings.py b/libs/community/langchain_community/embeddings/octoai_embeddings.py
new file mode 100644
index 00000000000..bcdd412e051
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/octoai_embeddings.py
@@ -0,0 +1,90 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+DEFAULT_EMBED_INSTRUCTION = "Represent this input: "
+DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving similar documents: "
+
+
+class OctoAIEmbeddings(BaseModel, Embeddings):
+ """OctoAI Compute Service embedding models.
+
+ The environment variable ``OCTOAI_API_TOKEN`` should be set
+ with your API token, or it can be passed
+ as a named parameter to the constructor.
+ """
+
+ endpoint_url: Optional[str] = Field(None, description="Endpoint URL to use.")
+ model_kwargs: Optional[dict] = Field(
+ None, description="Keyword arguments to pass to the model."
+ )
+ octoai_api_token: Optional[str] = Field(None, description="OCTOAI API Token")
+ embed_instruction: str = Field(
+ DEFAULT_EMBED_INSTRUCTION,
+ description="Instruction to use for embedding documents.",
+ )
+ query_instruction: str = Field(
+ DEFAULT_QUERY_INSTRUCTION, description="Instruction to use for embedding query."
+ )
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(allow_reuse=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Ensure that the API key and python package exist in environment."""
+ values["octoai_api_token"] = get_from_dict_or_env(
+ values, "octoai_api_token", "OCTOAI_API_TOKEN"
+ )
+ values["endpoint_url"] = get_from_dict_or_env(
+ values, "endpoint_url", "ENDPOINT_URL"
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Return the identifying parameters."""
+ return {
+ "endpoint_url": self.endpoint_url,
+ "model_kwargs": self.model_kwargs or {},
+ }
+
+ def _compute_embeddings(
+ self, texts: List[str], instruction: str
+ ) -> List[List[float]]:
+ """Compute embeddings using an OctoAI instruct model."""
+ from octoai import client
+
+ embeddings = []
+ octoai_client = client.Client(token=self.octoai_api_token)
+
+ for text in texts:
+ parameter_payload = {
+ "sentence": str([text]), # for item in text]),
+ "instruction": str([instruction]), # for item in text]),
+ "parameters": self.model_kwargs or {},
+ }
+
+ try:
+ resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
+ embedding = resp_json["embeddings"]
+ except Exception as e:
+ raise ValueError(f"Error raised by the inference endpoint: {e}") from e
+
+ embeddings.append(embedding)
+
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute document embeddings using an OctoAI instruct model."""
+ texts = list(map(lambda x: x.replace("\n", " "), texts))
+ return self._compute_embeddings(texts, self.embed_instruction)
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embedding using an OctoAI instruct model."""
+ text = text.replace("\n", " ")
+ return self._compute_embeddings([text], self.query_instruction)[0]
diff --git a/libs/community/langchain_community/embeddings/ollama.py b/libs/community/langchain_community/embeddings/ollama.py
new file mode 100644
index 00000000000..9b9830fb0a0
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/ollama.py
@@ -0,0 +1,218 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaEmbeddings(BaseModel, Embeddings):
+ """Ollama locally runs large language models.
+
+ To use, follow the instructions at https://ollama.ai/.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import OllamaEmbeddings
+ ollama_emb = OllamaEmbeddings(
+ model="llama:7b",
+ )
+ r1 = ollama_emb.embed_documents(
+ [
+ "Alpha is the first letter of Greek alphabet",
+ "Beta is the second letter of Greek alphabet",
+ ]
+ )
+ r2 = ollama_emb.embed_query(
+ "What is the second letter of Greek alphabet"
+ )
+
+ """
+
+ base_url: str = "http://localhost:11434"
+ """Base url the model is hosted under."""
+ model: str = "llama2"
+ """Model name to use."""
+
+ embed_instruction: str = "passage: "
+ """Instruction used to embed documents."""
+ query_instruction: str = "query: "
+ """Instruction used to embed the query."""
+
+ mirostat: Optional[int] = None
+ """Enable Mirostat sampling for controlling perplexity.
+ (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
+
+ mirostat_eta: Optional[float] = None
+ """Influences how quickly the algorithm responds to feedback
+ from the generated text. A lower learning rate will result in
+ slower adjustments, while a higher learning rate will make
+ the algorithm more responsive. (Default: 0.1)"""
+
+ mirostat_tau: Optional[float] = None
+ """Controls the balance between coherence and diversity
+ of the output. A lower value will result in more focused and
+ coherent text. (Default: 5.0)"""
+
+ num_ctx: Optional[int] = None
+ """Sets the size of the context window used to generate the
+ next token. (Default: 2048) """
+
+ num_gpu: Optional[int] = None
+ """The number of GPUs to use. On macOS it defaults to 1 to
+ enable metal support, 0 to disable."""
+
+ num_thread: Optional[int] = None
+ """Sets the number of threads to use during computation.
+ By default, Ollama will detect this for optimal performance.
+ It is recommended to set this value to the number of physical
+ CPU cores your system has (as opposed to the logical number of cores)."""
+
+ repeat_last_n: Optional[int] = None
+ """Sets how far back for the model to look back to prevent
+ repetition. (Default: 64, 0 = disabled, -1 = num_ctx)"""
+
+ repeat_penalty: Optional[float] = None
+ """Sets how strongly to penalize repetitions. A higher value (e.g., 1.5)
+ will penalize repetitions more strongly, while a lower value (e.g., 0.9)
+ will be more lenient. (Default: 1.1)"""
+
+ temperature: Optional[float] = None
+ """The temperature of the model. Increasing the temperature will
+ make the model answer more creatively. (Default: 0.8)"""
+
+ stop: Optional[List[str]] = None
+ """Sets the stop tokens to use."""
+
+ tfs_z: Optional[float] = None
+ """Tail free sampling is used to reduce the impact of less probable
+ tokens from the output. A higher value (e.g., 2.0) will reduce the
+ impact more, while a value of 1.0 disables this setting. (default: 1)"""
+
+ top_k: Optional[int] = None
+ """Reduces the probability of generating nonsense. A higher value (e.g. 100)
+ will give more diverse answers, while a lower value (e.g. 10)
+ will be more conservative. (Default: 40)"""
+
+ top_p: Optional[int] = None
+ """Works together with top-k. A higher value (e.g., 0.95) will lead
+ to more diverse text, while a lower value (e.g., 0.5) will
+ generate more focused and conservative text. (Default: 0.9)"""
+
+ show_progress: bool = False
+ """Whether to show a tqdm progress bar. Must have `tqdm` installed."""
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Ollama."""
+ return {
+ "model": self.model,
+ "options": {
+ "mirostat": self.mirostat,
+ "mirostat_eta": self.mirostat_eta,
+ "mirostat_tau": self.mirostat_tau,
+ "num_ctx": self.num_ctx,
+ "num_gpu": self.num_gpu,
+ "num_thread": self.num_thread,
+ "repeat_last_n": self.repeat_last_n,
+ "repeat_penalty": self.repeat_penalty,
+ "temperature": self.temperature,
+ "stop": self.stop,
+ "tfs_z": self.tfs_z,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ },
+ }
+
+ model_kwargs: Optional[dict] = None
+ """Other model keyword args"""
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model": self.model}, **self._default_params}
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _process_emb_response(self, input: str) -> List[float]:
+ """Process a response from the API.
+
+ Args:
+ response: The response from the API.
+
+ Returns:
+ The response as a dictionary.
+ """
+ headers = {
+ "Content-Type": "application/json",
+ }
+
+ try:
+ res = requests.post(
+ f"{self.base_url}/api/embeddings",
+ headers=headers,
+ json={"model": self.model, "prompt": input, **self._default_params},
+ )
+ except requests.exceptions.RequestException as e:
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+
+ if res.status_code != 200:
+ raise ValueError(
+ "Error raised by inference API HTTP code: %s, %s"
+ % (res.status_code, res.text)
+ )
+ try:
+ t = res.json()
+ return t["embedding"]
+ except requests.exceptions.JSONDecodeError as e:
+ raise ValueError(
+ f"Error raised by inference API: {e}.\nResponse: {res.text}"
+ )
+
+ def _embed(self, input: List[str]) -> List[List[float]]:
+ if self.show_progress:
+ try:
+ from tqdm import tqdm
+
+ iter_ = tqdm(input, desc="OllamaEmbeddings")
+ except ImportError:
+ logger.warning(
+ "Unable to show progress bar because tqdm could not be imported. "
+ "Please install with `pip install tqdm`."
+ )
+ iter_ = input
+ else:
+ iter_ = input
+ return [self._process_emb_response(prompt) for prompt in iter_]
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed documents using an Ollama deployed embedding model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ instruction_pairs = [f"{self.embed_instruction}{text}" for text in texts]
+ embeddings = self._embed(instruction_pairs)
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query using a Ollama deployed embedding model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ instruction_pair = f"{self.query_instruction}{text}"
+ embedding = self._embed([instruction_pair])[0]
+ return embedding
diff --git a/libs/community/langchain_community/embeddings/openai.py b/libs/community/langchain_community/embeddings/openai.py
new file mode 100644
index 00000000000..a44b20527e1
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/openai.py
@@ -0,0 +1,708 @@
+from __future__ import annotations
+
+import logging
+import os
+import warnings
+from importlib.metadata import version
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+ cast,
+)
+
+import numpy as np
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
+from packaging.version import Version, parse
+from tenacity import (
+ AsyncRetrying,
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]:
+ import openai
+
+ # Wait 2^x * 1 second between each retry starting with
+ # retry_min_seconds seconds, then up to retry_max_seconds seconds,
+ # then retry_max_seconds seconds afterwards
+ # retry_min_seconds and retry_max_seconds are optional arguments of
+ # OpenAIEmbeddings
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(embeddings.max_retries),
+ wait=wait_exponential(
+ multiplier=1,
+ min=embeddings.retry_min_seconds,
+ max=embeddings.retry_max_seconds,
+ ),
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
+ import openai
+
+ # Wait 2^x * 1 second between each retry starting with
+ # retry_min_seconds seconds, then up to retry_max_seconds seconds,
+ # then retry_max_seconds seconds afterwards
+ # retry_min_seconds and retry_max_seconds are optional arguments of
+ # OpenAIEmbeddings
+ async_retrying = AsyncRetrying(
+ reraise=True,
+ stop=stop_after_attempt(embeddings.max_retries),
+ wait=wait_exponential(
+ multiplier=1,
+ min=embeddings.retry_min_seconds,
+ max=embeddings.retry_max_seconds,
+ ),
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+ def wrap(func: Callable) -> Callable:
+ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
+ async for _ in async_retrying:
+ return await func(*args, **kwargs)
+ raise AssertionError("this is unreachable")
+
+ return wrapped_f
+
+ return wrap
+
+
+# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
+def _check_response(response: dict, skip_empty: bool = False) -> dict:
+ if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty:
+ import openai
+
+ raise openai.error.APIError("OpenAI API returned an empty embedding")
+ return response
+
+
+def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
+ """Use tenacity to retry the embedding call."""
+ if _is_openai_v1():
+ return embeddings.client.create(**kwargs)
+ retry_decorator = _create_retry_decorator(embeddings)
+
+ @retry_decorator
+ def _embed_with_retry(**kwargs: Any) -> Any:
+ response = embeddings.client.create(**kwargs)
+ return _check_response(response, skip_empty=embeddings.skip_empty)
+
+ return _embed_with_retry(**kwargs)
+
+
+async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
+ """Use tenacity to retry the embedding call."""
+
+ if _is_openai_v1():
+ return await embeddings.async_client.create(**kwargs)
+
+ @_async_retry_decorator(embeddings)
+ async def _async_embed_with_retry(**kwargs: Any) -> Any:
+ response = await embeddings.client.acreate(**kwargs)
+ return _check_response(response, skip_empty=embeddings.skip_empty)
+
+ return await _async_embed_with_retry(**kwargs)
+
+
+def _is_openai_v1() -> bool:
+ _version = parse(version("openai"))
+ return _version >= Version("1.0.0")
+
+
+class OpenAIEmbeddings(BaseModel, Embeddings):
+ """OpenAI embedding models.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``OPENAI_API_KEY`` set with your API key or pass it
+ as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import OpenAIEmbeddings
+ openai = OpenAIEmbeddings(openai_api_key="my-api-key")
+
+ In order to use the library with Microsoft Azure endpoints, you need to set
+ the OPENAI_API_TYPE, OPENAI_API_BASE, OPENAI_API_KEY and OPENAI_API_VERSION.
+ The OPENAI_API_TYPE must be set to 'azure' and the others correspond to
+ the properties of your endpoint.
+ In addition, the deployment name must be passed as the model parameter.
+
+ Example:
+ .. code-block:: python
+
+ import os
+
+ os.environ["OPENAI_API_TYPE"] = "azure"
+ os.environ["OPENAI_API_BASE"] = "https:// Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ warnings.warn(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["openai_api_key"] = get_from_dict_or_env(
+ values, "openai_api_key", "OPENAI_API_KEY"
+ )
+ values["openai_api_base"] = values["openai_api_base"] or os.getenv(
+ "OPENAI_API_BASE"
+ )
+ values["openai_api_type"] = get_from_dict_or_env(
+ values,
+ "openai_api_type",
+ "OPENAI_API_TYPE",
+ default="",
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values,
+ "openai_proxy",
+ "OPENAI_PROXY",
+ default="",
+ )
+ if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
+ default_api_version = "2023-05-15"
+ # Azure OpenAI embedding models allow a maximum of 16 texts
+ # at a time in each batch
+ # See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
+ values["chunk_size"] = min(values["chunk_size"], 16)
+ else:
+ default_api_version = ""
+ values["openai_api_version"] = get_from_dict_or_env(
+ values,
+ "openai_api_version",
+ "OPENAI_API_VERSION",
+ default=default_api_version,
+ )
+ # Check OPENAI_ORGANIZATION for backwards compatibility.
+ values["openai_organization"] = (
+ values["openai_organization"]
+ or os.getenv("OPENAI_ORG_ID")
+ or os.getenv("OPENAI_ORGANIZATION")
+ )
+ try:
+ import openai
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ else:
+ if _is_openai_v1():
+ if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
+ warnings.warn(
+ "If you have openai>=1.0.0 installed and are using Azure, "
+ "please use the `AzureOpenAIEmbeddings` class."
+ )
+ client_params = {
+ "api_key": values["openai_api_key"],
+ "organization": values["openai_organization"],
+ "base_url": values["openai_api_base"],
+ "timeout": values["request_timeout"],
+ "max_retries": values["max_retries"],
+ "default_headers": values["default_headers"],
+ "default_query": values["default_query"],
+ "http_client": values["http_client"],
+ }
+ if not values.get("client"):
+ values["client"] = openai.OpenAI(**client_params).embeddings
+ if not values.get("async_client"):
+ values["async_client"] = openai.AsyncOpenAI(
+ **client_params
+ ).embeddings
+ elif not values.get("client"):
+ values["client"] = openai.Embedding
+ else:
+ pass
+ return values
+
+ @property
+ def _invocation_params(self) -> Dict[str, Any]:
+ if _is_openai_v1():
+ openai_args: Dict = {"model": self.model, **self.model_kwargs}
+ else:
+ openai_args = {
+ "model": self.model,
+ "request_timeout": self.request_timeout,
+ "headers": self.headers,
+ "api_key": self.openai_api_key,
+ "organization": self.openai_organization,
+ "api_base": self.openai_api_base,
+ "api_type": self.openai_api_type,
+ "api_version": self.openai_api_version,
+ **self.model_kwargs,
+ }
+ if self.openai_api_type in ("azure", "azure_ad", "azuread"):
+ openai_args["engine"] = self.deployment
+ # TODO: Look into proxy with openai v1.
+ if self.openai_proxy:
+ try:
+ import openai
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+
+ openai.proxy = {
+ "http": self.openai_proxy,
+ "https": self.openai_proxy,
+ } # type: ignore[assignment] # noqa: E501
+ return openai_args
+
+ # please refer to
+ # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
+ def _get_len_safe_embeddings(
+ self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
+ ) -> List[List[float]]:
+ """
+ Generate length-safe embeddings for a list of texts.
+
+ This method handles tokenization and embedding generation, respecting the
+ set embedding context length and chunk size. It supports both tiktoken
+ and HuggingFace tokenizer based on the tiktoken_enabled flag.
+
+ Args:
+ texts (List[str]): A list of texts to embed.
+ engine (str): The engine or model to use for embeddings.
+ chunk_size (Optional[int]): The size of chunks for processing embeddings.
+
+ Returns:
+ List[List[float]]: A list of embeddings for each input text.
+ """
+
+ tokens = []
+ indices = []
+ model_name = self.tiktoken_model_name or self.model
+ _chunk_size = chunk_size or self.chunk_size
+
+ # If tiktoken flag set to False
+ if not self.tiktoken_enabled:
+ try:
+ from transformers import AutoTokenizer
+ except ImportError:
+ raise ValueError(
+ "Could not import transformers python package. "
+ "This is needed in order to for OpenAIEmbeddings without "
+ "`tiktoken`. Please install it with `pip install transformers`. "
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ pretrained_model_name_or_path=model_name
+ )
+ for i, text in enumerate(texts):
+ # Tokenize the text using HuggingFace transformers
+ tokenized = tokenizer.encode(text, add_special_tokens=False)
+
+ # Split tokens into chunks respecting the embedding_ctx_length
+ for j in range(0, len(tokenized), self.embedding_ctx_length):
+ token_chunk = tokenized[j : j + self.embedding_ctx_length]
+
+ # Convert token IDs back to a string
+ chunk_text = tokenizer.decode(token_chunk)
+ tokens.append(chunk_text)
+ indices.append(i)
+ else:
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to for OpenAIEmbeddings. "
+ "Please install it with `pip install tiktoken`."
+ )
+
+ try:
+ encoding = tiktoken.encoding_for_model(model_name)
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ model = "cl100k_base"
+ encoding = tiktoken.get_encoding(model)
+ for i, text in enumerate(texts):
+ if self.model.endswith("001"):
+ # See: https://github.com/openai/openai-python/
+ # issues/418#issuecomment-1525939500
+ # replace newlines, which can negatively affect performance.
+ text = text.replace("\n", " ")
+
+ token = encoding.encode(
+ text=text,
+ allowed_special=self.allowed_special,
+ disallowed_special=self.disallowed_special,
+ )
+
+ # Split tokens into chunks respecting the embedding_ctx_length
+ for j in range(0, len(token), self.embedding_ctx_length):
+ tokens.append(token[j : j + self.embedding_ctx_length])
+ indices.append(i)
+
+ if self.show_progress_bar:
+ try:
+ from tqdm.auto import tqdm
+
+ _iter = tqdm(range(0, len(tokens), _chunk_size))
+ except ImportError:
+ _iter = range(0, len(tokens), _chunk_size)
+ else:
+ _iter = range(0, len(tokens), _chunk_size)
+
+ batched_embeddings: List[List[float]] = []
+ for i in _iter:
+ response = embed_with_retry(
+ self,
+ input=tokens[i : i + _chunk_size],
+ **self._invocation_params,
+ )
+ if not isinstance(response, dict):
+ response = response.dict()
+ batched_embeddings.extend(r["embedding"] for r in response["data"])
+
+ results: List[List[List[float]]] = [[] for _ in range(len(texts))]
+ num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
+ for i in range(len(indices)):
+ if self.skip_empty and len(batched_embeddings[i]) == 1:
+ continue
+ results[indices[i]].append(batched_embeddings[i])
+ num_tokens_in_batch[indices[i]].append(len(tokens[i]))
+
+ embeddings: List[List[float]] = [[] for _ in range(len(texts))]
+ for i in range(len(texts)):
+ _result = results[i]
+ if len(_result) == 0:
+ average_embedded = embed_with_retry(
+ self,
+ input="",
+ **self._invocation_params,
+ )
+ if not isinstance(average_embedded, dict):
+ average_embedded = average_embedded.dict()
+ average = average_embedded["data"][0]["embedding"]
+ else:
+ average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
+ embeddings[i] = (average / np.linalg.norm(average)).tolist()
+
+ return embeddings
+
+ # please refer to
+ # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
+ async def _aget_len_safe_embeddings(
+ self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
+ ) -> List[List[float]]:
+ """
+ Asynchronously generate length-safe embeddings for a list of texts.
+
+ This method handles tokenization and asynchronous embedding generation,
+ respecting the set embedding context length and chunk size. It supports both
+ `tiktoken` and HuggingFace `tokenizer` based on the tiktoken_enabled flag.
+
+ Args:
+ texts (List[str]): A list of texts to embed.
+ engine (str): The engine or model to use for embeddings.
+ chunk_size (Optional[int]): The size of chunks for processing embeddings.
+
+ Returns:
+ List[List[float]]: A list of embeddings for each input text.
+ """
+
+ tokens = []
+ indices = []
+ model_name = self.tiktoken_model_name or self.model
+ _chunk_size = chunk_size or self.chunk_size
+
+ # If tiktoken flag set to False
+ if not self.tiktoken_enabled:
+ try:
+ from transformers import AutoTokenizer
+ except ImportError:
+ raise ValueError(
+ "Could not import transformers python package. "
+ "This is needed in order to for OpenAIEmbeddings without "
+ " `tiktoken`. Please install it with `pip install transformers`."
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ pretrained_model_name_or_path=model_name
+ )
+ for i, text in enumerate(texts):
+ # Tokenize the text using HuggingFace transformers
+ tokenized = tokenizer.encode(text, add_special_tokens=False)
+
+ # Split tokens into chunks respecting the embedding_ctx_length
+ for j in range(0, len(tokenized), self.embedding_ctx_length):
+ token_chunk = tokenized[j : j + self.embedding_ctx_length]
+
+ # Convert token IDs back to a string
+ chunk_text = tokenizer.decode(token_chunk)
+ tokens.append(chunk_text)
+ indices.append(i)
+ else:
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to for OpenAIEmbeddings. "
+ "Please install it with `pip install tiktoken`."
+ )
+
+ try:
+ encoding = tiktoken.encoding_for_model(model_name)
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ model = "cl100k_base"
+ encoding = tiktoken.get_encoding(model)
+ for i, text in enumerate(texts):
+ if self.model.endswith("001"):
+ # See: https://github.com/openai/openai-python/
+ # issues/418#issuecomment-1525939500
+ # replace newlines, which can negatively affect performance.
+ text = text.replace("\n", " ")
+
+ token = encoding.encode(
+ text=text,
+ allowed_special=self.allowed_special,
+ disallowed_special=self.disallowed_special,
+ )
+
+ # Split tokens into chunks respecting the embedding_ctx_length
+ for j in range(0, len(token), self.embedding_ctx_length):
+ tokens.append(token[j : j + self.embedding_ctx_length])
+ indices.append(i)
+
+ batched_embeddings: List[List[float]] = []
+ _chunk_size = chunk_size or self.chunk_size
+ for i in range(0, len(tokens), _chunk_size):
+ response = await async_embed_with_retry(
+ self,
+ input=tokens[i : i + _chunk_size],
+ **self._invocation_params,
+ )
+
+ if not isinstance(response, dict):
+ response = response.dict()
+ batched_embeddings.extend(r["embedding"] for r in response["data"])
+
+ results: List[List[List[float]]] = [[] for _ in range(len(texts))]
+ num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
+ for i in range(len(indices)):
+ results[indices[i]].append(batched_embeddings[i])
+ num_tokens_in_batch[indices[i]].append(len(tokens[i]))
+
+ embeddings: List[List[float]] = [[] for _ in range(len(texts))]
+ for i in range(len(texts)):
+ _result = results[i]
+ if len(_result) == 0:
+ average_embedded = await async_embed_with_retry(
+ self,
+ input="",
+ **self._invocation_params,
+ )
+ if not isinstance(average_embedded, dict):
+ average_embedded = average_embedded.dict()
+ average = average_embedded["data"][0]["embedding"]
+ else:
+ average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
+ embeddings[i] = (average / np.linalg.norm(average)).tolist()
+
+ return embeddings
+
+ def embed_documents(
+ self, texts: List[str], chunk_size: Optional[int] = 0
+ ) -> List[List[float]]:
+ """Call out to OpenAI's embedding endpoint for embedding search docs.
+
+ Args:
+ texts: The list of texts to embed.
+ chunk_size: The chunk size of embeddings. If None, will use the chunk size
+ specified by the class.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ # NOTE: to keep things simple, we assume the list may contain texts longer
+ # than the maximum context and use length-safe embedding function.
+ engine = cast(str, self.deployment)
+ return self._get_len_safe_embeddings(texts, engine=engine)
+
+ async def aembed_documents(
+ self, texts: List[str], chunk_size: Optional[int] = 0
+ ) -> List[List[float]]:
+ """Call out to OpenAI's embedding endpoint async for embedding search docs.
+
+ Args:
+ texts: The list of texts to embed.
+ chunk_size: The chunk size of embeddings. If None, will use the chunk size
+ specified by the class.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ # NOTE: to keep things simple, we assume the list may contain texts longer
+ # than the maximum context and use length-safe embedding function.
+ engine = cast(str, self.deployment)
+ return await self._aget_len_safe_embeddings(texts, engine=engine)
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to OpenAI's embedding endpoint for embedding query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embedding for the text.
+ """
+ return self.embed_documents([text])[0]
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """Call out to OpenAI's embedding endpoint async for embedding query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embedding for the text.
+ """
+ embeddings = await self.aembed_documents([text])
+ return embeddings[0]
diff --git a/libs/community/langchain_community/embeddings/sagemaker_endpoint.py b/libs/community/langchain_community/embeddings/sagemaker_endpoint.py
new file mode 100644
index 00000000000..4e906133f76
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/sagemaker_endpoint.py
@@ -0,0 +1,211 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+from langchain_community.llms.sagemaker_endpoint import ContentHandlerBase
+
+
+class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]):
+ """Content handler for LLM class."""
+
+
+class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
+ """Custom Sagemaker Inference Endpoints.
+
+ To use, you must supply the endpoint name from your deployed
+ Sagemaker model & the region where it is deployed.
+
+ To authenticate, the AWS client uses the following methods to
+ automatically load credentials:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+
+ If a specific credential profile should be used, you must pass
+ the name of the profile from the ~/.aws/credentials file that is to be used.
+
+ Make sure the credentials / roles used have the required policies to
+ access the Sagemaker endpoint.
+ See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
+ """
+
+ """
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import SagemakerEndpointEmbeddings
+ endpoint_name = (
+ "my-endpoint-name"
+ )
+ region_name = (
+ "us-west-2"
+ )
+ credentials_profile_name = (
+ "default"
+ )
+ se = SagemakerEndpointEmbeddings(
+ endpoint_name=endpoint_name,
+ region_name=region_name,
+ credentials_profile_name=credentials_profile_name
+ )
+
+ #Use with boto3 client
+ client = boto3.client(
+ "sagemaker-runtime",
+ region_name=region_name
+ )
+ se = SagemakerEndpointEmbeddings(
+ endpoint_name=endpoint_name,
+ client=client
+ )
+ """
+ client: Any = None
+
+ endpoint_name: str = ""
+ """The name of the endpoint from the deployed Sagemaker model.
+ Must be unique within an AWS Region."""
+
+ region_name: str = ""
+ """The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
+
+ credentials_profile_name: Optional[str] = None
+ """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
+ has either access keys or role information specified.
+ If not specified, the default credential profile or, if on an EC2 instance,
+ credentials from IMDS will be used.
+ See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+ """
+
+ content_handler: EmbeddingsContentHandler
+ """The content handler class that provides an input and
+ output transform functions to handle formats between LLM
+ and the endpoint.
+ """
+
+ """
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
+
+ class ContentHandler(EmbeddingsContentHandler):
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
+ input_str = json.dumps({prompts: prompts, **model_kwargs})
+ return input_str.encode('utf-8')
+
+ def transform_output(self, output: bytes) -> List[List[float]]:
+ response_json = json.loads(output.read().decode("utf-8"))
+ return response_json["vectors"]
+ """ # noqa: E501
+
+ model_kwargs: Optional[Dict] = None
+ """Keyword arguments to pass to the model."""
+
+ endpoint_kwargs: Optional[Dict] = None
+ """Optional attributes passed to the invoke_endpoint
+ function. See `boto3`_. docs for more info.
+ .. _boto3:
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Dont do anything if client provided externally"""
+ if values.get("client") is not None:
+ return values
+
+ """Validate that AWS credentials to and python package exists in environment."""
+ try:
+ import boto3
+
+ try:
+ if values["credentials_profile_name"] is not None:
+ session = boto3.Session(
+ profile_name=values["credentials_profile_name"]
+ )
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ values["client"] = session.client(
+ "sagemaker-runtime", region_name=values["region_name"]
+ )
+
+ except Exception as e:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+
+ except ImportError:
+ raise ImportError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ return values
+
+ def _embedding_func(self, texts: List[str]) -> List[List[float]]:
+ """Call out to SageMaker Inference embedding endpoint."""
+ # replace newlines, which can negatively affect performance.
+ texts = list(map(lambda x: x.replace("\n", " "), texts))
+ _model_kwargs = self.model_kwargs or {}
+ _endpoint_kwargs = self.endpoint_kwargs or {}
+
+ body = self.content_handler.transform_input(texts, _model_kwargs)
+ content_type = self.content_handler.content_type
+ accepts = self.content_handler.accepts
+
+ # send request
+ try:
+ response = self.client.invoke_endpoint(
+ EndpointName=self.endpoint_name,
+ Body=body,
+ ContentType=content_type,
+ Accept=accepts,
+ **_endpoint_kwargs,
+ )
+ except Exception as e:
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+
+ return self.content_handler.transform_output(response["Body"])
+
+ def embed_documents(
+ self, texts: List[str], chunk_size: int = 64
+ ) -> List[List[float]]:
+ """Compute doc embeddings using a SageMaker Inference Endpoint.
+
+ Args:
+ texts: The list of texts to embed.
+ chunk_size: The chunk size defines how many input texts will
+ be grouped together as request. If None, will use the
+ chunk size specified by the class.
+
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ results = []
+ _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
+ for i in range(0, len(texts), _chunk_size):
+ response = self._embedding_func(texts[i : i + _chunk_size])
+ results.extend(response)
+ return results
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a SageMaker inference endpoint.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ return self._embedding_func([text])[0]
diff --git a/libs/community/langchain_community/embeddings/self_hosted.py b/libs/community/langchain_community/embeddings/self_hosted.py
new file mode 100644
index 00000000000..925b64d7943
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/self_hosted.py
@@ -0,0 +1,102 @@
+from typing import Any, Callable, List
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import Extra
+
+from langchain_community.llms.self_hosted import SelfHostedPipeline
+
+
+def _embed_documents(pipeline: Any, *args: Any, **kwargs: Any) -> List[List[float]]:
+ """Inference function to send to the remote hardware.
+
+ Accepts a sentence_transformer model_id and
+ returns a list of embeddings for each document in the batch.
+ """
+ return pipeline(*args, **kwargs)
+
+
+class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
+ """Custom embedding models on self-hosted remote hardware.
+
+ Supported hardware includes auto-launched instances on AWS, GCP, Azure,
+ and Lambda, as well as servers specified
+ by IP address and SSH credentials (such as on-prem, or another
+ cloud like Paperspace, Coreweave, etc.).
+
+ To use, you should have the ``runhouse`` python package installed.
+
+ Example using a model load function:
+ .. code-block:: python
+
+ from langchain_community.embeddings import SelfHostedEmbeddings
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
+ import runhouse as rh
+
+ gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
+ def get_pipeline():
+ model_id = "facebook/bart-large"
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ model = AutoModelForCausalLM.from_pretrained(model_id)
+ return pipeline("feature-extraction", model=model, tokenizer=tokenizer)
+ embeddings = SelfHostedEmbeddings(
+ model_load_fn=get_pipeline,
+ hardware=gpu
+ model_reqs=["./", "torch", "transformers"],
+ )
+ Example passing in a pipeline path:
+ .. code-block:: python
+
+ from langchain_community.embeddings import SelfHostedHFEmbeddings
+ import runhouse as rh
+ from transformers import pipeline
+
+ gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
+ pipeline = pipeline(model="bert-base-uncased", task="feature-extraction")
+ rh.blob(pickle.dumps(pipeline),
+ path="models/pipeline.pkl").save().to(gpu, path="models")
+ embeddings = SelfHostedHFEmbeddings.from_pipeline(
+ pipeline="models/pipeline.pkl",
+ hardware=gpu,
+ model_reqs=["./", "torch", "transformers"],
+ )
+ """
+
+ inference_fn: Callable = _embed_documents
+ """Inference function to extract the embeddings on the remote hardware."""
+ inference_kwargs: Any = None
+ """Any kwargs to pass to the model's inference function."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a HuggingFace transformer model.
+
+ Args:
+ texts: The list of texts to embed.s
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ texts = list(map(lambda x: x.replace("\n", " "), texts))
+ embeddings = self.client(self.pipeline_ref, texts)
+ if not isinstance(embeddings, list):
+ return embeddings.tolist()
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a HuggingFace transformer model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ text = text.replace("\n", " ")
+ embeddings = self.client(self.pipeline_ref, text)
+ if not isinstance(embeddings, list):
+ return embeddings.tolist()
+ return embeddings
diff --git a/libs/community/langchain_community/embeddings/self_hosted_hugging_face.py b/libs/community/langchain_community/embeddings/self_hosted_hugging_face.py
new file mode 100644
index 00000000000..0b706532cf2
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/self_hosted_hugging_face.py
@@ -0,0 +1,168 @@
+import importlib
+import logging
+from typing import Any, Callable, List, Optional
+
+from langchain_community.embeddings.self_hosted import SelfHostedEmbeddings
+
+DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
+DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
+DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
+DEFAULT_QUERY_INSTRUCTION = (
+ "Represent the question for retrieving supporting documents: "
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _embed_documents(client: Any, *args: Any, **kwargs: Any) -> List[List[float]]:
+ """Inference function to send to the remote hardware.
+
+ Accepts a sentence_transformer model_id and
+ returns a list of embeddings for each document in the batch.
+ """
+ return client.encode(*args, **kwargs)
+
+
+def load_embedding_model(model_id: str, instruct: bool = False, device: int = 0) -> Any:
+ """Load the embedding model."""
+ if not instruct:
+ import sentence_transformers
+
+ client = sentence_transformers.SentenceTransformer(model_id)
+ else:
+ from InstructorEmbedding import INSTRUCTOR
+
+ client = INSTRUCTOR(model_id)
+
+ if importlib.util.find_spec("torch") is not None:
+ import torch
+
+ cuda_device_count = torch.cuda.device_count()
+ if device < -1 or (device >= cuda_device_count):
+ raise ValueError(
+ f"Got device=={device}, "
+ f"device is required to be within [-1, {cuda_device_count})"
+ )
+ if device < 0 and cuda_device_count > 0:
+ logger.warning(
+ "Device has %d GPUs available. "
+ "Provide device={deviceId} to `from_model_id` to use available"
+ "GPUs for execution. deviceId is -1 for CPU and "
+ "can be a positive integer associated with CUDA device id.",
+ cuda_device_count,
+ )
+
+ client = client.to(device)
+ return client
+
+
+class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings):
+ """HuggingFace embedding models on self-hosted remote hardware.
+
+ Supported hardware includes auto-launched instances on AWS, GCP, Azure,
+ and Lambda, as well as servers specified
+ by IP address and SSH credentials (such as on-prem, or another cloud
+ like Paperspace, Coreweave, etc.).
+
+ To use, you should have the ``runhouse`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import SelfHostedHuggingFaceEmbeddings
+ import runhouse as rh
+ model_name = "sentence-transformers/all-mpnet-base-v2"
+ gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
+ hf = SelfHostedHuggingFaceEmbeddings(model_name=model_name, hardware=gpu)
+ """
+
+ client: Any #: :meta private:
+ model_id: str = DEFAULT_MODEL_NAME
+ """Model name to use."""
+ model_reqs: List[str] = ["./", "sentence_transformers", "torch"]
+ """Requirements to install on hardware to inference the model."""
+ hardware: Any
+ """Remote hardware to send the inference function to."""
+ model_load_fn: Callable = load_embedding_model
+ """Function to load the model remotely on the server."""
+ load_fn_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model load function."""
+ inference_fn: Callable = _embed_documents
+ """Inference function to extract the embeddings."""
+
+ def __init__(self, **kwargs: Any):
+ """Initialize the remote inference function."""
+ load_fn_kwargs = kwargs.pop("load_fn_kwargs", {})
+ load_fn_kwargs["model_id"] = load_fn_kwargs.get("model_id", DEFAULT_MODEL_NAME)
+ load_fn_kwargs["instruct"] = load_fn_kwargs.get("instruct", False)
+ load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
+ super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
+
+
+class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
+ """HuggingFace InstructEmbedding models on self-hosted remote hardware.
+
+ Supported hardware includes auto-launched instances on AWS, GCP, Azure,
+ and Lambda, as well as servers specified
+ by IP address and SSH credentials (such as on-prem, or another
+ cloud like Paperspace, Coreweave, etc.).
+
+ To use, you should have the ``runhouse`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import SelfHostedHuggingFaceInstructEmbeddings
+ import runhouse as rh
+ model_name = "hkunlp/instructor-large"
+ gpu = rh.cluster(name='rh-a10x', instance_type='A100:1')
+ hf = SelfHostedHuggingFaceInstructEmbeddings(
+ model_name=model_name, hardware=gpu)
+ """ # noqa: E501
+
+ model_id: str = DEFAULT_INSTRUCT_MODEL
+ """Model name to use."""
+ embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
+ """Instruction to use for embedding documents."""
+ query_instruction: str = DEFAULT_QUERY_INSTRUCTION
+ """Instruction to use for embedding query."""
+ model_reqs: List[str] = ["./", "InstructorEmbedding", "torch"]
+ """Requirements to install on hardware to inference the model."""
+
+ def __init__(self, **kwargs: Any):
+ """Initialize the remote inference function."""
+ load_fn_kwargs = kwargs.pop("load_fn_kwargs", {})
+ load_fn_kwargs["model_id"] = load_fn_kwargs.get(
+ "model_id", DEFAULT_INSTRUCT_MODEL
+ )
+ load_fn_kwargs["instruct"] = load_fn_kwargs.get("instruct", True)
+ load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
+ super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a HuggingFace instruct model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ instruction_pairs = []
+ for text in texts:
+ instruction_pairs.append([self.embed_instruction, text])
+ embeddings = self.client(self.pipeline_ref, instruction_pairs)
+ return embeddings.tolist()
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a HuggingFace instruct model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ instruction_pair = [self.query_instruction, text]
+ embedding = self.client(self.pipeline_ref, [instruction_pair])[0]
+ return embedding.tolist()
diff --git a/libs/community/langchain_community/embeddings/sentence_transformer.py b/libs/community/langchain_community/embeddings/sentence_transformer.py
new file mode 100644
index 00000000000..08da53bd447
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/sentence_transformer.py
@@ -0,0 +1,4 @@
+"""HuggingFace sentence_transformer embedding models."""
+from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
+
+SentenceTransformerEmbeddings = HuggingFaceEmbeddings
diff --git a/libs/community/langchain_community/embeddings/spacy_embeddings.py b/libs/community/langchain_community/embeddings/spacy_embeddings.py
new file mode 100644
index 00000000000..eb581d73849
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/spacy_embeddings.py
@@ -0,0 +1,113 @@
+import importlib.util
+from typing import Any, Dict, List
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+
+class SpacyEmbeddings(BaseModel, Embeddings):
+ """Embeddings by SpaCy models.
+
+ It only supports the 'en_core_web_sm' model.
+
+ Attributes:
+ nlp (Any): The Spacy model loaded into memory.
+
+ Methods:
+ embed_documents(texts: List[str]) -> List[List[float]]:
+ Generates embeddings for a list of documents.
+ embed_query(text: str) -> List[float]:
+ Generates an embedding for a single piece of text.
+ """
+
+ nlp: Any # The Spacy model loaded into memory
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid # Forbid extra attributes during model initialization
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """
+ Validates that the Spacy package and the 'en_core_web_sm' model are installed.
+
+ Args:
+ values (Dict): The values provided to the class constructor.
+
+ Returns:
+ The validated values.
+
+ Raises:
+ ValueError: If the Spacy package or the 'en_core_web_sm'
+ model are not installed.
+ """
+ # Check if the Spacy package is installed
+ if importlib.util.find_spec("spacy") is None:
+ raise ValueError(
+ "Spacy package not found. "
+ "Please install it with `pip install spacy`."
+ )
+ try:
+ # Try to load the 'en_core_web_sm' Spacy model
+ import spacy
+
+ values["nlp"] = spacy.load("en_core_web_sm")
+ except OSError:
+ # If the model is not found, raise a ValueError
+ raise ValueError(
+ "Spacy model 'en_core_web_sm' not found. "
+ "Please install it with"
+ " `python -m spacy download en_core_web_sm`."
+ )
+ return values # Return the validated values
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """
+ Generates embeddings for a list of documents.
+
+ Args:
+ texts (List[str]): The documents to generate embeddings for.
+
+ Returns:
+ A list of embeddings, one for each document.
+ """
+ return [self.nlp(text).vector.tolist() for text in texts]
+
+ def embed_query(self, text: str) -> List[float]:
+ """
+ Generates an embedding for a single piece of text.
+
+ Args:
+ text (str): The text to generate an embedding for.
+
+ Returns:
+ The embedding for the text.
+ """
+ return self.nlp(text).vector.tolist()
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ """
+ Asynchronously generates embeddings for a list of documents.
+ This method is not implemented and raises a NotImplementedError.
+
+ Args:
+ texts (List[str]): The documents to generate embeddings for.
+
+ Raises:
+ NotImplementedError: This method is not implemented.
+ """
+ raise NotImplementedError("Asynchronous embedding generation is not supported.")
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """
+ Asynchronously generates an embedding for a single piece of text.
+ This method is not implemented and raises a NotImplementedError.
+
+ Args:
+ text (str): The text to generate an embedding for.
+
+ Raises:
+ NotImplementedError: This method is not implemented.
+ """
+ raise NotImplementedError("Asynchronous embedding generation is not supported.")
diff --git a/libs/community/langchain_community/embeddings/tensorflow_hub.py b/libs/community/langchain_community/embeddings/tensorflow_hub.py
new file mode 100644
index 00000000000..1cc01ca91ca
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/tensorflow_hub.py
@@ -0,0 +1,75 @@
+from typing import Any, List
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
+
+
+class TensorflowHubEmbeddings(BaseModel, Embeddings):
+ """TensorflowHub embedding models.
+
+ To use, you should have the ``tensorflow_text`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import TensorflowHubEmbeddings
+ url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
+ tf = TensorflowHubEmbeddings(model_url=url)
+ """
+
+ embed: Any #: :meta private:
+ model_url: str = DEFAULT_MODEL_URL
+ """Model name to use."""
+
+ def __init__(self, **kwargs: Any):
+ """Initialize the tensorflow_hub and tensorflow_text."""
+ super().__init__(**kwargs)
+ try:
+ import tensorflow_hub
+ except ImportError:
+ raise ImportError(
+ "Could not import tensorflow-hub python package. "
+ "Please install it with `pip install tensorflow-hub``."
+ )
+ try:
+ import tensorflow_text # noqa
+ except ImportError:
+ raise ImportError(
+ "Could not import tensorflow_text python package. "
+ "Please install it with `pip install tensorflow_text``."
+ )
+
+ self.embed = tensorflow_hub.load(self.model_url)
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Compute doc embeddings using a TensorflowHub embedding model.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ texts = list(map(lambda x: x.replace("\n", " "), texts))
+ embeddings = self.embed(texts).numpy()
+ return embeddings.tolist()
+
+ def embed_query(self, text: str) -> List[float]:
+ """Compute query embeddings using a TensorflowHub embedding model.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embeddings for the text.
+ """
+ text = text.replace("\n", " ")
+ embedding = self.embed([text]).numpy()[0]
+ return embedding.tolist()
diff --git a/libs/community/langchain_community/embeddings/vertexai.py b/libs/community/langchain_community/embeddings/vertexai.py
new file mode 100644
index 00000000000..3264fef18d5
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/vertexai.py
@@ -0,0 +1,56 @@
+from typing import Dict, List
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import root_validator
+
+from langchain_community.llms.vertexai import _VertexAICommon
+from langchain_community.utilities.vertexai import raise_vertex_import_error
+
+
+class VertexAIEmbeddings(_VertexAICommon, Embeddings):
+ """Google Cloud VertexAI embedding models."""
+
+ model_name: str = "textembedding-gecko"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validates that the python package exists in environment."""
+ cls._try_init_vertexai(values)
+ try:
+ from vertexai.language_models import TextEmbeddingModel
+ except ImportError:
+ raise_vertex_import_error()
+ values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
+ return values
+
+ def embed_documents(
+ self, texts: List[str], batch_size: int = 5
+ ) -> List[List[float]]:
+ """Embed a list of strings. Vertex AI currently
+ sets a max batch size of 5 strings.
+
+ Args:
+ texts: List[str] The list of strings to embed.
+ batch_size: [int] The batch size of embeddings to send to the model
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ embeddings = []
+ for batch in range(0, len(texts), batch_size):
+ text_batch = texts[batch : batch + batch_size]
+ embeddings_batch = self.client.get_embeddings(text_batch)
+ embeddings.extend([el.values for el in embeddings_batch])
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embedding for the text.
+ """
+ embeddings = self.client.get_embeddings([text])
+ return embeddings[0].values
diff --git a/libs/community/langchain_community/embeddings/voyageai.py b/libs/community/langchain_community/embeddings/voyageai.py
new file mode 100644
index 00000000000..93109d45c65
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/voyageai.py
@@ -0,0 +1,195 @@
+from __future__ import annotations
+
+import json
+import logging
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
+
+import requests
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+from tenacity import (
+ before_sleep_log,
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator(embeddings: VoyageEmbeddings) -> Callable[[Any], Any]:
+ min_seconds = 4
+ max_seconds = 10
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(embeddings.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def _check_response(response: dict) -> dict:
+ if "data" not in response:
+ raise RuntimeError(f"Voyage API Error. Message: {json.dumps(response)}")
+ return response
+
+
+def embed_with_retry(embeddings: VoyageEmbeddings, **kwargs: Any) -> Any:
+ """Use tenacity to retry the embedding call."""
+ retry_decorator = _create_retry_decorator(embeddings)
+
+ @retry_decorator
+ def _embed_with_retry(**kwargs: Any) -> Any:
+ response = requests.post(**kwargs)
+ return _check_response(response.json())
+
+ return _embed_with_retry(**kwargs)
+
+
+class VoyageEmbeddings(BaseModel, Embeddings):
+ """Voyage embedding models.
+
+ To use, you should have the environment variable ``VOYAGE_API_KEY`` set with
+ your API key or pass it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import VoyageEmbeddings
+
+ voyage = VoyageEmbeddings(voyage_api_key="your-api-key")
+ text = "This is a test query."
+ query_result = voyage.embed_query(text)
+ """
+
+ model: str = "voyage-01"
+ voyage_api_base: str = "https://api.voyageai.com/v1/embeddings"
+ voyage_api_key: Optional[SecretStr] = None
+ batch_size: int = 8
+ """Maximum number of texts to embed in each API request."""
+ max_retries: int = 6
+ """Maximum number of retries to make when generating."""
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
+ """Timeout in seconds for the API request."""
+ show_progress_bar: bool = False
+ """Whether to show a progress bar when embedding. Must have tqdm installed if set
+ to True."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["voyage_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY")
+ )
+ return values
+
+ def _invocation_params(
+ self, input: List[str], input_type: Optional[str] = None
+ ) -> Dict:
+ api_key = cast(SecretStr, self.voyage_api_key).get_secret_value()
+ params = {
+ "url": self.voyage_api_base,
+ "headers": {"Authorization": f"Bearer {api_key}"},
+ "json": {"model": self.model, "input": input, "input_type": input_type},
+ "timeout": self.request_timeout,
+ }
+ return params
+
+ def _get_embeddings(
+ self,
+ texts: List[str],
+ batch_size: Optional[int] = None,
+ input_type: Optional[str] = None,
+ ) -> List[List[float]]:
+ embeddings: List[List[float]] = []
+
+ if batch_size is None:
+ batch_size = self.batch_size
+
+ if self.show_progress_bar:
+ try:
+ from tqdm.auto import tqdm
+ except ImportError as e:
+ raise ImportError(
+ "Must have tqdm installed if `show_progress_bar` is set to True. "
+ "Please install with `pip install tqdm`."
+ ) from e
+
+ _iter = tqdm(range(0, len(texts), batch_size))
+ else:
+ _iter = range(0, len(texts), batch_size)
+
+ if input_type and input_type not in ["query", "document"]:
+ raise ValueError(
+ f"input_type {input_type} is invalid. Options: None, 'query', "
+ "'document'."
+ )
+
+ for i in _iter:
+ response = embed_with_retry(
+ self,
+ **self._invocation_params(
+ input=texts[i : i + batch_size], input_type=input_type
+ ),
+ )
+ embeddings.extend(r["embedding"] for r in response["data"])
+
+ return embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Call out to Voyage Embedding endpoint for embedding search docs.
+
+ Args:
+ texts: The list of texts to embed.
+
+ Returns:
+ List of embeddings, one for each text.
+ """
+ return self._get_embeddings(
+ texts, batch_size=self.batch_size, input_type="document"
+ )
+
+ def embed_query(self, text: str) -> List[float]:
+ """Call out to Voyage Embedding endpoint for embedding query text.
+
+ Args:
+ text: The text to embed.
+
+ Returns:
+ Embedding for the text.
+ """
+ return self._get_embeddings([text], input_type="query")[0]
+
+ def embed_general_texts(
+ self, texts: List[str], *, input_type: Optional[str] = None
+ ) -> List[List[float]]:
+ """Call out to Voyage Embedding endpoint for embedding general text.
+
+ Args:
+ texts: The list of texts to embed.
+ input_type: Type of the input text. Default to None, meaning the type is
+ unspecified. Other options: query, document.
+
+ Returns:
+ Embedding for the text.
+ """
+ return self._get_embeddings(
+ texts, batch_size=self.batch_size, input_type=input_type
+ )
diff --git a/libs/community/langchain_community/embeddings/xinference.py b/libs/community/langchain_community/embeddings/xinference.py
new file mode 100644
index 00000000000..db9b56fcbf9
--- /dev/null
+++ b/libs/community/langchain_community/embeddings/xinference.py
@@ -0,0 +1,124 @@
+"""Wrapper around Xinference embedding models."""
+from typing import Any, List, Optional
+
+from langchain_core.embeddings import Embeddings
+
+
+class XinferenceEmbeddings(Embeddings):
+
+ """Xinference embedding models.
+
+ To use, you should have the xinference library installed:
+
+ .. code-block:: bash
+
+ pip install xinference
+
+ Check out: https://github.com/xorbitsai/inference
+ To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers.
+
+ Example:
+ To start a local instance of Xinference, run
+
+ .. code-block:: bash
+
+ $ xinference
+
+ You can also deploy Xinference in a distributed cluster. Here are the steps:
+
+ Starting the supervisor:
+
+ .. code-block:: bash
+
+ $ xinference-supervisor
+
+ Starting the worker:
+
+ .. code-block:: bash
+
+ $ xinference-worker
+
+ Then, launch a model using command line interface (CLI).
+
+ Example:
+
+ .. code-block:: bash
+
+ $ xinference launch -n orca -s 3 -q q4_0
+
+ It will return a model UID. Then you can use Xinference Embedding with LangChain.
+
+ Example:
+
+ .. code-block:: python
+
+ from langchain_community.embeddings import XinferenceEmbeddings
+
+ xinference = XinferenceEmbeddings(
+ server_url="http://0.0.0.0:9997",
+ model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
+ )
+
+ """ # noqa: E501
+
+ client: Any
+ server_url: Optional[str]
+ """URL of the xinference server"""
+ model_uid: Optional[str]
+ """UID of the launched model"""
+
+ def __init__(
+ self, server_url: Optional[str] = None, model_uid: Optional[str] = None
+ ):
+ try:
+ from xinference.client import RESTfulClient
+ except ImportError as e:
+ raise ImportError(
+ "Could not import RESTfulClient from xinference. Please install it"
+ " with `pip install xinference`."
+ ) from e
+
+ super().__init__()
+
+ if server_url is None:
+ raise ValueError("Please provide server URL")
+
+ if model_uid is None:
+ raise ValueError("Please provide the model UID")
+
+ self.server_url = server_url
+
+ self.model_uid = model_uid
+
+ self.client = RESTfulClient(server_url)
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed a list of documents using Xinference.
+ Args:
+ texts: The list of texts to embed.
+ Returns:
+ List of embeddings, one for each text.
+ """
+
+ model = self.client.get_model(self.model_uid)
+
+ embeddings = [
+ model.create_embedding(text)["data"][0]["embedding"] for text in texts
+ ]
+ return [list(map(float, e)) for e in embeddings]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query of documents using Xinference.
+ Args:
+ text: The text to embed.
+ Returns:
+ Embeddings for the text.
+ """
+
+ model = self.client.get_model(self.model_uid)
+
+ embedding_res = model.create_embedding(text)
+
+ embedding = embedding_res["data"][0]["embedding"]
+
+ return list(map(float, embedding))
diff --git a/libs/community/langchain_community/graphs/__init__.py b/libs/community/langchain_community/graphs/__init__.py
new file mode 100644
index 00000000000..7de3bdbc7bd
--- /dev/null
+++ b/libs/community/langchain_community/graphs/__init__.py
@@ -0,0 +1,25 @@
+"""**Graphs** provide a natural language interface to graph databases."""
+
+from langchain_community.graphs.arangodb_graph import ArangoGraph
+from langchain_community.graphs.falkordb_graph import FalkorDBGraph
+from langchain_community.graphs.hugegraph import HugeGraph
+from langchain_community.graphs.kuzu_graph import KuzuGraph
+from langchain_community.graphs.memgraph_graph import MemgraphGraph
+from langchain_community.graphs.nebula_graph import NebulaGraph
+from langchain_community.graphs.neo4j_graph import Neo4jGraph
+from langchain_community.graphs.neptune_graph import NeptuneGraph
+from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
+from langchain_community.graphs.rdf_graph import RdfGraph
+
+__all__ = [
+ "MemgraphGraph",
+ "NetworkxEntityGraph",
+ "Neo4jGraph",
+ "NebulaGraph",
+ "NeptuneGraph",
+ "KuzuGraph",
+ "HugeGraph",
+ "RdfGraph",
+ "ArangoGraph",
+ "FalkorDBGraph",
+]
diff --git a/libs/community/langchain_community/graphs/arangodb_graph.py b/libs/community/langchain_community/graphs/arangodb_graph.py
new file mode 100644
index 00000000000..b9e4530058e
--- /dev/null
+++ b/libs/community/langchain_community/graphs/arangodb_graph.py
@@ -0,0 +1,182 @@
+import os
+from math import ceil
+from typing import Any, Dict, List, Optional
+
+
+class ArangoGraph:
+ """ArangoDB wrapper for graph operations.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(self, db: Any) -> None:
+ """Create a new ArangoDB graph wrapper instance."""
+ self.set_db(db)
+ self.set_schema()
+
+ @property
+ def db(self) -> Any:
+ return self.__db
+
+ @property
+ def schema(self) -> Dict[str, Any]:
+ return self.__schema
+
+ def set_db(self, db: Any) -> None:
+ from arango.database import Database
+
+ if not isinstance(db, Database):
+ msg = "**db** parameter must inherit from arango.database.Database"
+ raise TypeError(msg)
+
+ self.__db: Database = db
+ self.set_schema()
+
+ def set_schema(self, schema: Optional[Dict[str, Any]] = None) -> None:
+ """
+ Set the schema of the ArangoDB Database.
+ Auto-generates Schema if **schema** is None.
+ """
+ self.__schema = self.generate_schema() if schema is None else schema
+
+ def generate_schema(
+ self, sample_ratio: float = 0
+ ) -> Dict[str, List[Dict[str, Any]]]:
+ """
+ Generates the schema of the ArangoDB Database and returns it
+ User can specify a **sample_ratio** (0 to 1) to determine the
+ ratio of documents/edges used (in relation to the Collection size)
+ to render each Collection Schema.
+ """
+ if not 0 <= sample_ratio <= 1:
+ raise ValueError("**sample_ratio** value must be in between 0 to 1")
+
+ # Stores the Edge Relationships between each ArangoDB Document Collection
+ graph_schema: List[Dict[str, Any]] = [
+ {"graph_name": g["name"], "edge_definitions": g["edge_definitions"]}
+ for g in self.db.graphs()
+ ]
+
+ # Stores the schema of every ArangoDB Document/Edge collection
+ collection_schema: List[Dict[str, Any]] = []
+
+ for collection in self.db.collections():
+ if collection["system"]:
+ continue
+
+ # Extract collection name, type, and size
+ col_name: str = collection["name"]
+ col_type: str = collection["type"]
+ col_size: int = self.db.collection(col_name).count()
+
+ # Skip collection if empty
+ if col_size == 0:
+ continue
+
+ # Set number of ArangoDB documents/edges to retrieve
+ limit_amount = ceil(sample_ratio * col_size) or 1
+
+ aql = f"""
+ FOR doc in {col_name}
+ LIMIT {limit_amount}
+ RETURN doc
+ """
+
+ doc: Dict[str, Any]
+ properties: List[Dict[str, str]] = []
+ for doc in self.__db.aql.execute(aql):
+ for key, value in doc.items():
+ properties.append({"name": key, "type": type(value).__name__})
+
+ collection_schema.append(
+ {
+ "collection_name": col_name,
+ "collection_type": col_type,
+ f"{col_type}_properties": properties,
+ f"example_{col_type}": doc,
+ }
+ )
+
+ return {"Graph Schema": graph_schema, "Collection Schema": collection_schema}
+
+ def query(
+ self, query: str, top_k: Optional[int] = None, **kwargs: Any
+ ) -> List[Dict[str, Any]]:
+ """Query the ArangoDB database."""
+ import itertools
+
+ cursor = self.__db.aql.execute(query, **kwargs)
+ return [doc for doc in itertools.islice(cursor, top_k)]
+
+ @classmethod
+ def from_db_credentials(
+ cls,
+ url: Optional[str] = None,
+ dbname: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ ) -> Any:
+ """Convenience constructor that builds Arango DB from credentials.
+
+ Args:
+ url: Arango DB url. Can be passed in as named arg or set as environment
+ var ``ARANGODB_URL``. Defaults to "http://localhost:8529".
+ dbname: Arango DB name. Can be passed in as named arg or set as
+ environment var ``ARANGODB_DBNAME``. Defaults to "_system".
+ username: Can be passed in as named arg or set as environment var
+ ``ARANGODB_USERNAME``. Defaults to "root".
+ password: Can be passed ni as named arg or set as environment var
+ ``ARANGODB_PASSWORD``. Defaults to "".
+
+ Returns:
+ An arango.database.StandardDatabase.
+ """
+ db = get_arangodb_client(
+ url=url, dbname=dbname, username=username, password=password
+ )
+ return cls(db)
+
+
+def get_arangodb_client(
+ url: Optional[str] = None,
+ dbname: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+) -> Any:
+ """Get the Arango DB client from credentials.
+
+ Args:
+ url: Arango DB url. Can be passed in as named arg or set as environment
+ var ``ARANGODB_URL``. Defaults to "http://localhost:8529".
+ dbname: Arango DB name. Can be passed in as named arg or set as
+ environment var ``ARANGODB_DBNAME``. Defaults to "_system".
+ username: Can be passed in as named arg or set as environment var
+ ``ARANGODB_USERNAME``. Defaults to "root".
+ password: Can be passed ni as named arg or set as environment var
+ ``ARANGODB_PASSWORD``. Defaults to "".
+
+ Returns:
+ An arango.database.StandardDatabase.
+ """
+ try:
+ from arango import ArangoClient
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import arango, please install with `pip install python-arango`."
+ ) from e
+
+ _url: str = url or os.environ.get("ARANGODB_URL", "http://localhost:8529") # type: ignore[assignment] # noqa: E501
+ _dbname: str = dbname or os.environ.get("ARANGODB_DBNAME", "_system") # type: ignore[assignment] # noqa: E501
+ _username: str = username or os.environ.get("ARANGODB_USERNAME", "root") # type: ignore[assignment] # noqa: E501
+ _password: str = password or os.environ.get("ARANGODB_PASSWORD", "") # type: ignore[assignment] # noqa: E501
+
+ return ArangoClient(_url).db(_dbname, _username, _password, verify=True)
diff --git a/libs/community/langchain_community/graphs/falkordb_graph.py b/libs/community/langchain_community/graphs/falkordb_graph.py
new file mode 100644
index 00000000000..e23d01d4ed6
--- /dev/null
+++ b/libs/community/langchain_community/graphs/falkordb_graph.py
@@ -0,0 +1,147 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_community.graphs.graph_document import GraphDocument
+from langchain_community.graphs.graph_store import GraphStore
+
+node_properties_query = """
+MATCH (n)
+WITH keys(n) as keys, labels(n) AS labels
+WITH CASE WHEN keys = [] THEN [NULL] ELSE keys END AS keys, labels
+UNWIND labels AS label
+UNWIND keys AS key
+WITH label, collect(DISTINCT key) AS keys
+RETURN {label:label, keys:keys} AS output
+"""
+
+rel_properties_query = """
+MATCH ()-[r]->()
+WITH keys(r) as keys, type(r) AS types
+WITH CASE WHEN keys = [] THEN [NULL] ELSE keys END AS keys, types
+UNWIND types AS type
+UNWIND keys AS key WITH type,
+collect(DISTINCT key) AS keys
+RETURN {types:type, keys:keys} AS output
+"""
+
+rel_query = """
+MATCH (n)-[r]->(m)
+UNWIND labels(n) as src_label
+UNWIND labels(m) as dst_label
+UNWIND type(r) as rel_type
+RETURN DISTINCT {start: src_label, type: rel_type, end: dst_label} AS output
+"""
+
+
+class FalkorDBGraph(GraphStore):
+ """FalkorDB wrapper for graph operations.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(
+ self,
+ database: str,
+ host: str = "localhost",
+ port: int = 6379,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ ssl: bool = False,
+ ) -> None:
+ """Create a new FalkorDB graph wrapper instance."""
+ try:
+ import redis
+ from redis.commands.graph import Graph
+ except ImportError:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ )
+
+ self._driver = redis.Redis(
+ host=host, port=port, username=username, password=password, ssl=ssl
+ )
+ self._graph = Graph(self._driver, database)
+ self.schema: str = ""
+ self.structured_schema: Dict[str, Any] = {}
+
+ try:
+ self.refresh_schema()
+ except Exception as e:
+ raise ValueError(f"Could not refresh schema. Error: {e}")
+
+ @property
+ def get_schema(self) -> str:
+ """Returns the schema of the FalkorDB database"""
+ return self.schema
+
+ @property
+ def get_structured_schema(self) -> Dict[str, Any]:
+ """Returns the structured schema of the Graph"""
+ return self.structured_schema
+
+ def refresh_schema(self) -> None:
+ """Refreshes the schema of the FalkorDB database"""
+ node_properties: List[Any] = self.query(node_properties_query)
+ rel_properties: List[Any] = self.query(rel_properties_query)
+ relationships: List[Any] = self.query(rel_query)
+
+ self.structured_schema = {
+ "node_props": {el[0]["label"]: el[0]["keys"] for el in node_properties},
+ "rel_props": {el[0]["types"]: el[0]["keys"] for el in rel_properties},
+ "relationships": [el[0] for el in relationships],
+ }
+
+ self.schema = (
+ f"Node properties: {node_properties}\n"
+ f"Relationships properties: {rel_properties}\n"
+ f"Relationships: {relationships}\n"
+ )
+
+ def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
+ """Query FalkorDB database."""
+
+ try:
+ data = self._graph.query(query, params)
+ return data.result_set
+ except Exception as e:
+ raise ValueError("Generated Cypher Statement is not valid\n" f"{e}")
+
+ def add_graph_documents(
+ self, graph_documents: List[GraphDocument], include_source: bool = False
+ ) -> None:
+ """
+ Take GraphDocument as input as uses it to construct a graph.
+ """
+ for document in graph_documents:
+ # Import nodes
+ for node in document.nodes:
+ self.query(
+ (
+ f"MERGE (n:{node.type} {{id:'{node.id}'}}) "
+ "SET n += $properties "
+ "RETURN distinct 'done' AS result"
+ ),
+ {"properties": node.properties},
+ )
+
+ # Import relationships
+ for rel in document.relationships:
+ self.query(
+ (
+ f"MATCH (a:{rel.source.type} {{id:'{rel.source.id}'}}), "
+ f"(b:{rel.target.type} {{id:'{rel.target.id}'}}) "
+ f"MERGE (a)-[r:{(rel.type.replace(' ', '_').upper())}]->(b) "
+ "SET r += $properties "
+ "RETURN distinct 'done' AS result"
+ ),
+ {"properties": rel.properties},
+ )
diff --git a/libs/community/langchain_community/graphs/graph_document.py b/libs/community/langchain_community/graphs/graph_document.py
new file mode 100644
index 00000000000..3e9a597cc56
--- /dev/null
+++ b/libs/community/langchain_community/graphs/graph_document.py
@@ -0,0 +1,51 @@
+from __future__ import annotations
+
+from typing import List, Union
+
+from langchain_core.documents import Document
+from langchain_core.load.serializable import Serializable
+from langchain_core.pydantic_v1 import Field
+
+
+class Node(Serializable):
+ """Represents a node in a graph with associated properties.
+
+ Attributes:
+ id (Union[str, int]): A unique identifier for the node.
+ type (str): The type or label of the node, default is "Node".
+ properties (dict): Additional properties and metadata associated with the node.
+ """
+
+ id: Union[str, int]
+ type: str = "Node"
+ properties: dict = Field(default_factory=dict)
+
+
+class Relationship(Serializable):
+ """Represents a directed relationship between two nodes in a graph.
+
+ Attributes:
+ source (Node): The source node of the relationship.
+ target (Node): The target node of the relationship.
+ type (str): The type of the relationship.
+ properties (dict): Additional properties associated with the relationship.
+ """
+
+ source: Node
+ target: Node
+ type: str
+ properties: dict = Field(default_factory=dict)
+
+
+class GraphDocument(Serializable):
+ """Represents a graph document consisting of nodes and relationships.
+
+ Attributes:
+ nodes (List[Node]): A list of nodes in the graph.
+ relationships (List[Relationship]): A list of relationships in the graph.
+ source (Document): The document from which the graph information is derived.
+ """
+
+ nodes: List[Node]
+ relationships: List[Relationship]
+ source: Document
diff --git a/libs/community/langchain_community/graphs/graph_store.py b/libs/community/langchain_community/graphs/graph_store.py
new file mode 100644
index 00000000000..0618eae48dd
--- /dev/null
+++ b/libs/community/langchain_community/graphs/graph_store.py
@@ -0,0 +1,37 @@
+from abc import abstractmethod
+from typing import Any, Dict, List
+
+from langchain_community.graphs.graph_document import GraphDocument
+
+
+class GraphStore:
+ """An abstract class wrapper for graph operations."""
+
+ @property
+ @abstractmethod
+ def get_schema(self) -> str:
+ """Returns the schema of the Graph database"""
+ pass
+
+ @property
+ @abstractmethod
+ def get_structured_schema(self) -> Dict[str, Any]:
+ """Returns the schema of the Graph database"""
+ pass
+
+ @abstractmethod
+ def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
+ """Query the graph."""
+ pass
+
+ @abstractmethod
+ def refresh_schema(self) -> None:
+ """Refreshes the graph schema information."""
+ pass
+
+ @abstractmethod
+ def add_graph_documents(
+ self, graph_documents: List[GraphDocument], include_source: bool = False
+ ) -> None:
+ """Take GraphDocument as input as uses it to construct a graph."""
+ pass
diff --git a/libs/community/langchain_community/graphs/hugegraph.py b/libs/community/langchain_community/graphs/hugegraph.py
new file mode 100644
index 00000000000..a052efce4d0
--- /dev/null
+++ b/libs/community/langchain_community/graphs/hugegraph.py
@@ -0,0 +1,74 @@
+from typing import Any, Dict, List
+
+
+class HugeGraph:
+ """HugeGraph wrapper for graph operations.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(
+ self,
+ username: str = "default",
+ password: str = "default",
+ address: str = "127.0.0.1",
+ port: int = 8081,
+ graph: str = "hugegraph",
+ ) -> None:
+ """Create a new HugeGraph wrapper instance."""
+ try:
+ from hugegraph.connection import PyHugeGraph
+ except ImportError:
+ raise ValueError(
+ "Please install HugeGraph Python client first: "
+ "`pip3 install hugegraph-python`"
+ )
+
+ self.username = username
+ self.password = password
+ self.address = address
+ self.port = port
+ self.graph = graph
+ self.client = PyHugeGraph(
+ address, port, user=username, pwd=password, graph=graph
+ )
+ self.schema = ""
+ # Set schema
+ try:
+ self.refresh_schema()
+ except Exception as e:
+ raise ValueError(f"Could not refresh schema. Error: {e}")
+
+ @property
+ def get_schema(self) -> str:
+ """Returns the schema of the HugeGraph database"""
+ return self.schema
+
+ def refresh_schema(self) -> None:
+ """
+ Refreshes the HugeGraph schema information.
+ """
+ schema = self.client.schema()
+ vertex_schema = schema.getVertexLabels()
+ edge_schema = schema.getEdgeLabels()
+ relationships = schema.getRelations()
+
+ self.schema = (
+ f"Node properties: {vertex_schema}\n"
+ f"Edge properties: {edge_schema}\n"
+ f"Relationships: {relationships}\n"
+ )
+
+ def query(self, query: str) -> List[Dict[str, Any]]:
+ g = self.client.gremlin()
+ res = g.exec(query)
+ return res["data"]
diff --git a/libs/community/langchain_community/graphs/kuzu_graph.py b/libs/community/langchain_community/graphs/kuzu_graph.py
new file mode 100644
index 00000000000..eda7417f940
--- /dev/null
+++ b/libs/community/langchain_community/graphs/kuzu_graph.py
@@ -0,0 +1,102 @@
+from typing import Any, Dict, List
+
+
+class KuzuGraph:
+ """KΓΉzu wrapper for graph operations.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(self, db: Any, database: str = "kuzu") -> None:
+ try:
+ import kuzu
+ except ImportError:
+ raise ImportError(
+ "Could not import KΓΉzu python package."
+ "Please install KΓΉzu with `pip install kuzu`."
+ )
+ self.db = db
+ self.conn = kuzu.Connection(self.db)
+ self.database = database
+ self.refresh_schema()
+
+ @property
+ def get_schema(self) -> str:
+ """Returns the schema of the KΓΉzu database"""
+ return self.schema
+
+ def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
+ """Query KΓΉzu database"""
+ params_list = []
+ for param_name in params:
+ params_list.append([param_name, params[param_name]])
+ result = self.conn.execute(query, params_list)
+ column_names = result.get_column_names()
+ return_list = []
+ while result.has_next():
+ row = result.get_next()
+ return_list.append(dict(zip(column_names, row)))
+ return return_list
+
+ def refresh_schema(self) -> None:
+ """Refreshes the KΓΉzu graph schema information"""
+ node_properties = []
+ node_table_names = self.conn._get_node_table_names()
+ for table_name in node_table_names:
+ current_table_schema = {"properties": [], "label": table_name}
+ properties = self.conn._get_node_property_names(table_name)
+ for property_name in properties:
+ property_type = properties[property_name]["type"]
+ list_type_flag = ""
+ if properties[property_name]["dimension"] > 0:
+ if "shape" in properties[property_name]:
+ for s in properties[property_name]["shape"]:
+ list_type_flag += "[%s]" % s
+ else:
+ for i in range(properties[property_name]["dimension"]):
+ list_type_flag += "[]"
+ property_type += list_type_flag
+ current_table_schema["properties"].append(
+ (property_name, property_type)
+ )
+ node_properties.append(current_table_schema)
+
+ relationships = []
+ rel_tables = self.conn._get_rel_table_names()
+ for table in rel_tables:
+ relationships.append(
+ "(:%s)-[:%s]->(:%s)" % (table["src"], table["name"], table["dst"])
+ )
+
+ rel_properties = []
+ for table in rel_tables:
+ current_table_schema = {"properties": [], "label": table["name"]}
+ properties_text = self.conn._connection.get_rel_property_names(
+ table["name"]
+ ).split("\n")
+ for i, line in enumerate(properties_text):
+ # The first 3 lines defines src, dst and name, so we skip them
+ if i < 3:
+ continue
+ if not line:
+ continue
+ property_name, property_type = line.strip().split(" ")
+ current_table_schema["properties"].append(
+ (property_name, property_type)
+ )
+ rel_properties.append(current_table_schema)
+
+ self.schema = (
+ f"Node properties: {node_properties}\n"
+ f"Relationships properties: {rel_properties}\n"
+ f"Relationships: {relationships}\n"
+ )
diff --git a/libs/community/langchain_community/graphs/memgraph_graph.py b/libs/community/langchain_community/graphs/memgraph_graph.py
new file mode 100644
index 00000000000..2df4612a2c0
--- /dev/null
+++ b/libs/community/langchain_community/graphs/memgraph_graph.py
@@ -0,0 +1,48 @@
+from langchain_community.graphs.neo4j_graph import Neo4jGraph
+
+SCHEMA_QUERY = """
+CALL llm_util.schema("prompt_ready")
+YIELD *
+RETURN *
+"""
+
+RAW_SCHEMA_QUERY = """
+CALL llm_util.schema("raw")
+YIELD *
+RETURN *
+"""
+
+
+class MemgraphGraph(Neo4jGraph):
+ """Memgraph wrapper for graph operations.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(
+ self, url: str, username: str, password: str, *, database: str = "memgraph"
+ ) -> None:
+ """Create a new Memgraph graph wrapper instance."""
+ super().__init__(url, username, password, database=database)
+
+ def refresh_schema(self) -> None:
+ """
+ Refreshes the Memgraph graph schema information.
+ """
+
+ db_schema = self.query(SCHEMA_QUERY)[0].get("schema")
+ assert db_schema is not None
+ self.schema = db_schema
+
+ db_structured_schema = self.query(RAW_SCHEMA_QUERY)[0].get("schema")
+ assert db_structured_schema is not None
+ self.structured_schema = db_structured_schema
diff --git a/libs/community/langchain_community/graphs/nebula_graph.py b/libs/community/langchain_community/graphs/nebula_graph.py
new file mode 100644
index 00000000000..a1b25e81c28
--- /dev/null
+++ b/libs/community/langchain_community/graphs/nebula_graph.py
@@ -0,0 +1,216 @@
+import logging
+from string import Template
+from typing import Any, Dict, Optional
+
+logger = logging.getLogger(__name__)
+
+rel_query = Template(
+ """
+MATCH ()-[e:`$edge_type`]->()
+ WITH e limit 1
+MATCH (m)-[:`$edge_type`]->(n) WHERE id(m) == src(e) AND id(n) == dst(e)
+RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels
+"""
+)
+
+RETRY_TIMES = 3
+
+
+class NebulaGraph:
+ """NebulaGraph wrapper for graph operations.
+
+ NebulaGraph inherits methods from Neo4jGraph to bring ease to the user space.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(
+ self,
+ space: str,
+ username: str = "root",
+ password: str = "nebula",
+ address: str = "127.0.0.1",
+ port: int = 9669,
+ session_pool_size: int = 30,
+ ) -> None:
+ """Create a new NebulaGraph wrapper instance."""
+ try:
+ import nebula3 # noqa: F401
+ import pandas # noqa: F401
+ except ImportError:
+ raise ValueError(
+ "Please install NebulaGraph Python client and pandas first: "
+ "`pip install nebula3-python pandas`"
+ )
+
+ self.username = username
+ self.password = password
+ self.address = address
+ self.port = port
+ self.space = space
+ self.session_pool_size = session_pool_size
+
+ self.session_pool = self._get_session_pool()
+ self.schema = ""
+ # Set schema
+ try:
+ self.refresh_schema()
+ except Exception as e:
+ raise ValueError(f"Could not refresh schema. Error: {e}")
+
+ def _get_session_pool(self) -> Any:
+ assert all(
+ [self.username, self.password, self.address, self.port, self.space]
+ ), (
+ "Please provide all of the following parameters: "
+ "username, password, address, port, space"
+ )
+
+ from nebula3.Config import SessionPoolConfig
+ from nebula3.Exception import AuthFailedException, InValidHostname
+ from nebula3.gclient.net.SessionPool import SessionPool
+
+ config = SessionPoolConfig()
+ config.max_size = self.session_pool_size
+
+ try:
+ session_pool = SessionPool(
+ self.username,
+ self.password,
+ self.space,
+ [(self.address, self.port)],
+ )
+ except InValidHostname:
+ raise ValueError(
+ "Could not connect to NebulaGraph database. "
+ "Please ensure that the address and port are correct"
+ )
+
+ try:
+ session_pool.init(config)
+ except AuthFailedException:
+ raise ValueError(
+ "Could not connect to NebulaGraph database. "
+ "Please ensure that the username and password are correct"
+ )
+ except RuntimeError as e:
+ raise ValueError(f"Error initializing session pool. Error: {e}")
+
+ return session_pool
+
+ def __del__(self) -> None:
+ try:
+ self.session_pool.close()
+ except Exception as e:
+ logger.warning(f"Could not close session pool. Error: {e}")
+
+ @property
+ def get_schema(self) -> str:
+ """Returns the schema of the NebulaGraph database"""
+ return self.schema
+
+ def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any:
+ """Query NebulaGraph database."""
+ from nebula3.Exception import IOErrorException, NoValidSessionException
+ from nebula3.fbthrift.transport.TTransport import TTransportException
+
+ params = params or {}
+ try:
+ result = self.session_pool.execute_parameter(query, params)
+ if not result.is_succeeded():
+ logger.warning(
+ f"Error executing query to NebulaGraph. "
+ f"Error: {result.error_msg()}\n"
+ f"Query: {query} \n"
+ )
+ return result
+
+ except NoValidSessionException:
+ logger.warning(
+ f"No valid session found in session pool. "
+ f"Please consider increasing the session pool size. "
+ f"Current size: {self.session_pool_size}"
+ )
+ raise ValueError(
+ f"No valid session found in session pool. "
+ f"Please consider increasing the session pool size. "
+ f"Current size: {self.session_pool_size}"
+ )
+
+ except RuntimeError as e:
+ if retry < RETRY_TIMES:
+ retry += 1
+ logger.warning(
+ f"Error executing query to NebulaGraph. "
+ f"Retrying ({retry}/{RETRY_TIMES})...\n"
+ f"query: {query} \n"
+ f"Error: {e}"
+ )
+ return self.execute(query, params, retry)
+ else:
+ raise ValueError(f"Error executing query to NebulaGraph. Error: {e}")
+
+ except (TTransportException, IOErrorException):
+ # connection issue, try to recreate session pool
+ if retry < RETRY_TIMES:
+ retry += 1
+ logger.warning(
+ f"Connection issue with NebulaGraph. "
+ f"Retrying ({retry}/{RETRY_TIMES})...\n to recreate session pool"
+ )
+ self.session_pool = self._get_session_pool()
+ return self.execute(query, params, retry)
+
+ def refresh_schema(self) -> None:
+ """
+ Refreshes the NebulaGraph schema information.
+ """
+ tags_schema, edge_types_schema, relationships = [], [], []
+ for tag in self.execute("SHOW TAGS").column_values("Name"):
+ tag_name = tag.cast()
+ tag_schema = {"tag": tag_name, "properties": []}
+ r = self.execute(f"DESCRIBE TAG `{tag_name}`")
+ props, types = r.column_values("Field"), r.column_values("Type")
+ for i in range(r.row_size()):
+ tag_schema["properties"].append((props[i].cast(), types[i].cast()))
+ tags_schema.append(tag_schema)
+ for edge_type in self.execute("SHOW EDGES").column_values("Name"):
+ edge_type_name = edge_type.cast()
+ edge_schema = {"edge": edge_type_name, "properties": []}
+ r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`")
+ props, types = r.column_values("Field"), r.column_values("Type")
+ for i in range(r.row_size()):
+ edge_schema["properties"].append((props[i].cast(), types[i].cast()))
+ edge_types_schema.append(edge_schema)
+
+ # build relationships types
+ r = self.execute(
+ rel_query.substitute(edge_type=edge_type_name)
+ ).column_values("rels")
+ if len(r) > 0:
+ relationships.append(r[0].cast())
+
+ self.schema = (
+ f"Node properties: {tags_schema}\n"
+ f"Edge properties: {edge_types_schema}\n"
+ f"Relationships: {relationships}\n"
+ )
+
+ def query(self, query: str, retry: int = 0) -> Dict[str, Any]:
+ result = self.execute(query, retry=retry)
+ columns = result.keys()
+ d: Dict[str, list] = {}
+ for col_num in range(result.col_size()):
+ col_name = columns[col_num]
+ col_list = result.column_values(col_name)
+ d[col_name] = [x.cast() for x in col_list]
+ return d
diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py
new file mode 100644
index 00000000000..af416b29bc5
--- /dev/null
+++ b/libs/community/langchain_community/graphs/neo4j_graph.py
@@ -0,0 +1,215 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.utils import get_from_env
+
+from langchain_community.graphs.graph_document import GraphDocument
+from langchain_community.graphs.graph_store import GraphStore
+
+node_properties_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
+WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
+RETURN {labels: nodeLabels, properties: properties} AS output
+
+"""
+
+rel_properties_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
+WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
+RETURN {type: nodeLabels, properties: properties} AS output
+"""
+
+rel_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE type = "RELATIONSHIP" AND elementType = "node"
+UNWIND other AS other_node
+RETURN {start: label, type: property, end: toString(other_node)} AS output
+"""
+
+
+class Neo4jGraph(GraphStore):
+ """Neo4j wrapper for graph operations.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(
+ self,
+ url: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ database: str = "neo4j",
+ ) -> None:
+ """Create a new Neo4j graph wrapper instance."""
+ try:
+ import neo4j
+ except ImportError:
+ raise ValueError(
+ "Could not import neo4j python package. "
+ "Please install it with `pip install neo4j`."
+ )
+
+ url = get_from_env("url", "NEO4J_URI", url)
+ username = get_from_env("username", "NEO4J_USERNAME", username)
+ password = get_from_env("password", "NEO4J_PASSWORD", password)
+ database = get_from_env("database", "NEO4J_DATABASE", database)
+
+ self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
+ self._database = database
+ self.schema: str = ""
+ self.structured_schema: Dict[str, Any] = {}
+ # Verify connection
+ try:
+ self._driver.verify_connectivity()
+ except neo4j.exceptions.ServiceUnavailable:
+ raise ValueError(
+ "Could not connect to Neo4j database. "
+ "Please ensure that the url is correct"
+ )
+ except neo4j.exceptions.AuthError:
+ raise ValueError(
+ "Could not connect to Neo4j database. "
+ "Please ensure that the username and password are correct"
+ )
+ # Set schema
+ try:
+ self.refresh_schema()
+ except neo4j.exceptions.ClientError:
+ raise ValueError(
+ "Could not use APOC procedures. "
+ "Please ensure the APOC plugin is installed in Neo4j and that "
+ "'apoc.meta.data()' is allowed in Neo4j configuration "
+ )
+
+ @property
+ def get_schema(self) -> str:
+ """Returns the schema of the Graph"""
+ return self.schema
+
+ @property
+ def get_structured_schema(self) -> Dict[str, Any]:
+ """Returns the structured schema of the Graph"""
+ return self.structured_schema
+
+ def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
+ """Query Neo4j database."""
+ from neo4j.exceptions import CypherSyntaxError
+
+ with self._driver.session(database=self._database) as session:
+ try:
+ data = session.run(query, params)
+ return [r.data() for r in data]
+ except CypherSyntaxError as e:
+ raise ValueError(f"Generated Cypher Statement is not valid\n{e}")
+
+ def refresh_schema(self) -> None:
+ """
+ Refreshes the Neo4j graph schema information.
+ """
+ node_properties = [el["output"] for el in self.query(node_properties_query)]
+ rel_properties = [el["output"] for el in self.query(rel_properties_query)]
+ relationships = [el["output"] for el in self.query(rel_query)]
+
+ self.structured_schema = {
+ "node_props": {el["labels"]: el["properties"] for el in node_properties},
+ "rel_props": {el["type"]: el["properties"] for el in rel_properties},
+ "relationships": relationships,
+ }
+
+ # Format node properties
+ formatted_node_props = []
+ for el in node_properties:
+ props_str = ", ".join(
+ [f"{prop['property']}: {prop['type']}" for prop in el["properties"]]
+ )
+ formatted_node_props.append(f"{el['labels']} {{{props_str}}}")
+
+ # Format relationship properties
+ formatted_rel_props = []
+ for el in rel_properties:
+ props_str = ", ".join(
+ [f"{prop['property']}: {prop['type']}" for prop in el["properties"]]
+ )
+ formatted_rel_props.append(f"{el['type']} {{{props_str}}}")
+
+ # Format relationships
+ formatted_rels = [
+ f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" for el in relationships
+ ]
+
+ self.schema = "\n".join(
+ [
+ "Node properties are the following:",
+ ",".join(formatted_node_props),
+ "Relationship properties are the following:",
+ ",".join(formatted_rel_props),
+ "The relationships are the following:",
+ ",".join(formatted_rels),
+ ]
+ )
+
+ def add_graph_documents(
+ self, graph_documents: List[GraphDocument], include_source: bool = False
+ ) -> None:
+ """
+ Take GraphDocument as input as uses it to construct a graph.
+ """
+ for document in graph_documents:
+ include_docs_query = (
+ "CREATE (d:Document) "
+ "SET d.text = $document.page_content "
+ "SET d += $document.metadata "
+ "WITH d "
+ )
+ # Import nodes
+ self.query(
+ (
+ f"{include_docs_query if include_source else ''}"
+ "UNWIND $data AS row "
+ "CALL apoc.merge.node([row.type], {id: row.id}, "
+ "row.properties, {}) YIELD node "
+ f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
+ "RETURN distinct 'done' AS result"
+ ),
+ {
+ "data": [el.__dict__ for el in document.nodes],
+ "document": document.source.__dict__,
+ },
+ )
+ # Import relationships
+ self.query(
+ "UNWIND $data AS row "
+ "CALL apoc.merge.node([row.source_label], {id: row.source},"
+ "{}, {}) YIELD node as source "
+ "CALL apoc.merge.node([row.target_label], {id: row.target},"
+ "{}, {}) YIELD node as target "
+ "CALL apoc.merge.relationship(source, row.type, "
+ "{}, row.properties, target) YIELD rel "
+ "RETURN distinct 'done'",
+ {
+ "data": [
+ {
+ "source": el.source.id,
+ "source_label": el.source.type,
+ "target": el.target.id,
+ "target_label": el.target.type,
+ "type": el.type.replace(" ", "_").upper(),
+ "properties": el.properties,
+ }
+ for el in document.relationships
+ ]
+ },
+ )
diff --git a/libs/community/langchain_community/graphs/neptune_graph.py b/libs/community/langchain_community/graphs/neptune_graph.py
new file mode 100644
index 00000000000..6dc45d12e94
--- /dev/null
+++ b/libs/community/langchain_community/graphs/neptune_graph.py
@@ -0,0 +1,270 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+
+class NeptuneQueryException(Exception):
+ """A class to handle queries that fail to execute"""
+
+ def __init__(self, exception: Union[str, Dict]):
+ if isinstance(exception, dict):
+ self.message = exception["message"] if "message" in exception else "unknown"
+ self.details = exception["details"] if "details" in exception else "unknown"
+ else:
+ self.message = exception
+ self.details = "unknown"
+
+ def get_message(self) -> str:
+ return self.message
+
+ def get_details(self) -> Any:
+ return self.details
+
+
+class NeptuneGraph:
+ """Neptune wrapper for graph operations.
+
+ Args:
+ host: endpoint for the database instance
+ port: port number for the database instance, default is 8182
+ use_https: whether to use secure connection, default is True
+ client: optional boto3 Neptune client
+ credentials_profile_name: optional AWS profile name
+ region_name: optional AWS region, e.g., us-west-2
+ service: optional service name, default is neptunedata
+ sign: optional, whether to sign the request payload, default is True
+
+ Example:
+ .. code-block:: python
+
+ graph = NeptuneGraph(
+ host='',
+ port=8182
+ )
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(
+ self,
+ host: str,
+ port: int = 8182,
+ use_https: bool = True,
+ client: Any = None,
+ credentials_profile_name: Optional[str] = None,
+ region_name: Optional[str] = None,
+ service: str = "neptunedata",
+ sign: bool = True,
+ ) -> None:
+ """Create a new Neptune graph wrapper instance."""
+
+ try:
+ if client is not None:
+ self.client = client
+ else:
+ import boto3
+
+ if credentials_profile_name is not None:
+ session = boto3.Session(profile_name=credentials_profile_name)
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ client_params = {}
+ if region_name:
+ client_params["region_name"] = region_name
+
+ protocol = "https" if use_https else "http"
+
+ client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
+
+ if sign:
+ self.client = session.client(service, **client_params)
+ else:
+ from botocore import UNSIGNED
+ from botocore.config import Config
+
+ self.client = session.client(
+ service,
+ **client_params,
+ config=Config(signature_version=UNSIGNED),
+ )
+
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ except Exception as e:
+ if type(e).__name__ == "UnknownServiceError":
+ raise ModuleNotFoundError(
+ "NeptuneGraph requires a boto3 version 1.28.38 or greater."
+ "Please install it with `pip install -U boto3`."
+ ) from e
+ else:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+
+ try:
+ self._refresh_schema()
+ except Exception as e:
+ raise NeptuneQueryException(
+ {
+ "message": "Could not get schema for Neptune database",
+ "detail": str(e),
+ }
+ )
+
+ @property
+ def get_schema(self) -> str:
+ """Returns the schema of the Neptune database"""
+ return self.schema
+
+ def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
+ """Query Neptune database."""
+ try:
+ return self.client.execute_open_cypher_query(openCypherQuery=query)
+ except Exception as e:
+ raise NeptuneQueryException(
+ {
+ "message": "An error occurred while executing the query.",
+ "details": str(e),
+ }
+ )
+
+ def _get_summary(self) -> Dict:
+ try:
+ response = self.client.get_propertygraph_summary()
+ except Exception as e:
+ raise NeptuneQueryException(
+ {
+ "message": (
+ "Summary API is not available for this instance of Neptune,"
+ "ensure the engine version is >=1.2.1.0"
+ ),
+ "details": str(e),
+ }
+ )
+
+ try:
+ summary = response["payload"]["graphSummary"]
+ except Exception:
+ raise NeptuneQueryException(
+ {
+ "message": "Summary API did not return a valid response.",
+ "details": response.content.decode(),
+ }
+ )
+ else:
+ return summary
+
+ def _get_labels(self) -> Tuple[List[str], List[str]]:
+ """Get node and edge labels from the Neptune statistics summary"""
+ summary = self._get_summary()
+ n_labels = summary["nodeLabels"]
+ e_labels = summary["edgeLabels"]
+ return n_labels, e_labels
+
+ def _get_triples(self, e_labels: List[str]) -> List[str]:
+ triple_query = """
+ MATCH (a)-[e:`{e_label}`]->(b)
+ WITH a,e,b LIMIT 3000
+ RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
+ LIMIT 10
+ """
+
+ triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)"
+ triple_schema = []
+ for label in e_labels:
+ q = triple_query.format(e_label=label)
+ data = self.query(q)
+ for d in data["results"]:
+ triple = triple_template.format(
+ a=d["from"][0], e=d["edge"], b=d["to"][0]
+ )
+ triple_schema.append(triple)
+
+ return triple_schema
+
+ def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
+ node_properties_query = """
+ MATCH (a:`{n_label}`)
+ RETURN properties(a) AS props
+ LIMIT 100
+ """
+ node_properties = []
+ for label in n_labels:
+ q = node_properties_query.format(n_label=label)
+ data = {"label": label, "properties": self.query(q)["results"]}
+ s = set({})
+ for p in data["properties"]:
+ for k, v in p["props"].items():
+ s.add((k, types[type(v).__name__]))
+
+ np = {
+ "properties": [{"property": k, "type": v} for k, v in s],
+ "labels": label,
+ }
+ node_properties.append(np)
+
+ return node_properties
+
+ def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
+ edge_properties_query = """
+ MATCH ()-[e:`{e_label}`]->()
+ RETURN properties(e) AS props
+ LIMIT 100
+ """
+ edge_properties = []
+ for label in e_labels:
+ q = edge_properties_query.format(e_label=label)
+ data = {"label": label, "properties": self.query(q)["results"]}
+ s = set({})
+ for p in data["properties"]:
+ for k, v in p["props"].items():
+ s.add((k, types[type(v).__name__]))
+
+ ep = {
+ "type": label,
+ "properties": [{"property": k, "type": v} for k, v in s],
+ }
+ edge_properties.append(ep)
+
+ return edge_properties
+
+ def _refresh_schema(self) -> None:
+ """
+ Refreshes the Neptune graph schema information.
+ """
+
+ types = {
+ "str": "STRING",
+ "float": "DOUBLE",
+ "int": "INTEGER",
+ "list": "LIST",
+ "dict": "MAP",
+ "bool": "BOOLEAN",
+ }
+ n_labels, e_labels = self._get_labels()
+ triple_schema = self._get_triples(e_labels)
+ node_properties = self._get_node_properties(n_labels, types)
+ edge_properties = self._get_edge_properties(e_labels, types)
+
+ self.schema = f"""
+ Node properties are the following:
+ {node_properties}
+ Relationship properties are the following:
+ {edge_properties}
+ The relationships are the following:
+ {triple_schema}
+ """
diff --git a/libs/community/langchain_community/graphs/networkx_graph.py b/libs/community/langchain_community/graphs/networkx_graph.py
new file mode 100644
index 00000000000..81b7862fab2
--- /dev/null
+++ b/libs/community/langchain_community/graphs/networkx_graph.py
@@ -0,0 +1,181 @@
+"""Networkx wrapper for graph operations."""
+from __future__ import annotations
+
+from typing import Any, List, NamedTuple, Optional, Tuple
+
+KG_TRIPLE_DELIMITER = "<|>"
+
+
+class KnowledgeTriple(NamedTuple):
+ """A triple in the graph."""
+
+ subject: str
+ predicate: str
+ object_: str
+
+ @classmethod
+ def from_string(cls, triple_string: str) -> "KnowledgeTriple":
+ """Create a KnowledgeTriple from a string."""
+ subject, predicate, object_ = triple_string.strip().split(", ")
+ subject = subject[1:]
+ object_ = object_[:-1]
+ return cls(subject, predicate, object_)
+
+
+def parse_triples(knowledge_str: str) -> List[KnowledgeTriple]:
+ """Parse knowledge triples from the knowledge string."""
+ knowledge_str = knowledge_str.strip()
+ if not knowledge_str or knowledge_str == "NONE":
+ return []
+ triple_strs = knowledge_str.split(KG_TRIPLE_DELIMITER)
+ results = []
+ for triple_str in triple_strs:
+ try:
+ kg_triple = KnowledgeTriple.from_string(triple_str)
+ except ValueError:
+ continue
+ results.append(kg_triple)
+ return results
+
+
+def get_entities(entity_str: str) -> List[str]:
+ """Extract entities from entity string."""
+ if entity_str.strip() == "NONE":
+ return []
+ else:
+ return [w.strip() for w in entity_str.split(",")]
+
+
+class NetworkxEntityGraph:
+ """Networkx wrapper for entity graph operations.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(self, graph: Optional[Any] = None) -> None:
+ """Create a new graph."""
+ try:
+ import networkx as nx
+ except ImportError:
+ raise ImportError(
+ "Could not import networkx python package. "
+ "Please install it with `pip install networkx`."
+ )
+ if graph is not None:
+ if not isinstance(graph, nx.DiGraph):
+ raise ValueError("Passed in graph is not of correct shape")
+ self._graph = graph
+ else:
+ self._graph = nx.DiGraph()
+
+ @classmethod
+ def from_gml(cls, gml_path: str) -> NetworkxEntityGraph:
+ try:
+ import networkx as nx
+ except ImportError:
+ raise ImportError(
+ "Could not import networkx python package. "
+ "Please install it with `pip install networkx`."
+ )
+ graph = nx.read_gml(gml_path)
+ return cls(graph)
+
+ def add_triple(self, knowledge_triple: KnowledgeTriple) -> None:
+ """Add a triple to the graph."""
+ # Creates nodes if they don't exist
+ # Overwrites existing edges
+ if not self._graph.has_node(knowledge_triple.subject):
+ self._graph.add_node(knowledge_triple.subject)
+ if not self._graph.has_node(knowledge_triple.object_):
+ self._graph.add_node(knowledge_triple.object_)
+ self._graph.add_edge(
+ knowledge_triple.subject,
+ knowledge_triple.object_,
+ relation=knowledge_triple.predicate,
+ )
+
+ def delete_triple(self, knowledge_triple: KnowledgeTriple) -> None:
+ """Delete a triple from the graph."""
+ if self._graph.has_edge(knowledge_triple.subject, knowledge_triple.object_):
+ self._graph.remove_edge(knowledge_triple.subject, knowledge_triple.object_)
+
+ def get_triples(self) -> List[Tuple[str, str, str]]:
+ """Get all triples in the graph."""
+ return [(u, v, d["relation"]) for u, v, d in self._graph.edges(data=True)]
+
+ def get_entity_knowledge(self, entity: str, depth: int = 1) -> List[str]:
+ """Get information about an entity."""
+ import networkx as nx
+
+ # TODO: Have more information-specific retrieval methods
+ if not self._graph.has_node(entity):
+ return []
+
+ results = []
+ for src, sink in nx.dfs_edges(self._graph, entity, depth_limit=depth):
+ relation = self._graph[src][sink]["relation"]
+ results.append(f"{src} {relation} {sink}")
+ return results
+
+ def write_to_gml(self, path: str) -> None:
+ import networkx as nx
+
+ nx.write_gml(self._graph, path)
+
+ def clear(self) -> None:
+ """Clear the graph."""
+ self._graph.clear()
+
+ def get_topological_sort(self) -> List[str]:
+ """Get a list of entity names in the graph sorted by causal dependence."""
+ import networkx as nx
+
+ return list(nx.topological_sort(self._graph))
+
+ def draw_graphviz(self, **kwargs: Any) -> None:
+ """
+ Provides better drawing
+
+ Usage in a jupyter notebook:
+
+ >>> from IPython.display import SVG
+ >>> self.draw_graphviz_svg(layout="dot", filename="web.svg")
+ >>> SVG('web.svg')
+ """
+ from networkx.drawing.nx_agraph import to_agraph
+
+ try:
+ import pygraphviz # noqa: F401
+
+ except ImportError as e:
+ if e.name == "_graphviz":
+ """
+ >>> e.msg # pygraphviz throws this error
+ ImportError: libcgraph.so.6: cannot open shared object file
+ """
+ raise ImportError(
+ "Could not import graphviz debian package. "
+ "Please install it with:"
+ "`sudo apt-get update`"
+ "`sudo apt-get install graphviz graphviz-dev`"
+ )
+ else:
+ raise ImportError(
+ "Could not import pygraphviz python package. "
+ "Please install it with:"
+ "`pip install pygraphviz`."
+ )
+
+ graph = to_agraph(self._graph) # --> pygraphviz.agraph.AGraph
+ # pygraphviz.github.io/documentation/stable/tutorial.html#layout-and-drawing
+ graph.layout(prog=kwargs.get("prog", "dot"))
+ graph.draw(kwargs.get("path", "graph.svg"))
diff --git a/libs/community/langchain_community/graphs/rdf_graph.py b/libs/community/langchain_community/graphs/rdf_graph.py
new file mode 100644
index 00000000000..1a2b89ba87b
--- /dev/null
+++ b/libs/community/langchain_community/graphs/rdf_graph.py
@@ -0,0 +1,297 @@
+from __future__ import annotations
+
+from typing import (
+ TYPE_CHECKING,
+ List,
+ Optional,
+)
+
+if TYPE_CHECKING:
+ import rdflib
+
+prefixes = {
+ "owl": """PREFIX owl: \n""",
+ "rdf": """PREFIX rdf: \n""",
+ "rdfs": """PREFIX rdfs: \n""",
+ "xsd": """PREFIX xsd: \n""",
+}
+
+cls_query_rdf = prefixes["rdfs"] + (
+ """SELECT DISTINCT ?cls ?com\n"""
+ """WHERE { \n"""
+ """ ?instance a ?cls . \n"""
+ """ OPTIONAL { ?cls rdfs:comment ?com } \n"""
+ """}"""
+)
+
+cls_query_rdfs = prefixes["rdfs"] + (
+ """SELECT DISTINCT ?cls ?com\n"""
+ """WHERE { \n"""
+ """ ?instance a/rdfs:subClassOf* ?cls . \n"""
+ """ OPTIONAL { ?cls rdfs:comment ?com } \n"""
+ """}"""
+)
+
+cls_query_owl = prefixes["rdfs"] + (
+ """SELECT DISTINCT ?cls ?com\n"""
+ """WHERE { \n"""
+ """ ?instance a/rdfs:subClassOf* ?cls . \n"""
+ """ FILTER (isIRI(?cls)) . \n"""
+ """ OPTIONAL { ?cls rdfs:comment ?com } \n"""
+ """}"""
+)
+
+rel_query_rdf = prefixes["rdfs"] + (
+ """SELECT DISTINCT ?rel ?com\n"""
+ """WHERE { \n"""
+ """ ?subj ?rel ?obj . \n"""
+ """ OPTIONAL { ?rel rdfs:comment ?com } \n"""
+ """}"""
+)
+
+rel_query_rdfs = (
+ prefixes["rdf"]
+ + prefixes["rdfs"]
+ + (
+ """SELECT DISTINCT ?rel ?com\n"""
+ """WHERE { \n"""
+ """ ?rel a/rdfs:subPropertyOf* rdf:Property . \n"""
+ """ OPTIONAL { ?rel rdfs:comment ?com } \n"""
+ """}"""
+ )
+)
+
+op_query_owl = (
+ prefixes["rdfs"]
+ + prefixes["owl"]
+ + (
+ """SELECT DISTINCT ?op ?com\n"""
+ """WHERE { \n"""
+ """ ?op a/rdfs:subPropertyOf* owl:ObjectProperty . \n"""
+ """ OPTIONAL { ?op rdfs:comment ?com } \n"""
+ """}"""
+ )
+)
+
+dp_query_owl = (
+ prefixes["rdfs"]
+ + prefixes["owl"]
+ + (
+ """SELECT DISTINCT ?dp ?com\n"""
+ """WHERE { \n"""
+ """ ?dp a/rdfs:subPropertyOf* owl:DatatypeProperty . \n"""
+ """ OPTIONAL { ?dp rdfs:comment ?com } \n"""
+ """}"""
+ )
+)
+
+
+class RdfGraph:
+ """RDFlib wrapper for graph operations.
+
+ Modes:
+ * local: Local file - can be queried and changed
+ * online: Online file - can only be queried, changes can be stored locally
+ * store: Triple store - can be queried and changed if update_endpoint available
+ Together with a source file, the serialization should be specified.
+
+ *Security note*: Make sure that the database connection uses credentials
+ that are narrowly-scoped to only include necessary permissions.
+ Failure to do so may result in data corruption or loss, since the calling
+ code may attempt commands that would result in deletion, mutation
+ of data if appropriately prompted or reading sensitive data if such
+ data is present in the database.
+ The best way to guard against such negative outcomes is to (as appropriate)
+ limit the permissions granted to the credentials used with this tool.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ def __init__(
+ self,
+ source_file: Optional[str] = None,
+ serialization: Optional[str] = "ttl",
+ query_endpoint: Optional[str] = None,
+ update_endpoint: Optional[str] = None,
+ standard: Optional[str] = "rdf",
+ local_copy: Optional[str] = None,
+ ) -> None:
+ """
+ Set up the RDFlib graph
+
+ :param source_file: either a path for a local file or a URL
+ :param serialization: serialization of the input
+ :param query_endpoint: SPARQL endpoint for queries, read access
+ :param update_endpoint: SPARQL endpoint for UPDATE queries, write access
+ :param standard: RDF, RDFS, or OWL
+ :param local_copy: new local copy for storing changes
+ """
+ self.source_file = source_file
+ self.serialization = serialization
+ self.query_endpoint = query_endpoint
+ self.update_endpoint = update_endpoint
+ self.standard = standard
+ self.local_copy = local_copy
+
+ try:
+ import rdflib
+ from rdflib.graph import DATASET_DEFAULT_GRAPH_ID as default
+ from rdflib.plugins.stores import sparqlstore
+ except ImportError:
+ raise ValueError(
+ "Could not import rdflib python package. "
+ "Please install it with `pip install rdflib`."
+ )
+ if self.standard not in (supported_standards := ("rdf", "rdfs", "owl")):
+ raise ValueError(
+ f"Invalid standard. Supported standards are: {supported_standards}."
+ )
+
+ if (
+ not source_file
+ and not query_endpoint
+ or source_file
+ and (query_endpoint or update_endpoint)
+ ):
+ raise ValueError(
+ "Could not unambiguously initialize the graph wrapper. "
+ "Specify either a file (local or online) via the source_file "
+ "or a triple store via the endpoints."
+ )
+
+ if source_file:
+ if source_file.startswith("http"):
+ self.mode = "online"
+ else:
+ self.mode = "local"
+ if self.local_copy is None:
+ self.local_copy = self.source_file
+ self.graph = rdflib.Graph()
+ self.graph.parse(source_file, format=self.serialization)
+
+ if query_endpoint:
+ self.mode = "store"
+ if not update_endpoint:
+ self._store = sparqlstore.SPARQLStore()
+ self._store.open(query_endpoint)
+ else:
+ self._store = sparqlstore.SPARQLUpdateStore()
+ self._store.open((query_endpoint, update_endpoint))
+ self.graph = rdflib.Graph(self._store, identifier=default)
+
+ # Verify that the graph was loaded
+ if not len(self.graph):
+ raise AssertionError("The graph is empty.")
+
+ # Set schema
+ self.schema = ""
+ self.load_schema()
+
+ @property
+ def get_schema(self) -> str:
+ """
+ Returns the schema of the graph database.
+ """
+ return self.schema
+
+ def query(
+ self,
+ query: str,
+ ) -> List[rdflib.query.ResultRow]:
+ """
+ Query the graph.
+ """
+ from rdflib.exceptions import ParserError
+ from rdflib.query import ResultRow
+
+ try:
+ res = self.graph.query(query)
+ except ParserError as e:
+ raise ValueError("Generated SPARQL statement is invalid\n" f"{e}")
+ return [r for r in res if isinstance(r, ResultRow)]
+
+ def update(
+ self,
+ query: str,
+ ) -> None:
+ """
+ Update the graph.
+ """
+ from rdflib.exceptions import ParserError
+
+ try:
+ self.graph.update(query)
+ except ParserError as e:
+ raise ValueError("Generated SPARQL statement is invalid\n" f"{e}")
+ if self.local_copy:
+ self.graph.serialize(
+ destination=self.local_copy, format=self.local_copy.split(".")[-1]
+ )
+ else:
+ raise ValueError("No target file specified for saving the updated file.")
+
+ @staticmethod
+ def _get_local_name(iri: str) -> str:
+ if "#" in iri:
+ local_name = iri.split("#")[-1]
+ elif "/" in iri:
+ local_name = iri.split("/")[-1]
+ else:
+ raise ValueError(f"Unexpected IRI '{iri}', contains neither '#' nor '/'.")
+ return local_name
+
+ def _res_to_str(self, res: rdflib.query.ResultRow, var: str) -> str:
+ return (
+ "<"
+ + str(res[var])
+ + "> ("
+ + self._get_local_name(res[var])
+ + ", "
+ + str(res["com"])
+ + ")"
+ )
+
+ def load_schema(self) -> None:
+ """
+ Load the graph schema information.
+ """
+
+ def _rdf_s_schema(
+ classes: List[rdflib.query.ResultRow],
+ relationships: List[rdflib.query.ResultRow],
+ ) -> str:
+ return (
+ f"In the following, each IRI is followed by the local name and "
+ f"optionally its description in parentheses. \n"
+ f"The RDF graph supports the following node types:\n"
+ f'{", ".join([self._res_to_str(r, "cls") for r in classes])}\n'
+ f"The RDF graph supports the following relationships:\n"
+ f'{", ".join([self._res_to_str(r, "rel") for r in relationships])}\n'
+ )
+
+ if self.standard == "rdf":
+ clss = self.query(cls_query_rdf)
+ rels = self.query(rel_query_rdf)
+ self.schema = _rdf_s_schema(clss, rels)
+ elif self.standard == "rdfs":
+ clss = self.query(cls_query_rdfs)
+ rels = self.query(rel_query_rdfs)
+ self.schema = _rdf_s_schema(clss, rels)
+ elif self.standard == "owl":
+ clss = self.query(cls_query_owl)
+ ops = self.query(op_query_owl)
+ dps = self.query(dp_query_owl)
+ self.schema = (
+ f"In the following, each IRI is followed by the local name and "
+ f"optionally its description in parentheses. \n"
+ f"The OWL graph supports the following node types:\n"
+ f'{", ".join([self._res_to_str(r, "cls") for r in clss])}\n'
+ f"The OWL graph supports the following object properties, "
+ f"i.e., relationships between objects:\n"
+ f'{", ".join([self._res_to_str(r, "op") for r in ops])}\n'
+ f"The OWL graph supports the following data properties, "
+ f"i.e., relationships between objects and literals:\n"
+ f'{", ".join([self._res_to_str(r, "dp") for r in dps])}\n'
+ )
+ else:
+ raise ValueError(f"Mode '{self.standard}' is currently not supported.")
diff --git a/libs/langchain/tests/integration_tests/graphs/__init__.py b/libs/community/langchain_community/indexes/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/graphs/__init__.py
rename to libs/community/langchain_community/indexes/__init__.py
diff --git a/libs/community/langchain_community/indexes/_sql_record_manager.py b/libs/community/langchain_community/indexes/_sql_record_manager.py
new file mode 100644
index 00000000000..544e828df2a
--- /dev/null
+++ b/libs/community/langchain_community/indexes/_sql_record_manager.py
@@ -0,0 +1,522 @@
+"""Implementation of a record management layer in SQLAlchemy.
+
+The management layer uses SQLAlchemy to track upserted records.
+
+Currently, this layer only works with SQLite; hopwever, should be adaptable
+to other SQL implementations with minimal effort.
+
+Currently, includes an implementation that uses SQLAlchemy which should
+allow it to work with a variety of SQL as a backend.
+
+* Each key is associated with an updated_at field.
+* This filed is updated whenever the key is updated.
+* Keys can be listed based on the updated at field.
+* Keys can be deleted.
+"""
+import contextlib
+import decimal
+import uuid
+from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union
+
+from sqlalchemy import (
+ URL,
+ Column,
+ Engine,
+ Float,
+ Index,
+ String,
+ UniqueConstraint,
+ and_,
+ create_engine,
+ delete,
+ select,
+ text,
+)
+from sqlalchemy.ext.asyncio import (
+ AsyncEngine,
+ AsyncSession,
+ async_sessionmaker,
+ create_async_engine,
+)
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Session, sessionmaker
+
+from langchain_community.indexes.base import RecordManager
+
+Base = declarative_base()
+
+
+class UpsertionRecord(Base): # type: ignore[valid-type,misc]
+ """Table used to keep track of when a key was last updated."""
+
+ # ATTENTION:
+ # Prior to modifying this table, please determine whether
+ # we should create migrations for this table to make sure
+ # users do not experience data loss.
+ __tablename__ = "upsertion_record"
+
+ uuid = Column(
+ String,
+ index=True,
+ default=lambda: str(uuid.uuid4()),
+ primary_key=True,
+ nullable=False,
+ )
+ key = Column(String, index=True)
+ # Using a non-normalized representation to handle `namespace` attribute.
+ # If the need arises, this attribute can be pulled into a separate Collection
+ # table at some time later.
+ namespace = Column(String, index=True, nullable=False)
+ group_id = Column(String, index=True, nullable=True)
+
+ # The timestamp associated with the last record upsertion.
+ updated_at = Column(Float, index=True)
+
+ __table_args__ = (
+ UniqueConstraint("key", "namespace", name="uix_key_namespace"),
+ Index("ix_key_namespace", "key", "namespace"),
+ )
+
+
+class SQLRecordManager(RecordManager):
+ """A SQL Alchemy based implementation of the record manager."""
+
+ def __init__(
+ self,
+ namespace: str,
+ *,
+ engine: Optional[Union[Engine, AsyncEngine]] = None,
+ db_url: Union[None, str, URL] = None,
+ engine_kwargs: Optional[Dict[str, Any]] = None,
+ async_mode: bool = False,
+ ) -> None:
+ """Initialize the SQLRecordManager.
+
+ This class serves as a manager persistence layer that uses an SQL
+ backend to track upserted records. You should specify either a db_url
+ to create an engine or provide an existing engine.
+
+ Args:
+ namespace: The namespace associated with this record manager.
+ engine: An already existing SQL Alchemy engine.
+ Default is None.
+ db_url: A database connection string used to create
+ an SQL Alchemy engine. Default is None.
+ engine_kwargs: Additional keyword arguments
+ to be passed when creating the engine. Default is an empty dictionary.
+ async_mode: Whether to create an async engine.
+ Driver should support async operations.
+ It only applies if db_url is provided.
+ Default is False.
+
+ Raises:
+ ValueError: If both db_url and engine are provided or neither.
+ AssertionError: If something unexpected happens during engine configuration.
+ """
+ super().__init__(namespace=namespace)
+ if db_url is None and engine is None:
+ raise ValueError("Must specify either db_url or engine")
+
+ if db_url is not None and engine is not None:
+ raise ValueError("Must specify either db_url or engine, not both")
+
+ _engine: Union[Engine, AsyncEngine]
+ if db_url:
+ if async_mode:
+ _engine = create_async_engine(db_url, **(engine_kwargs or {}))
+ else:
+ _engine = create_engine(db_url, **(engine_kwargs or {}))
+ elif engine:
+ _engine = engine
+
+ else:
+ raise AssertionError("Something went wrong with configuration of engine.")
+
+ _session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
+ if isinstance(_engine, AsyncEngine):
+ _session_factory = async_sessionmaker(bind=_engine)
+ else:
+ _session_factory = sessionmaker(bind=_engine)
+
+ self.engine = _engine
+ self.dialect = _engine.dialect.name
+ self.session_factory = _session_factory
+
+ def create_schema(self) -> None:
+ """Create the database schema."""
+ if isinstance(self.engine, AsyncEngine):
+ raise AssertionError("This method is not supported for async engines.")
+
+ Base.metadata.create_all(self.engine)
+
+ async def acreate_schema(self) -> None:
+ """Create the database schema."""
+
+ if not isinstance(self.engine, AsyncEngine):
+ raise AssertionError("This method is not supported for sync engines.")
+
+ async with self.engine.begin() as session:
+ await session.run_sync(Base.metadata.create_all)
+
+ @contextlib.contextmanager
+ def _make_session(self) -> Generator[Session, None, None]:
+ """Create a session and close it after use."""
+
+ if isinstance(self.session_factory, async_sessionmaker):
+ raise AssertionError("This method is not supported for async engines.")
+
+ session = self.session_factory()
+ try:
+ yield session
+ finally:
+ session.close()
+
+ @contextlib.asynccontextmanager
+ async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]:
+ """Create a session and close it after use."""
+
+ if not isinstance(self.session_factory, async_sessionmaker):
+ raise AssertionError("This method is not supported for sync engines.")
+
+ async with self.session_factory() as session:
+ yield session
+
+ def get_time(self) -> float:
+ """Get the current server time as a timestamp.
+
+ Please note it's critical that time is obtained from the server since
+ we want a monotonic clock.
+ """
+ with self._make_session() as session:
+ # * SQLite specific implementation, can be changed based on dialect.
+ # * For SQLite, unlike unixepoch it will work with older versions of SQLite.
+ # ----
+ # julianday('now'): Julian day number for the current date and time.
+ # The Julian day is a continuous count of days, starting from a
+ # reference date (Julian day number 0).
+ # 2440587.5 - constant represents the Julian day number for January 1, 1970
+ # 86400.0 - constant represents the number of seconds
+ # in a day (24 hours * 60 minutes * 60 seconds)
+ if self.dialect == "sqlite":
+ query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
+ elif self.dialect == "postgresql":
+ query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
+ else:
+ raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
+
+ dt = session.execute(query).scalar()
+ if isinstance(dt, decimal.Decimal):
+ dt = float(dt)
+ if not isinstance(dt, float):
+ raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
+ return dt
+
+ async def aget_time(self) -> float:
+ """Get the current server time as a timestamp.
+
+ Please note it's critical that time is obtained from the server since
+ we want a monotonic clock.
+ """
+ async with self._amake_session() as session:
+ # * SQLite specific implementation, can be changed based on dialect.
+ # * For SQLite, unlike unixepoch it will work with older versions of SQLite.
+ # ----
+ # julianday('now'): Julian day number for the current date and time.
+ # The Julian day is a continuous count of days, starting from a
+ # reference date (Julian day number 0).
+ # 2440587.5 - constant represents the Julian day number for January 1, 1970
+ # 86400.0 - constant represents the number of seconds
+ # in a day (24 hours * 60 minutes * 60 seconds)
+ if self.dialect == "sqlite":
+ query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
+ elif self.dialect == "postgresql":
+ query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
+ else:
+ raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
+
+ dt = (await session.execute(query)).scalar_one_or_none()
+
+ if isinstance(dt, decimal.Decimal):
+ dt = float(dt)
+ if not isinstance(dt, float):
+ raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
+ return dt
+
+ def update(
+ self,
+ keys: Sequence[str],
+ *,
+ group_ids: Optional[Sequence[Optional[str]]] = None,
+ time_at_least: Optional[float] = None,
+ ) -> None:
+ """Upsert records into the SQLite database."""
+ if group_ids is None:
+ group_ids = [None] * len(keys)
+
+ if len(keys) != len(group_ids):
+ raise ValueError(
+ f"Number of keys ({len(keys)}) does not match number of "
+ f"group_ids ({len(group_ids)})"
+ )
+
+ # Get the current time from the server.
+ # This makes an extra round trip to the server, should not be a big deal
+ # if the batch size is large enough.
+ # Getting the time here helps us compare it against the time_at_least
+ # and raise an error if there is a time sync issue.
+ # Here, we're just being extra careful to minimize the chance of
+ # data loss due to incorrectly deleting records.
+ update_time = self.get_time()
+
+ if time_at_least and update_time < time_at_least:
+ # Safeguard against time sync issues
+ raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
+
+ records_to_upsert = [
+ {
+ "key": key,
+ "namespace": self.namespace,
+ "updated_at": update_time,
+ "group_id": group_id,
+ }
+ for key, group_id in zip(keys, group_ids)
+ ]
+
+ with self._make_session() as session:
+ if self.dialect == "sqlite":
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
+
+ # Note: uses SQLite insert to make on_conflict_do_update work.
+ # This code needs to be generalized a bit to work with more dialects.
+ insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
+ stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
+ [UpsertionRecord.key, UpsertionRecord.namespace],
+ set_=dict(
+ # attr-defined type ignore
+ updated_at=insert_stmt.excluded.updated_at, # type: ignore
+ group_id=insert_stmt.excluded.group_id, # type: ignore
+ ),
+ )
+ elif self.dialect == "postgresql":
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
+
+ # Note: uses SQLite insert to make on_conflict_do_update work.
+ # This code needs to be generalized a bit to work with more dialects.
+ insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
+ stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
+ "uix_key_namespace", # Name of constraint
+ set_=dict(
+ # attr-defined type ignore
+ updated_at=insert_stmt.excluded.updated_at, # type: ignore
+ group_id=insert_stmt.excluded.group_id, # type: ignore
+ ),
+ )
+ else:
+ raise NotImplementedError(f"Unsupported dialect {self.dialect}")
+
+ session.execute(stmt)
+ session.commit()
+
+ async def aupdate(
+ self,
+ keys: Sequence[str],
+ *,
+ group_ids: Optional[Sequence[Optional[str]]] = None,
+ time_at_least: Optional[float] = None,
+ ) -> None:
+ """Upsert records into the SQLite database."""
+ if group_ids is None:
+ group_ids = [None] * len(keys)
+
+ if len(keys) != len(group_ids):
+ raise ValueError(
+ f"Number of keys ({len(keys)}) does not match number of "
+ f"group_ids ({len(group_ids)})"
+ )
+
+ # Get the current time from the server.
+ # This makes an extra round trip to the server, should not be a big deal
+ # if the batch size is large enough.
+ # Getting the time here helps us compare it against the time_at_least
+ # and raise an error if there is a time sync issue.
+ # Here, we're just being extra careful to minimize the chance of
+ # data loss due to incorrectly deleting records.
+ update_time = await self.aget_time()
+
+ if time_at_least and update_time < time_at_least:
+ # Safeguard against time sync issues
+ raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
+
+ records_to_upsert = [
+ {
+ "key": key,
+ "namespace": self.namespace,
+ "updated_at": update_time,
+ "group_id": group_id,
+ }
+ for key, group_id in zip(keys, group_ids)
+ ]
+
+ async with self._amake_session() as session:
+ if self.dialect == "sqlite":
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
+
+ # Note: uses SQLite insert to make on_conflict_do_update work.
+ # This code needs to be generalized a bit to work with more dialects.
+ insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
+ stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
+ [UpsertionRecord.key, UpsertionRecord.namespace],
+ set_=dict(
+ # attr-defined type ignore
+ updated_at=insert_stmt.excluded.updated_at, # type: ignore
+ group_id=insert_stmt.excluded.group_id, # type: ignore
+ ),
+ )
+ elif self.dialect == "postgresql":
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
+
+ # Note: uses SQLite insert to make on_conflict_do_update work.
+ # This code needs to be generalized a bit to work with more dialects.
+ insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
+ stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
+ "uix_key_namespace", # Name of constraint
+ set_=dict(
+ # attr-defined type ignore
+ updated_at=insert_stmt.excluded.updated_at, # type: ignore
+ group_id=insert_stmt.excluded.group_id, # type: ignore
+ ),
+ )
+ else:
+ raise NotImplementedError(f"Unsupported dialect {self.dialect}")
+
+ await session.execute(stmt)
+ await session.commit()
+
+ def exists(self, keys: Sequence[str]) -> List[bool]:
+ """Check if the given keys exist in the SQLite database."""
+ with self._make_session() as session:
+ records = (
+ # mypy does not recognize .all()
+ session.query(UpsertionRecord.key) # type: ignore[attr-defined]
+ .filter(
+ and_(
+ UpsertionRecord.key.in_(keys),
+ UpsertionRecord.namespace == self.namespace,
+ )
+ )
+ .all()
+ )
+ found_keys = set(r.key for r in records)
+ return [k in found_keys for k in keys]
+
+ async def aexists(self, keys: Sequence[str]) -> List[bool]:
+ """Check if the given keys exist in the SQLite database."""
+ async with self._amake_session() as session:
+ records = (
+ (
+ await session.execute(
+ select(UpsertionRecord.key).where(
+ and_(
+ UpsertionRecord.key.in_(keys),
+ UpsertionRecord.namespace == self.namespace,
+ )
+ )
+ )
+ )
+ .scalars()
+ .all()
+ )
+ found_keys = set(records)
+ return [k in found_keys for k in keys]
+
+ def list_keys(
+ self,
+ *,
+ before: Optional[float] = None,
+ after: Optional[float] = None,
+ group_ids: Optional[Sequence[str]] = None,
+ limit: Optional[int] = None,
+ ) -> List[str]:
+ """List records in the SQLite database based on the provided date range."""
+ with self._make_session() as session:
+ query = session.query(UpsertionRecord).filter(
+ UpsertionRecord.namespace == self.namespace
+ )
+
+ # mypy does not recognize .all() or .filter()
+ if after:
+ query = query.filter( # type: ignore[attr-defined]
+ UpsertionRecord.updated_at > after
+ )
+ if before:
+ query = query.filter( # type: ignore[attr-defined]
+ UpsertionRecord.updated_at < before
+ )
+ if group_ids:
+ query = query.filter( # type: ignore[attr-defined]
+ UpsertionRecord.group_id.in_(group_ids)
+ )
+
+ if limit:
+ query = query.limit(limit) # type: ignore[attr-defined]
+ records = query.all() # type: ignore[attr-defined]
+ return [r.key for r in records]
+
+ async def alist_keys(
+ self,
+ *,
+ before: Optional[float] = None,
+ after: Optional[float] = None,
+ group_ids: Optional[Sequence[str]] = None,
+ limit: Optional[int] = None,
+ ) -> List[str]:
+ """List records in the SQLite database based on the provided date range."""
+ async with self._amake_session() as session:
+ query = select(UpsertionRecord.key).filter(
+ UpsertionRecord.namespace == self.namespace
+ )
+
+ # mypy does not recognize .all() or .filter()
+ if after:
+ query = query.filter( # type: ignore[attr-defined]
+ UpsertionRecord.updated_at > after
+ )
+ if before:
+ query = query.filter( # type: ignore[attr-defined]
+ UpsertionRecord.updated_at < before
+ )
+ if group_ids:
+ query = query.filter( # type: ignore[attr-defined]
+ UpsertionRecord.group_id.in_(group_ids)
+ )
+
+ if limit:
+ query = query.limit(limit) # type: ignore[attr-defined]
+ records = (await session.execute(query)).scalars().all()
+ return list(records)
+
+ def delete_keys(self, keys: Sequence[str]) -> None:
+ """Delete records from the SQLite database."""
+ with self._make_session() as session:
+ # mypy does not recognize .delete()
+ session.query(UpsertionRecord).filter(
+ and_(
+ UpsertionRecord.key.in_(keys),
+ UpsertionRecord.namespace == self.namespace,
+ )
+ ).delete() # type: ignore[attr-defined]
+ session.commit()
+
+ async def adelete_keys(self, keys: Sequence[str]) -> None:
+ """Delete records from the SQLite database."""
+ async with self._amake_session() as session:
+ await session.execute(
+ delete(UpsertionRecord).where(
+ and_(
+ UpsertionRecord.key.in_(keys),
+ UpsertionRecord.namespace == self.namespace,
+ )
+ )
+ )
+
+ await session.commit()
diff --git a/libs/community/langchain_community/indexes/base.py b/libs/community/langchain_community/indexes/base.py
new file mode 100644
index 00000000000..46ef5bf2efa
--- /dev/null
+++ b/libs/community/langchain_community/indexes/base.py
@@ -0,0 +1,172 @@
+from __future__ import annotations
+
+import uuid
+from abc import ABC, abstractmethod
+from typing import List, Optional, Sequence
+
+NAMESPACE_UUID = uuid.UUID(int=1984)
+
+
+class RecordManager(ABC):
+ """An abstract base class representing the interface for a record manager."""
+
+ def __init__(
+ self,
+ namespace: str,
+ ) -> None:
+ """Initialize the record manager.
+
+ Args:
+ namespace (str): The namespace for the record manager.
+ """
+ self.namespace = namespace
+
+ @abstractmethod
+ def create_schema(self) -> None:
+ """Create the database schema for the record manager."""
+
+ @abstractmethod
+ async def acreate_schema(self) -> None:
+ """Create the database schema for the record manager."""
+
+ @abstractmethod
+ def get_time(self) -> float:
+ """Get the current server time as a high resolution timestamp!
+
+ It's important to get this from the server to ensure a monotonic clock,
+ otherwise there may be data loss when cleaning up old documents!
+
+ Returns:
+ The current server time as a float timestamp.
+ """
+
+ @abstractmethod
+ async def aget_time(self) -> float:
+ """Get the current server time as a high resolution timestamp!
+
+ It's important to get this from the server to ensure a monotonic clock,
+ otherwise there may be data loss when cleaning up old documents!
+
+ Returns:
+ The current server time as a float timestamp.
+ """
+
+ @abstractmethod
+ def update(
+ self,
+ keys: Sequence[str],
+ *,
+ group_ids: Optional[Sequence[Optional[str]]] = None,
+ time_at_least: Optional[float] = None,
+ ) -> None:
+ """Upsert records into the database.
+
+ Args:
+ keys: A list of record keys to upsert.
+ group_ids: A list of group IDs corresponding to the keys.
+ time_at_least: if provided, updates should only happen if the
+ updated_at field is at least this time.
+
+ Raises:
+ ValueError: If the length of keys doesn't match the length of group_ids.
+ """
+
+ @abstractmethod
+ async def aupdate(
+ self,
+ keys: Sequence[str],
+ *,
+ group_ids: Optional[Sequence[Optional[str]]] = None,
+ time_at_least: Optional[float] = None,
+ ) -> None:
+ """Upsert records into the database.
+
+ Args:
+ keys: A list of record keys to upsert.
+ group_ids: A list of group IDs corresponding to the keys.
+ time_at_least: if provided, updates should only happen if the
+ updated_at field is at least this time.
+
+ Raises:
+ ValueError: If the length of keys doesn't match the length of group_ids.
+ """
+
+ @abstractmethod
+ def exists(self, keys: Sequence[str]) -> List[bool]:
+ """Check if the provided keys exist in the database.
+
+ Args:
+ keys: A list of keys to check.
+
+ Returns:
+ A list of boolean values indicating the existence of each key.
+ """
+
+ @abstractmethod
+ async def aexists(self, keys: Sequence[str]) -> List[bool]:
+ """Check if the provided keys exist in the database.
+
+ Args:
+ keys: A list of keys to check.
+
+ Returns:
+ A list of boolean values indicating the existence of each key.
+ """
+
+ @abstractmethod
+ def list_keys(
+ self,
+ *,
+ before: Optional[float] = None,
+ after: Optional[float] = None,
+ group_ids: Optional[Sequence[str]] = None,
+ limit: Optional[int] = None,
+ ) -> List[str]:
+ """List records in the database based on the provided filters.
+
+ Args:
+ before: Filter to list records updated before this time.
+ after: Filter to list records updated after this time.
+ group_ids: Filter to list records with specific group IDs.
+ limit: optional limit on the number of records to return.
+
+ Returns:
+ A list of keys for the matching records.
+ """
+
+ @abstractmethod
+ async def alist_keys(
+ self,
+ *,
+ before: Optional[float] = None,
+ after: Optional[float] = None,
+ group_ids: Optional[Sequence[str]] = None,
+ limit: Optional[int] = None,
+ ) -> List[str]:
+ """List records in the database based on the provided filters.
+
+ Args:
+ before: Filter to list records updated before this time.
+ after: Filter to list records updated after this time.
+ group_ids: Filter to list records with specific group IDs.
+ limit: optional limit on the number of records to return.
+
+ Returns:
+ A list of keys for the matching records.
+ """
+
+ @abstractmethod
+ def delete_keys(self, keys: Sequence[str]) -> None:
+ """Delete specified records from the database.
+
+ Args:
+ keys: A list of keys to delete.
+ """
+
+ @abstractmethod
+ async def adelete_keys(self, keys: Sequence[str]) -> None:
+ """Delete specified records from the database.
+
+ Args:
+ keys: A list of keys to delete.
+ """
diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py
new file mode 100644
index 00000000000..729198aaa05
--- /dev/null
+++ b/libs/community/langchain_community/llms/__init__.py
@@ -0,0 +1,882 @@
+"""
+**LLM** classes provide
+access to the large language model (**LLM**) APIs and services.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseLanguageModel --> BaseLLM --> LLM --> # Examples: AI21, HuggingFaceHub, OpenAI
+
+**Main helpers:**
+
+.. code-block::
+
+ LLMResult, PromptValue,
+ CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
+ CallbackManager, AsyncCallbackManager,
+ AIMessage, BaseMessage
+""" # noqa: E501
+from typing import Any, Callable, Dict, Type
+
+from langchain_core.language_models.llms import BaseLLM
+
+
+def _import_ai21() -> Any:
+ from langchain_community.llms.ai21 import AI21
+
+ return AI21
+
+
+def _import_aleph_alpha() -> Any:
+ from langchain_community.llms.aleph_alpha import AlephAlpha
+
+ return AlephAlpha
+
+
+def _import_amazon_api_gateway() -> Any:
+ from langchain_community.llms.amazon_api_gateway import AmazonAPIGateway
+
+ return AmazonAPIGateway
+
+
+def _import_anthropic() -> Any:
+ from langchain_community.llms.anthropic import Anthropic
+
+ return Anthropic
+
+
+def _import_anyscale() -> Any:
+ from langchain_community.llms.anyscale import Anyscale
+
+ return Anyscale
+
+
+def _import_arcee() -> Any:
+ from langchain_community.llms.arcee import Arcee
+
+ return Arcee
+
+
+def _import_aviary() -> Any:
+ from langchain_community.llms.aviary import Aviary
+
+ return Aviary
+
+
+def _import_azureml_endpoint() -> Any:
+ from langchain_community.llms.azureml_endpoint import AzureMLOnlineEndpoint
+
+ return AzureMLOnlineEndpoint
+
+
+def _import_baidu_qianfan_endpoint() -> Any:
+ from langchain_community.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
+
+ return QianfanLLMEndpoint
+
+
+def _import_bananadev() -> Any:
+ from langchain_community.llms.bananadev import Banana
+
+ return Banana
+
+
+def _import_baseten() -> Any:
+ from langchain_community.llms.baseten import Baseten
+
+ return Baseten
+
+
+def _import_beam() -> Any:
+ from langchain_community.llms.beam import Beam
+
+ return Beam
+
+
+def _import_bedrock() -> Any:
+ from langchain_community.llms.bedrock import Bedrock
+
+ return Bedrock
+
+
+def _import_bittensor() -> Any:
+ from langchain_community.llms.bittensor import NIBittensorLLM
+
+ return NIBittensorLLM
+
+
+def _import_cerebriumai() -> Any:
+ from langchain_community.llms.cerebriumai import CerebriumAI
+
+ return CerebriumAI
+
+
+def _import_chatglm() -> Any:
+ from langchain_community.llms.chatglm import ChatGLM
+
+ return ChatGLM
+
+
+def _import_clarifai() -> Any:
+ from langchain_community.llms.clarifai import Clarifai
+
+ return Clarifai
+
+
+def _import_cohere() -> Any:
+ from langchain_community.llms.cohere import Cohere
+
+ return Cohere
+
+
+def _import_ctransformers() -> Any:
+ from langchain_community.llms.ctransformers import CTransformers
+
+ return CTransformers
+
+
+def _import_ctranslate2() -> Any:
+ from langchain_community.llms.ctranslate2 import CTranslate2
+
+ return CTranslate2
+
+
+def _import_databricks() -> Any:
+ from langchain_community.llms.databricks import Databricks
+
+ return Databricks
+
+
+def _import_databricks_chat() -> Any:
+ from langchain_community.chat_models.databricks import ChatDatabricks
+
+ return ChatDatabricks
+
+
+def _import_deepinfra() -> Any:
+ from langchain_community.llms.deepinfra import DeepInfra
+
+ return DeepInfra
+
+
+def _import_deepsparse() -> Any:
+ from langchain_community.llms.deepsparse import DeepSparse
+
+ return DeepSparse
+
+
+def _import_edenai() -> Any:
+ from langchain_community.llms.edenai import EdenAI
+
+ return EdenAI
+
+
+def _import_fake() -> Any:
+ from langchain_community.llms.fake import FakeListLLM
+
+ return FakeListLLM
+
+
+def _import_fireworks() -> Any:
+ from langchain_community.llms.fireworks import Fireworks
+
+ return Fireworks
+
+
+def _import_forefrontai() -> Any:
+ from langchain_community.llms.forefrontai import ForefrontAI
+
+ return ForefrontAI
+
+
+def _import_gigachat() -> Any:
+ from langchain_community.llms.gigachat import GigaChat
+
+ return GigaChat
+
+
+def _import_google_palm() -> Any:
+ from langchain_community.llms.google_palm import GooglePalm
+
+ return GooglePalm
+
+
+def _import_gooseai() -> Any:
+ from langchain_community.llms.gooseai import GooseAI
+
+ return GooseAI
+
+
+def _import_gpt4all() -> Any:
+ from langchain_community.llms.gpt4all import GPT4All
+
+ return GPT4All
+
+
+def _import_gradient_ai() -> Any:
+ from langchain_community.llms.gradient_ai import GradientLLM
+
+ return GradientLLM
+
+
+def _import_huggingface_endpoint() -> Any:
+ from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
+
+ return HuggingFaceEndpoint
+
+
+def _import_huggingface_hub() -> Any:
+ from langchain_community.llms.huggingface_hub import HuggingFaceHub
+
+ return HuggingFaceHub
+
+
+def _import_huggingface_pipeline() -> Any:
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
+
+ return HuggingFacePipeline
+
+
+def _import_huggingface_text_gen_inference() -> Any:
+ from langchain_community.llms.huggingface_text_gen_inference import (
+ HuggingFaceTextGenInference,
+ )
+
+ return HuggingFaceTextGenInference
+
+
+def _import_human() -> Any:
+ from langchain_community.llms.human import HumanInputLLM
+
+ return HumanInputLLM
+
+
+def _import_javelin_ai_gateway() -> Any:
+ from langchain_community.llms.javelin_ai_gateway import JavelinAIGateway
+
+ return JavelinAIGateway
+
+
+def _import_koboldai() -> Any:
+ from langchain_community.llms.koboldai import KoboldApiLLM
+
+ return KoboldApiLLM
+
+
+def _import_llamacpp() -> Any:
+ from langchain_community.llms.llamacpp import LlamaCpp
+
+ return LlamaCpp
+
+
+def _import_manifest() -> Any:
+ from langchain_community.llms.manifest import ManifestWrapper
+
+ return ManifestWrapper
+
+
+def _import_minimax() -> Any:
+ from langchain_community.llms.minimax import Minimax
+
+ return Minimax
+
+
+def _import_mlflow() -> Any:
+ from langchain_community.llms.mlflow import Mlflow
+
+ return Mlflow
+
+
+def _import_mlflow_chat() -> Any:
+ from langchain_community.chat_models.mlflow import ChatMlflow
+
+ return ChatMlflow
+
+
+def _import_mlflow_ai_gateway() -> Any:
+ from langchain_community.llms.mlflow_ai_gateway import MlflowAIGateway
+
+ return MlflowAIGateway
+
+
+def _import_modal() -> Any:
+ from langchain_community.llms.modal import Modal
+
+ return Modal
+
+
+def _import_mosaicml() -> Any:
+ from langchain_community.llms.mosaicml import MosaicML
+
+ return MosaicML
+
+
+def _import_nlpcloud() -> Any:
+ from langchain_community.llms.nlpcloud import NLPCloud
+
+ return NLPCloud
+
+
+def _import_octoai_endpoint() -> Any:
+ from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
+
+ return OctoAIEndpoint
+
+
+def _import_ollama() -> Any:
+ from langchain_community.llms.ollama import Ollama
+
+ return Ollama
+
+
+def _import_opaqueprompts() -> Any:
+ from langchain_community.llms.opaqueprompts import OpaquePrompts
+
+ return OpaquePrompts
+
+
+def _import_azure_openai() -> Any:
+ from langchain_community.llms.openai import AzureOpenAI
+
+ return AzureOpenAI
+
+
+def _import_openai() -> Any:
+ from langchain_community.llms.openai import OpenAI
+
+ return OpenAI
+
+
+def _import_openai_chat() -> Any:
+ from langchain_community.llms.openai import OpenAIChat
+
+ return OpenAIChat
+
+
+def _import_openllm() -> Any:
+ from langchain_community.llms.openllm import OpenLLM
+
+ return OpenLLM
+
+
+def _import_openlm() -> Any:
+ from langchain_community.llms.openlm import OpenLM
+
+ return OpenLM
+
+
+def _import_pai_eas_endpoint() -> Any:
+ from langchain_community.llms.pai_eas_endpoint import PaiEasEndpoint
+
+ return PaiEasEndpoint
+
+
+def _import_petals() -> Any:
+ from langchain_community.llms.petals import Petals
+
+ return Petals
+
+
+def _import_pipelineai() -> Any:
+ from langchain_community.llms.pipelineai import PipelineAI
+
+ return PipelineAI
+
+
+def _import_predibase() -> Any:
+ from langchain_community.llms.predibase import Predibase
+
+ return Predibase
+
+
+def _import_predictionguard() -> Any:
+ from langchain_community.llms.predictionguard import PredictionGuard
+
+ return PredictionGuard
+
+
+def _import_promptlayer() -> Any:
+ from langchain_community.llms.promptlayer_openai import PromptLayerOpenAI
+
+ return PromptLayerOpenAI
+
+
+def _import_promptlayer_chat() -> Any:
+ from langchain_community.llms.promptlayer_openai import PromptLayerOpenAIChat
+
+ return PromptLayerOpenAIChat
+
+
+def _import_replicate() -> Any:
+ from langchain_community.llms.replicate import Replicate
+
+ return Replicate
+
+
+def _import_rwkv() -> Any:
+ from langchain_community.llms.rwkv import RWKV
+
+ return RWKV
+
+
+def _import_sagemaker_endpoint() -> Any:
+ from langchain_community.llms.sagemaker_endpoint import SagemakerEndpoint
+
+ return SagemakerEndpoint
+
+
+def _import_self_hosted() -> Any:
+ from langchain_community.llms.self_hosted import SelfHostedPipeline
+
+ return SelfHostedPipeline
+
+
+def _import_self_hosted_hugging_face() -> Any:
+ from langchain_community.llms.self_hosted_hugging_face import (
+ SelfHostedHuggingFaceLLM,
+ )
+
+ return SelfHostedHuggingFaceLLM
+
+
+def _import_stochasticai() -> Any:
+ from langchain_community.llms.stochasticai import StochasticAI
+
+ return StochasticAI
+
+
+def _import_symblai_nebula() -> Any:
+ from langchain_community.llms.symblai_nebula import Nebula
+
+ return Nebula
+
+
+def _import_textgen() -> Any:
+ from langchain_community.llms.textgen import TextGen
+
+ return TextGen
+
+
+def _import_titan_takeoff() -> Any:
+ from langchain_community.llms.titan_takeoff import TitanTakeoff
+
+ return TitanTakeoff
+
+
+def _import_titan_takeoff_pro() -> Any:
+ from langchain_community.llms.titan_takeoff_pro import TitanTakeoffPro
+
+ return TitanTakeoffPro
+
+
+def _import_together() -> Any:
+ from langchain_community.llms.together import Together
+
+ return Together
+
+
+def _import_tongyi() -> Any:
+ from langchain_community.llms.tongyi import Tongyi
+
+ return Tongyi
+
+
+def _import_vertex() -> Any:
+ from langchain_community.llms.vertexai import VertexAI
+
+ return VertexAI
+
+
+def _import_vertex_model_garden() -> Any:
+ from langchain_community.llms.vertexai import VertexAIModelGarden
+
+ return VertexAIModelGarden
+
+
+def _import_vllm() -> Any:
+ from langchain_community.llms.vllm import VLLM
+
+ return VLLM
+
+
+def _import_vllm_openai() -> Any:
+ from langchain_community.llms.vllm import VLLMOpenAI
+
+ return VLLMOpenAI
+
+
+def _import_watsonxllm() -> Any:
+ from langchain_community.llms.watsonxllm import WatsonxLLM
+
+ return WatsonxLLM
+
+
+def _import_writer() -> Any:
+ from langchain_community.llms.writer import Writer
+
+ return Writer
+
+
+def _import_xinference() -> Any:
+ from langchain_community.llms.xinference import Xinference
+
+ return Xinference
+
+
+def _import_yandex_gpt() -> Any:
+ from langchain_community.llms.yandex import YandexGPT
+
+ return YandexGPT
+
+
+def _import_volcengine_maas() -> Any:
+ from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM
+
+ return VolcEngineMaasLLM
+
+
+def __getattr__(name: str) -> Any:
+ if name == "AI21":
+ return _import_ai21()
+ elif name == "AlephAlpha":
+ return _import_aleph_alpha()
+ elif name == "AmazonAPIGateway":
+ return _import_amazon_api_gateway()
+ elif name == "Anthropic":
+ return _import_anthropic()
+ elif name == "Anyscale":
+ return _import_anyscale()
+ elif name == "Arcee":
+ return _import_arcee()
+ elif name == "Aviary":
+ return _import_aviary()
+ elif name == "AzureMLOnlineEndpoint":
+ return _import_azureml_endpoint()
+ elif name == "QianfanLLMEndpoint":
+ return _import_baidu_qianfan_endpoint()
+ elif name == "Banana":
+ return _import_bananadev()
+ elif name == "Baseten":
+ return _import_baseten()
+ elif name == "Beam":
+ return _import_beam()
+ elif name == "Bedrock":
+ return _import_bedrock()
+ elif name == "NIBittensorLLM":
+ return _import_bittensor()
+ elif name == "CerebriumAI":
+ return _import_cerebriumai()
+ elif name == "ChatGLM":
+ return _import_chatglm()
+ elif name == "Clarifai":
+ return _import_clarifai()
+ elif name == "Cohere":
+ return _import_cohere()
+ elif name == "CTransformers":
+ return _import_ctransformers()
+ elif name == "CTranslate2":
+ return _import_ctranslate2()
+ elif name == "Databricks":
+ return _import_databricks()
+ elif name == "DeepInfra":
+ return _import_deepinfra()
+ elif name == "DeepSparse":
+ return _import_deepsparse()
+ elif name == "EdenAI":
+ return _import_edenai()
+ elif name == "FakeListLLM":
+ return _import_fake()
+ elif name == "Fireworks":
+ return _import_fireworks()
+ elif name == "ForefrontAI":
+ return _import_forefrontai()
+ elif name == "GigaChat":
+ return _import_gigachat()
+ elif name == "GooglePalm":
+ return _import_google_palm()
+ elif name == "GooseAI":
+ return _import_gooseai()
+ elif name == "GPT4All":
+ return _import_gpt4all()
+ elif name == "GradientLLM":
+ return _import_gradient_ai()
+ elif name == "HuggingFaceEndpoint":
+ return _import_huggingface_endpoint()
+ elif name == "HuggingFaceHub":
+ return _import_huggingface_hub()
+ elif name == "HuggingFacePipeline":
+ return _import_huggingface_pipeline()
+ elif name == "HuggingFaceTextGenInference":
+ return _import_huggingface_text_gen_inference()
+ elif name == "HumanInputLLM":
+ return _import_human()
+ elif name == "JavelinAIGateway":
+ return _import_javelin_ai_gateway()
+ elif name == "KoboldApiLLM":
+ return _import_koboldai()
+ elif name == "LlamaCpp":
+ return _import_llamacpp()
+ elif name == "ManifestWrapper":
+ return _import_manifest()
+ elif name == "Minimax":
+ return _import_minimax()
+ elif name == "Mlflow":
+ return _import_mlflow()
+ elif name == "MlflowAIGateway":
+ return _import_mlflow_ai_gateway()
+ elif name == "Modal":
+ return _import_modal()
+ elif name == "MosaicML":
+ return _import_mosaicml()
+ elif name == "NLPCloud":
+ return _import_nlpcloud()
+ elif name == "OctoAIEndpoint":
+ return _import_octoai_endpoint()
+ elif name == "Ollama":
+ return _import_ollama()
+ elif name == "OpaquePrompts":
+ return _import_opaqueprompts()
+ elif name == "AzureOpenAI":
+ return _import_azure_openai()
+ elif name == "OpenAI":
+ return _import_openai()
+ elif name == "OpenAIChat":
+ return _import_openai_chat()
+ elif name == "OpenLLM":
+ return _import_openllm()
+ elif name == "OpenLM":
+ return _import_openlm()
+ elif name == "PaiEasEndpoint":
+ return _import_pai_eas_endpoint()
+ elif name == "Petals":
+ return _import_petals()
+ elif name == "PipelineAI":
+ return _import_pipelineai()
+ elif name == "Predibase":
+ return _import_predibase()
+ elif name == "PredictionGuard":
+ return _import_predictionguard()
+ elif name == "PromptLayerOpenAI":
+ return _import_promptlayer()
+ elif name == "PromptLayerOpenAIChat":
+ return _import_promptlayer_chat()
+ elif name == "Replicate":
+ return _import_replicate()
+ elif name == "RWKV":
+ return _import_rwkv()
+ elif name == "SagemakerEndpoint":
+ return _import_sagemaker_endpoint()
+ elif name == "SelfHostedPipeline":
+ return _import_self_hosted()
+ elif name == "SelfHostedHuggingFaceLLM":
+ return _import_self_hosted_hugging_face()
+ elif name == "StochasticAI":
+ return _import_stochasticai()
+ elif name == "Nebula":
+ return _import_symblai_nebula()
+ elif name == "TextGen":
+ return _import_textgen()
+ elif name == "TitanTakeoff":
+ return _import_titan_takeoff()
+ elif name == "TitanTakeoffPro":
+ return _import_titan_takeoff_pro()
+ elif name == "Together":
+ return _import_together()
+ elif name == "Tongyi":
+ return _import_tongyi()
+ elif name == "VertexAI":
+ return _import_vertex()
+ elif name == "VertexAIModelGarden":
+ return _import_vertex_model_garden()
+ elif name == "VLLM":
+ return _import_vllm()
+ elif name == "VLLMOpenAI":
+ return _import_vllm_openai()
+ elif name == "WatsonxLLM":
+ return _import_watsonxllm()
+ elif name == "Writer":
+ return _import_writer()
+ elif name == "Xinference":
+ return _import_xinference()
+ elif name == "YandexGPT":
+ return _import_yandex_gpt()
+ elif name == "VolcEngineMaasLLM":
+ return _import_volcengine_maas()
+ elif name == "type_to_cls_dict":
+ # for backwards compatibility
+ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
+ k: v() for k, v in get_type_to_cls_dict().items()
+ }
+ return type_to_cls_dict
+ else:
+ raise AttributeError(f"Could not find: {name}")
+
+
+__all__ = [
+ "AI21",
+ "AlephAlpha",
+ "AmazonAPIGateway",
+ "Anthropic",
+ "Anyscale",
+ "Arcee",
+ "Aviary",
+ "AzureMLOnlineEndpoint",
+ "AzureOpenAI",
+ "Banana",
+ "Baseten",
+ "Beam",
+ "Bedrock",
+ "CTransformers",
+ "CTranslate2",
+ "CerebriumAI",
+ "ChatGLM",
+ "Clarifai",
+ "Cohere",
+ "Databricks",
+ "DeepInfra",
+ "DeepSparse",
+ "EdenAI",
+ "FakeListLLM",
+ "Fireworks",
+ "ForefrontAI",
+ "GigaChat",
+ "GPT4All",
+ "GooglePalm",
+ "GooseAI",
+ "GradientLLM",
+ "HuggingFaceEndpoint",
+ "HuggingFaceHub",
+ "HuggingFacePipeline",
+ "HuggingFaceTextGenInference",
+ "HumanInputLLM",
+ "KoboldApiLLM",
+ "LlamaCpp",
+ "TextGen",
+ "ManifestWrapper",
+ "Minimax",
+ "MlflowAIGateway",
+ "Modal",
+ "MosaicML",
+ "Nebula",
+ "NIBittensorLLM",
+ "NLPCloud",
+ "Ollama",
+ "OpenAI",
+ "OpenAIChat",
+ "OpenLLM",
+ "OpenLM",
+ "PaiEasEndpoint",
+ "Petals",
+ "PipelineAI",
+ "Predibase",
+ "PredictionGuard",
+ "PromptLayerOpenAI",
+ "PromptLayerOpenAIChat",
+ "OpaquePrompts",
+ "RWKV",
+ "Replicate",
+ "SagemakerEndpoint",
+ "SelfHostedHuggingFaceLLM",
+ "SelfHostedPipeline",
+ "StochasticAI",
+ "TitanTakeoff",
+ "TitanTakeoffPro",
+ "Tongyi",
+ "VertexAI",
+ "VertexAIModelGarden",
+ "VLLM",
+ "VLLMOpenAI",
+ "WatsonxLLM",
+ "Writer",
+ "OctoAIEndpoint",
+ "Xinference",
+ "JavelinAIGateway",
+ "QianfanLLMEndpoint",
+ "YandexGPT",
+ "VolcEngineMaasLLM",
+]
+
+
+def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
+ return {
+ "ai21": _import_ai21,
+ "aleph_alpha": _import_aleph_alpha,
+ "amazon_api_gateway": _import_amazon_api_gateway,
+ "amazon_bedrock": _import_bedrock,
+ "anthropic": _import_anthropic,
+ "anyscale": _import_anyscale,
+ "arcee": _import_arcee,
+ "aviary": _import_aviary,
+ "azure": _import_azure_openai,
+ "azureml_endpoint": _import_azureml_endpoint,
+ "bananadev": _import_bananadev,
+ "baseten": _import_baseten,
+ "beam": _import_beam,
+ "cerebriumai": _import_cerebriumai,
+ "chat_glm": _import_chatglm,
+ "clarifai": _import_clarifai,
+ "cohere": _import_cohere,
+ "ctransformers": _import_ctransformers,
+ "ctranslate2": _import_ctranslate2,
+ "databricks": _import_databricks,
+ "databricks-chat": _import_databricks_chat,
+ "deepinfra": _import_deepinfra,
+ "deepsparse": _import_deepsparse,
+ "edenai": _import_edenai,
+ "fake-list": _import_fake,
+ "forefrontai": _import_forefrontai,
+ "giga-chat-model": _import_gigachat,
+ "google_palm": _import_google_palm,
+ "gooseai": _import_gooseai,
+ "gradient": _import_gradient_ai,
+ "gpt4all": _import_gpt4all,
+ "huggingface_endpoint": _import_huggingface_endpoint,
+ "huggingface_hub": _import_huggingface_hub,
+ "huggingface_pipeline": _import_huggingface_pipeline,
+ "huggingface_textgen_inference": _import_huggingface_text_gen_inference,
+ "human-input": _import_human,
+ "koboldai": _import_koboldai,
+ "llamacpp": _import_llamacpp,
+ "textgen": _import_textgen,
+ "minimax": _import_minimax,
+ "mlflow": _import_mlflow,
+ "mlflow-chat": _import_mlflow_chat,
+ "mlflow-ai-gateway": _import_mlflow_ai_gateway,
+ "modal": _import_modal,
+ "mosaic": _import_mosaicml,
+ "nebula": _import_symblai_nebula,
+ "nibittensor": _import_bittensor,
+ "nlpcloud": _import_nlpcloud,
+ "ollama": _import_ollama,
+ "openai": _import_openai,
+ "openlm": _import_openlm,
+ "pai_eas_endpoint": _import_pai_eas_endpoint,
+ "petals": _import_petals,
+ "pipelineai": _import_pipelineai,
+ "predibase": _import_predibase,
+ "opaqueprompts": _import_opaqueprompts,
+ "replicate": _import_replicate,
+ "rwkv": _import_rwkv,
+ "sagemaker_endpoint": _import_sagemaker_endpoint,
+ "self_hosted": _import_self_hosted,
+ "self_hosted_hugging_face": _import_self_hosted_hugging_face,
+ "stochasticai": _import_stochasticai,
+ "together": _import_together,
+ "tongyi": _import_tongyi,
+ "titan_takeoff": _import_titan_takeoff,
+ "titan_takeoff_pro": _import_titan_takeoff_pro,
+ "vertexai": _import_vertex,
+ "vertexai_model_garden": _import_vertex_model_garden,
+ "openllm": _import_openllm,
+ "openllm_client": _import_openllm,
+ "vllm": _import_vllm,
+ "vllm_openai": _import_vllm_openai,
+ "watsonxllm": _import_watsonxllm,
+ "writer": _import_writer,
+ "xinference": _import_xinference,
+ "javelin-ai-gateway": _import_javelin_ai_gateway,
+ "qianfan_endpoint": _import_baidu_qianfan_endpoint,
+ "yandex_gpt": _import_yandex_gpt,
+ "VolcEngineMaasLLM": _import_volcengine_maas,
+ }
diff --git a/libs/community/langchain_community/llms/ai21.py b/libs/community/langchain_community/llms/ai21.py
new file mode 100644
index 00000000000..dd86ba516ae
--- /dev/null
+++ b/libs/community/langchain_community/llms/ai21.py
@@ -0,0 +1,158 @@
+from typing import Any, Dict, List, Optional, cast
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+
+class AI21PenaltyData(BaseModel):
+ """Parameters for AI21 penalty data."""
+
+ scale: int = 0
+ applyToWhitespaces: bool = True
+ applyToPunctuations: bool = True
+ applyToNumbers: bool = True
+ applyToStopwords: bool = True
+ applyToEmojis: bool = True
+
+
+class AI21(LLM):
+ """AI21 large language models.
+
+ To use, you should have the environment variable ``AI21_API_KEY``
+ set with your API key or pass it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import AI21
+ ai21 = AI21(ai21_api_key="my-api-key", model="j2-jumbo-instruct")
+ """
+
+ model: str = "j2-jumbo-instruct"
+ """Model name to use."""
+
+ temperature: float = 0.7
+ """What sampling temperature to use."""
+
+ maxTokens: int = 256
+ """The maximum number of tokens to generate in the completion."""
+
+ minTokens: int = 0
+ """The minimum number of tokens to generate in the completion."""
+
+ topP: float = 1.0
+ """Total probability mass of tokens to consider at each step."""
+
+ presencePenalty: AI21PenaltyData = AI21PenaltyData()
+ """Penalizes repeated tokens."""
+
+ countPenalty: AI21PenaltyData = AI21PenaltyData()
+ """Penalizes repeated tokens according to count."""
+
+ frequencyPenalty: AI21PenaltyData = AI21PenaltyData()
+ """Penalizes repeated tokens according to frequency."""
+
+ numResults: int = 1
+ """How many completions to generate for each prompt."""
+
+ logitBias: Optional[Dict[str, float]] = None
+ """Adjust the probability of specific tokens being generated."""
+
+ ai21_api_key: Optional[SecretStr] = None
+
+ stop: Optional[List[str]] = None
+
+ base_url: Optional[str] = None
+ """Base url to use, if None decides based on model name."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ ai21_api_key = convert_to_secret_str(
+ get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
+ )
+ values["ai21_api_key"] = ai21_api_key
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling AI21 API."""
+ return {
+ "temperature": self.temperature,
+ "maxTokens": self.maxTokens,
+ "minTokens": self.minTokens,
+ "topP": self.topP,
+ "presencePenalty": self.presencePenalty.dict(),
+ "countPenalty": self.countPenalty.dict(),
+ "frequencyPenalty": self.frequencyPenalty.dict(),
+ "numResults": self.numResults,
+ "logitBias": self.logitBias,
+ }
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model": self.model}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "ai21"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to AI21's complete endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = ai21("Tell me a joke.")
+ """
+ if self.stop is not None and stop is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+ elif self.stop is not None:
+ stop = self.stop
+ elif stop is None:
+ stop = []
+ if self.base_url is not None:
+ base_url = self.base_url
+ else:
+ if self.model in ("j1-grande-instruct",):
+ base_url = "https://api.ai21.com/studio/v1/experimental"
+ else:
+ base_url = "https://api.ai21.com/studio/v1"
+ params = {**self._default_params, **kwargs}
+ self.ai21_api_key = cast(SecretStr, self.ai21_api_key)
+ response = requests.post(
+ url=f"{base_url}/{self.model}/complete",
+ headers={"Authorization": f"Bearer {self.ai21_api_key.get_secret_value()}"},
+ json={"prompt": prompt, "stopSequences": stop, **params},
+ )
+ if response.status_code != 200:
+ optional_detail = response.json().get("error")
+ raise ValueError(
+ f"AI21 /complete call failed with status code {response.status_code}."
+ f" Details: {optional_detail}"
+ )
+ response_json = response.json()
+ return response_json["completions"][0]["data"]["text"]
diff --git a/libs/community/langchain_community/llms/aleph_alpha.py b/libs/community/langchain_community/llms/aleph_alpha.py
new file mode 100644
index 00000000000..8ae891024b8
--- /dev/null
+++ b/libs/community/langchain_community/llms/aleph_alpha.py
@@ -0,0 +1,287 @@
+from typing import Any, Dict, List, Optional, Sequence
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class AlephAlpha(LLM):
+ """Aleph Alpha large language models.
+
+ To use, you should have the ``aleph_alpha_client`` python package installed, and the
+ environment variable ``ALEPH_ALPHA_API_KEY`` set with your API key, or pass
+ it as a named parameter to the constructor.
+
+ Parameters are explained more in depth here:
+ https://github.com/Aleph-Alpha/aleph-alpha-client/blob/c14b7dd2b4325c7da0d6a119f6e76385800e097b/aleph_alpha_client/completion.py#L10
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import AlephAlpha
+ aleph_alpha = AlephAlpha(aleph_alpha_api_key="my-api-key")
+ """
+
+ client: Any #: :meta private:
+ model: Optional[str] = "luminous-base"
+ """Model name to use."""
+
+ maximum_tokens: int = 64
+ """The maximum number of tokens to be generated."""
+
+ temperature: float = 0.0
+ """A non-negative float that tunes the degree of randomness in generation."""
+
+ top_k: int = 0
+ """Number of most likely tokens to consider at each step."""
+
+ top_p: float = 0.0
+ """Total probability mass of tokens to consider at each step."""
+
+ presence_penalty: float = 0.0
+ """Penalizes repeated tokens."""
+
+ frequency_penalty: float = 0.0
+ """Penalizes repeated tokens according to frequency."""
+
+ repetition_penalties_include_prompt: Optional[bool] = False
+ """Flag deciding whether presence penalty or frequency penalty are
+ updated from the prompt."""
+
+ use_multiplicative_presence_penalty: Optional[bool] = False
+ """Flag deciding whether presence penalty is applied
+ multiplicatively (True) or additively (False)."""
+
+ penalty_bias: Optional[str] = None
+ """Penalty bias for the completion."""
+
+ penalty_exceptions: Optional[List[str]] = None
+ """List of strings that may be generated without penalty,
+ regardless of other penalty settings"""
+
+ penalty_exceptions_include_stop_sequences: Optional[bool] = None
+ """Should stop_sequences be included in penalty_exceptions."""
+
+ best_of: Optional[int] = None
+ """returns the one with the "best of" results
+ (highest log probability per token)
+ """
+
+ n: int = 1
+ """How many completions to generate for each prompt."""
+
+ logit_bias: Optional[Dict[int, float]] = None
+ """The logit bias allows to influence the likelihood of generating tokens."""
+
+ log_probs: Optional[int] = None
+ """Number of top log probabilities to be returned for each generated token."""
+
+ tokens: Optional[bool] = False
+ """return tokens of completion."""
+
+ disable_optimizations: Optional[bool] = False
+
+ minimum_tokens: Optional[int] = 0
+ """Generate at least this number of tokens."""
+
+ echo: bool = False
+ """Echo the prompt in the completion."""
+
+ use_multiplicative_frequency_penalty: bool = False
+
+ sequence_penalty: float = 0.0
+
+ sequence_penalty_min_length: int = 2
+
+ use_multiplicative_sequence_penalty: bool = False
+
+ completion_bias_inclusion: Optional[Sequence[str]] = None
+
+ completion_bias_inclusion_first_token_only: bool = False
+
+ completion_bias_exclusion: Optional[Sequence[str]] = None
+
+ completion_bias_exclusion_first_token_only: bool = False
+ """Only consider the first token for the completion_bias_exclusion."""
+
+ contextual_control_threshold: Optional[float] = None
+ """If set to None, attention control parameters only apply to those tokens that have
+ explicitly been set in the request.
+ If set to a non-None value, control parameters are also applied to similar tokens.
+ """
+
+ control_log_additive: Optional[bool] = True
+ """True: apply control by adding the log(control_factor) to attention scores.
+ False: (attention_scores - - attention_scores.min(-1)) * control_factor
+ """
+
+ repetition_penalties_include_completion: bool = True
+ """Flag deciding whether presence penalty or frequency penalty
+ are updated from the completion."""
+
+ raw_completion: bool = False
+ """Force the raw completion of the model to be returned."""
+
+ stop_sequences: Optional[List[str]] = None
+ """Stop sequences to use."""
+
+ # Client params
+ aleph_alpha_api_key: Optional[str] = None
+ """API key for Aleph Alpha API."""
+ host: str = "https://api.aleph-alpha.com"
+ """The hostname of the API host.
+ The default one is "https://api.aleph-alpha.com")"""
+ hosting: Optional[str] = None
+ """Determines in which datacenters the request may be processed.
+ You can either set the parameter to "aleph-alpha" or omit it (defaulting to None).
+ Not setting this value, or setting it to None, gives us maximal
+ flexibility in processing your request in our
+ own datacenters and on servers hosted with other providers.
+ Choose this option for maximal availability.
+ Setting it to "aleph-alpha" allows us to only process the
+ request in our own datacenters.
+ Choose this option for maximal data privacy."""
+ request_timeout_seconds: int = 305
+ """Client timeout that will be set for HTTP requests in the
+ `requests` library's API calls.
+ Server will close all requests after 300 seconds with an internal server error."""
+ total_retries: int = 8
+ """The number of retries made in case requests fail with certain retryable
+ status codes. If the last
+ retry fails a corresponding exception is raised. Note, that between retries
+ an exponential backoff
+ is applied, starting with 0.5 s after the first retry and doubling for
+ each retry made. So with the
+ default setting of 8 retries a total wait time of 63.5 s is added
+ between the retries."""
+ nice: bool = False
+ """Setting this to True, will signal to the API that you intend to be
+ nice to other users
+ by de-prioritizing your request below concurrent ones."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["aleph_alpha_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY")
+ )
+ try:
+ from aleph_alpha_client import Client
+
+ values["client"] = Client(
+ token=values["aleph_alpha_api_key"].get_secret_value(),
+ host=values["host"],
+ hosting=values["hosting"],
+ request_timeout_seconds=values["request_timeout_seconds"],
+ total_retries=values["total_retries"],
+ nice=values["nice"],
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import aleph_alpha_client python package. "
+ "Please install it with `pip install aleph_alpha_client`."
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling the Aleph Alpha API."""
+ return {
+ "maximum_tokens": self.maximum_tokens,
+ "temperature": self.temperature,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ "presence_penalty": self.presence_penalty,
+ "frequency_penalty": self.frequency_penalty,
+ "n": self.n,
+ "repetition_penalties_include_prompt": self.repetition_penalties_include_prompt, # noqa: E501
+ "use_multiplicative_presence_penalty": self.use_multiplicative_presence_penalty, # noqa: E501
+ "penalty_bias": self.penalty_bias,
+ "penalty_exceptions": self.penalty_exceptions,
+ "penalty_exceptions_include_stop_sequences": self.penalty_exceptions_include_stop_sequences, # noqa: E501
+ "best_of": self.best_of,
+ "logit_bias": self.logit_bias,
+ "log_probs": self.log_probs,
+ "tokens": self.tokens,
+ "disable_optimizations": self.disable_optimizations,
+ "minimum_tokens": self.minimum_tokens,
+ "echo": self.echo,
+ "use_multiplicative_frequency_penalty": self.use_multiplicative_frequency_penalty, # noqa: E501
+ "sequence_penalty": self.sequence_penalty,
+ "sequence_penalty_min_length": self.sequence_penalty_min_length,
+ "use_multiplicative_sequence_penalty": self.use_multiplicative_sequence_penalty, # noqa: E501
+ "completion_bias_inclusion": self.completion_bias_inclusion,
+ "completion_bias_inclusion_first_token_only": self.completion_bias_inclusion_first_token_only, # noqa: E501
+ "completion_bias_exclusion": self.completion_bias_exclusion,
+ "completion_bias_exclusion_first_token_only": self.completion_bias_exclusion_first_token_only, # noqa: E501
+ "contextual_control_threshold": self.contextual_control_threshold,
+ "control_log_additive": self.control_log_additive,
+ "repetition_penalties_include_completion": self.repetition_penalties_include_completion, # noqa: E501
+ "raw_completion": self.raw_completion,
+ }
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model": self.model}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "aleph_alpha"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Aleph Alpha's completion endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = aleph_alpha("Tell me a joke.")
+ """
+ from aleph_alpha_client import CompletionRequest, Prompt
+
+ params = self._default_params
+ if self.stop_sequences is not None and stop is not None:
+ raise ValueError(
+ "stop sequences found in both the input and default params."
+ )
+ elif self.stop_sequences is not None:
+ params["stop_sequences"] = self.stop_sequences
+ else:
+ params["stop_sequences"] = stop
+ params = {**params, **kwargs}
+ request = CompletionRequest(prompt=Prompt.from_text(prompt), **params)
+ response = self.client.complete(model=self.model, request=request)
+ text = response.completions[0].completion
+ # If stop tokens are provided, Aleph Alpha's endpoint returns them.
+ # In order to make this consistent with other endpoints, we strip them.
+ if stop is not None or self.stop_sequences is not None:
+ text = enforce_stop_tokens(text, params["stop_sequences"])
+ return text
+
+
+if __name__ == "__main__":
+ aa = AlephAlpha()
+
+ print(aa("How are you?"))
diff --git a/libs/community/langchain_community/llms/amazon_api_gateway.py b/libs/community/langchain_community/llms/amazon_api_gateway.py
new file mode 100644
index 00000000000..be2266a331c
--- /dev/null
+++ b/libs/community/langchain_community/llms/amazon_api_gateway.py
@@ -0,0 +1,104 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class ContentHandlerAmazonAPIGateway:
+ """Adapter to prepare the inputs from Langchain to a format
+ that LLM model expects.
+
+ It also provides helper function to extract
+ the generated text from the model response."""
+
+ @classmethod
+ def transform_input(
+ cls, prompt: str, model_kwargs: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ return {"inputs": prompt, "parameters": model_kwargs}
+
+ @classmethod
+ def transform_output(cls, response: Any) -> str:
+ return response.json()[0]["generated_text"]
+
+
+class AmazonAPIGateway(LLM):
+ """Amazon API Gateway to access LLM models hosted on AWS."""
+
+ api_url: str
+ """API Gateway URL"""
+
+ headers: Optional[Dict] = None
+ """API Gateway HTTP Headers to send, e.g. for authentication"""
+
+ model_kwargs: Optional[Dict] = None
+ """Keyword arguments to pass to the model."""
+
+ content_handler: ContentHandlerAmazonAPIGateway = ContentHandlerAmazonAPIGateway()
+ """The content handler class that provides an input and
+ output transform functions to handle formats between LLM
+ and the endpoint.
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"api_url": self.api_url, "headers": self.headers},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "amazon_api_gateway"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Amazon API Gateway model.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = se("Tell me a joke.")
+ """
+ _model_kwargs = self.model_kwargs or {}
+ payload = self.content_handler.transform_input(prompt, _model_kwargs)
+
+ try:
+ response = requests.post(
+ self.api_url,
+ headers=self.headers,
+ json=payload,
+ )
+ text = self.content_handler.transform_output(response)
+
+ except Exception as error:
+ raise ValueError(f"Error raised by the service: {error}")
+
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+
+ return text
diff --git a/libs/community/langchain_community/llms/anthropic.py b/libs/community/langchain_community/llms/anthropic.py
new file mode 100644
index 00000000000..be832cf1368
--- /dev/null
+++ b/libs/community/langchain_community/llms/anthropic.py
@@ -0,0 +1,351 @@
+import re
+import warnings
+from typing import (
+ Any,
+ AsyncIterator,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.prompt_values import PromptValue
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import (
+ check_package_version,
+ get_from_dict_or_env,
+ get_pydantic_field_names,
+)
+from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str
+
+
+class _AnthropicCommon(BaseLanguageModel):
+ client: Any = None #: :meta private:
+ async_client: Any = None #: :meta private:
+ model: str = Field(default="claude-2", alias="model_name")
+ """Model name to use."""
+
+ max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
+ """Denotes the number of tokens to predict per generation."""
+
+ temperature: Optional[float] = None
+ """A non-negative float that tunes the degree of randomness in generation."""
+
+ top_k: Optional[int] = None
+ """Number of most likely tokens to consider at each step."""
+
+ top_p: Optional[float] = None
+ """Total probability mass of tokens to consider at each step."""
+
+ streaming: bool = False
+ """Whether to stream the results."""
+
+ default_request_timeout: Optional[float] = None
+ """Timeout for requests to Anthropic Completion API. Default is 600 seconds."""
+
+ anthropic_api_url: Optional[str] = None
+
+ anthropic_api_key: Optional[SecretStr] = None
+
+ HUMAN_PROMPT: Optional[str] = None
+ AI_PROMPT: Optional[str] = None
+ count_tokens: Optional[Callable[[str], int]] = None
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict) -> Dict:
+ extra = values.get("model_kwargs", {})
+ all_required_field_names = get_pydantic_field_names(cls)
+ values["model_kwargs"] = build_extra_kwargs(
+ extra, values, all_required_field_names
+ )
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["anthropic_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
+ )
+ # Get custom api url from environment.
+ values["anthropic_api_url"] = get_from_dict_or_env(
+ values,
+ "anthropic_api_url",
+ "ANTHROPIC_API_URL",
+ default="https://api.anthropic.com",
+ )
+
+ try:
+ import anthropic
+
+ check_package_version("anthropic", gte_version="0.3")
+ values["client"] = anthropic.Anthropic(
+ base_url=values["anthropic_api_url"],
+ api_key=values["anthropic_api_key"].get_secret_value(),
+ timeout=values["default_request_timeout"],
+ )
+ values["async_client"] = anthropic.AsyncAnthropic(
+ base_url=values["anthropic_api_url"],
+ api_key=values["anthropic_api_key"].get_secret_value(),
+ timeout=values["default_request_timeout"],
+ )
+ values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
+ values["AI_PROMPT"] = anthropic.AI_PROMPT
+ values["count_tokens"] = values["client"].count_tokens
+
+ except ImportError:
+ raise ImportError(
+ "Could not import anthropic python package. "
+ "Please it install it with `pip install anthropic`."
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Mapping[str, Any]:
+ """Get the default parameters for calling Anthropic API."""
+ d = {
+ "max_tokens_to_sample": self.max_tokens_to_sample,
+ "model": self.model,
+ }
+ if self.temperature is not None:
+ d["temperature"] = self.temperature
+ if self.top_k is not None:
+ d["top_k"] = self.top_k
+ if self.top_p is not None:
+ d["top_p"] = self.top_p
+ return {**d, **self.model_kwargs}
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {**{}, **self._default_params}
+
+ def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
+ if not self.HUMAN_PROMPT or not self.AI_PROMPT:
+ raise NameError("Please ensure the anthropic package is loaded")
+
+ if stop is None:
+ stop = []
+
+ # Never want model to invent new turns of Human / Assistant dialog.
+ stop.extend([self.HUMAN_PROMPT])
+
+ return stop
+
+
+class Anthropic(LLM, _AnthropicCommon):
+ """Anthropic large language models.
+
+ To use, you should have the ``anthropic`` python package installed, and the
+ environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ import anthropic
+ from langchain_community.llms import Anthropic
+
+ model = Anthropic(model="", anthropic_api_key="my-api-key")
+
+ # Simplest invocation, automatically wrapped with HUMAN_PROMPT
+ # and AI_PROMPT.
+ response = model("What are the biggest risks facing humanity?")
+
+ # Or if you want to use the chat mode, build a few-shot-prompt, or
+ # put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT:
+ raw_prompt = "What are the biggest risks facing humanity?"
+ prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
+ response = model(prompt)
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+ arbitrary_types_allowed = True
+
+ @root_validator()
+ def raise_warning(cls, values: Dict) -> Dict:
+ """Raise warning that this class is deprecated."""
+ warnings.warn(
+ "This Anthropic LLM is deprecated. "
+ "Please use `from langchain_community.chat_models import ChatAnthropic` "
+ "instead"
+ )
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "anthropic-llm"
+
+ def _wrap_prompt(self, prompt: str) -> str:
+ if not self.HUMAN_PROMPT or not self.AI_PROMPT:
+ raise NameError("Please ensure the anthropic package is loaded")
+
+ if prompt.startswith(self.HUMAN_PROMPT):
+ return prompt # Already wrapped.
+
+ # Guard against common errors in specifying wrong number of newlines.
+ corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt)
+ if n_subs == 1:
+ return corrected_prompt
+
+ # As a last resort, wrap the prompt ourselves to emulate instruct-style.
+ return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ r"""Call out to Anthropic's completion endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ prompt = "What are the biggest risks facing humanity?"
+ prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
+ response = model(prompt)
+
+ """
+ if self.streaming:
+ completion = ""
+ for chunk in self._stream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ completion += chunk.text
+ return completion
+
+ stop = self._get_anthropic_stop(stop)
+ params = {**self._default_params, **kwargs}
+ response = self.client.completions.create(
+ prompt=self._wrap_prompt(prompt),
+ stop_sequences=stop,
+ **params,
+ )
+ return response.completion
+
+ def convert_prompt(self, prompt: PromptValue) -> str:
+ return self._wrap_prompt(prompt.to_string())
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Anthropic's completion endpoint asynchronously."""
+ if self.streaming:
+ completion = ""
+ async for chunk in self._astream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ completion += chunk.text
+ return completion
+
+ stop = self._get_anthropic_stop(stop)
+ params = {**self._default_params, **kwargs}
+
+ response = await self.async_client.completions.create(
+ prompt=self._wrap_prompt(prompt),
+ stop_sequences=stop,
+ **params,
+ )
+ return response.completion
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ r"""Call Anthropic completion_stream and return the resulting generator.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ A generator representing the stream of tokens from Anthropic.
+ Example:
+ .. code-block:: python
+
+ prompt = "Write a poem about a stream."
+ prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
+ generator = anthropic.stream(prompt)
+ for token in generator:
+ yield token
+ """
+ stop = self._get_anthropic_stop(stop)
+ params = {**self._default_params, **kwargs}
+
+ for token in self.client.completions.create(
+ prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
+ ):
+ chunk = GenerationChunk(text=token.completion)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ r"""Call Anthropic completion_stream and return the resulting generator.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ A generator representing the stream of tokens from Anthropic.
+ Example:
+ .. code-block:: python
+ prompt = "Write a poem about a stream."
+ prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
+ generator = anthropic.stream(prompt)
+ for token in generator:
+ yield token
+ """
+ stop = self._get_anthropic_stop(stop)
+ params = {**self._default_params, **kwargs}
+
+ async for token in await self.async_client.completions.create(
+ prompt=self._wrap_prompt(prompt),
+ stop_sequences=stop,
+ stream=True,
+ **params,
+ ):
+ chunk = GenerationChunk(text=token.completion)
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+ def get_num_tokens(self, text: str) -> int:
+ """Calculate number of tokens."""
+ if not self.count_tokens:
+ raise NameError("Please ensure the anthropic package is loaded")
+ return self.count_tokens(text)
diff --git a/libs/community/langchain_community/llms/anyscale.py b/libs/community/langchain_community/llms/anyscale.py
new file mode 100644
index 00000000000..b994c425ab9
--- /dev/null
+++ b/libs/community/langchain_community/llms/anyscale.py
@@ -0,0 +1,279 @@
+"""Wrapper around Anyscale Endpoint"""
+from typing import (
+ Any,
+ AsyncIterator,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.outputs import Generation, GenerationChunk, LLMResult
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.llms.openai import (
+ BaseOpenAI,
+ acompletion_with_retry,
+ completion_with_retry,
+)
+
+
+def update_token_usage(
+ keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
+) -> None:
+ """Update token usage."""
+ _keys_to_use = keys.intersection(response["usage"])
+ for _key in _keys_to_use:
+ if _key not in token_usage:
+ token_usage[_key] = response["usage"][_key]
+ else:
+ token_usage[_key] += response["usage"][_key]
+
+
+def create_llm_result(
+ choices: Any, prompts: List[str], token_usage: Dict[str, int], model_name: str
+) -> LLMResult:
+ """Create the LLMResult from the choices and prompts."""
+ generations = []
+ for i, _ in enumerate(prompts):
+ choice = choices[i]
+ generations.append(
+ [
+ Generation(
+ text=choice["message"]["content"],
+ generation_info=dict(
+ finish_reason=choice.get("finish_reason"),
+ logprobs=choice.get("logprobs"),
+ ),
+ )
+ ]
+ )
+ llm_output = {"token_usage": token_usage, "model_name": model_name}
+ return LLMResult(generations=generations, llm_output=llm_output)
+
+
+class Anyscale(BaseOpenAI):
+ """Anyscale large language models.
+
+ To use, you should have the environment variable ``ANYSCALE_API_BASE`` and
+ ``ANYSCALE_API_KEY``set with your Anyscale Endpoint, or pass it as a named
+ parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+ from langchain_community.llms import Anyscale
+ anyscalellm = Anyscale(anyscale_api_base="ANYSCALE_API_BASE",
+ anyscale_api_key="ANYSCALE_API_KEY",
+ model_name="meta-llama/Llama-2-7b-chat-hf")
+ # To leverage Ray for parallel processing
+ @ray.remote(num_cpus=1)
+ def send_query(llm, text):
+ resp = llm(text)
+ return resp
+ futures = [send_query.remote(anyscalellm, text) for text in texts]
+ results = ray.get(futures)
+ """
+
+ """Key word arguments to pass to the model."""
+ anyscale_api_base: Optional[str] = None
+ anyscale_api_key: Optional[SecretStr] = None
+
+ prefix_messages: List = Field(default_factory=list)
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["anyscale_api_base"] = get_from_dict_or_env(
+ values, "anyscale_api_base", "ANYSCALE_API_BASE"
+ )
+ values["anyscale_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "anyscale_api_key", "ANYSCALE_API_KEY")
+ )
+
+ try:
+ import openai
+
+ ## Always create ChatComplete client, replacing the legacy Complete client
+ values["client"] = openai.ChatCompletion
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ if values["streaming"] and values["n"] > 1:
+ raise ValueError("Cannot stream results when n > 1.")
+ if values["streaming"] and values["best_of"] > 1:
+ raise ValueError("Cannot stream results when best_of > 1.")
+
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"model_name": self.model_name},
+ **super()._identifying_params,
+ }
+
+ @property
+ def _invocation_params(self) -> Dict[str, Any]:
+ """Get the parameters used to invoke the model."""
+ openai_creds: Dict[str, Any] = {
+ "api_key": cast(SecretStr, self.anyscale_api_key).get_secret_value(),
+ "api_base": self.anyscale_api_base,
+ }
+ return {**openai_creds, **{"model": self.model_name}, **super()._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "Anyscale LLM"
+
+ def _get_chat_messages(
+ self, prompts: List[str], stop: Optional[List[str]] = None
+ ) -> Tuple:
+ if len(prompts) > 1:
+ raise ValueError(
+ f"Anyscale currently only supports single prompt, got {prompts}"
+ )
+ messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
+ params: Dict[str, Any] = self._invocation_params
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+ if params.get("max_tokens") == -1:
+ # for Chat api, omitting max_tokens is equivalent to having no limit
+ del params["max_tokens"]
+ return messages, params
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ messages, params = self._get_chat_messages([prompt], stop)
+ params = {**params, **kwargs, "stream": True}
+ for stream_resp in completion_with_retry(
+ self, messages=messages, run_manager=run_manager, **params
+ ):
+ token = stream_resp["choices"][0]["delta"].get("content", "")
+ chunk = GenerationChunk(text=token)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(token, chunk=chunk)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ messages, params = self._get_chat_messages([prompt], stop)
+ params = {**params, **kwargs, "stream": True}
+ async for stream_resp in await acompletion_with_retry(
+ self, messages=messages, run_manager=run_manager, **params
+ ):
+ token = stream_resp["choices"][0]["delta"].get("content", "")
+ chunk = GenerationChunk(text=token)
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(token, chunk=chunk)
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ choices = []
+ token_usage: Dict[str, int] = {}
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
+ for prompt in prompts:
+ if self.streaming:
+ generation: Optional[GenerationChunk] = None
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ choices.append(
+ {
+ "message": {"content": generation.text},
+ "finish_reason": generation.generation_info.get("finish_reason")
+ if generation.generation_info
+ else None,
+ "logprobs": generation.generation_info.get("logprobs")
+ if generation.generation_info
+ else None,
+ }
+ )
+
+ else:
+ messages, params = self._get_chat_messages([prompt], stop)
+ params = {**params, **kwargs}
+ response = completion_with_retry(
+ self, messages=messages, run_manager=run_manager, **params
+ )
+ choices.extend(response["choices"])
+ update_token_usage(_keys, response, token_usage)
+ return create_llm_result(choices, prompts, token_usage, self.model_name)
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ choices = []
+ token_usage: Dict[str, int] = {}
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
+ for prompt in prompts:
+ messages = self.prefix_messages + [{"role": "user", "content": prompt}]
+ if self.streaming:
+ generation: Optional[GenerationChunk] = None
+ async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ choices.append(
+ {
+ "message": {"content": generation.text},
+ "finish_reason": generation.generation_info.get("finish_reason")
+ if generation.generation_info
+ else None,
+ "logprobs": generation.generation_info.get("logprobs")
+ if generation.generation_info
+ else None,
+ }
+ )
+ else:
+ messages, params = self._get_chat_messages([prompt], stop)
+ params = {**params, **kwargs}
+ response = await acompletion_with_retry(
+ self, messages=messages, run_manager=run_manager, **params
+ )
+ choices.extend(response["choices"])
+ update_token_usage(_keys, response, token_usage)
+ return create_llm_result(choices, prompts, token_usage, self.model_name)
diff --git a/libs/community/langchain_community/llms/arcee.py b/libs/community/langchain_community/llms/arcee.py
new file mode 100644
index 00000000000..cab21c60e68
--- /dev/null
+++ b/libs/community/langchain_community/llms/arcee.py
@@ -0,0 +1,147 @@
+from typing import Any, Dict, List, Optional, Union, cast
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter
+
+
+class Arcee(LLM):
+ """Arcee's Domain Adapted Language Models (DALMs).
+
+ To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
+ or pass ``arcee_api_key`` as a named parameter.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Arcee
+
+ arcee = Arcee(
+ model="DALM-PubMed",
+ arcee_api_key="ARCEE-API-KEY"
+ )
+
+ response = arcee("AI-driven music therapy")
+ """
+
+ _client: Optional[ArceeWrapper] = None #: :meta private:
+ """Arcee _client."""
+
+ arcee_api_key: Union[SecretStr, str, None] = None
+ """Arcee API Key"""
+
+ model: str
+ """Arcee DALM name"""
+
+ arcee_api_url: str = "https://api.arcee.ai"
+ """Arcee API URL"""
+
+ arcee_api_version: str = "v2"
+ """Arcee API Version"""
+
+ arcee_app_url: str = "https://app.arcee.ai"
+ """Arcee App URL"""
+
+ model_id: str = ""
+ """Arcee Model ID"""
+
+ model_kwargs: Optional[Dict[str, Any]] = None
+ """Keyword arguments to pass to the model."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ underscore_attrs_are_private = True
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "arcee"
+
+ def __init__(self, **data: Any) -> None:
+ """Initializes private fields."""
+
+ super().__init__(**data)
+ api_key = cast(SecretStr, self.arcee_api_key)
+ self._client = ArceeWrapper(
+ arcee_api_key=api_key,
+ arcee_api_url=self.arcee_api_url,
+ arcee_api_version=self.arcee_api_version,
+ model_kwargs=self.model_kwargs,
+ model_name=self.model,
+ )
+
+ @root_validator(pre=False)
+ def validate_environments(cls, values: Dict) -> Dict:
+ """Validate Arcee environment variables."""
+
+ # validate env vars
+ values["arcee_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "arcee_api_key",
+ "ARCEE_API_KEY",
+ )
+ )
+
+ values["arcee_api_url"] = get_from_dict_or_env(
+ values,
+ "arcee_api_url",
+ "ARCEE_API_URL",
+ )
+
+ values["arcee_app_url"] = get_from_dict_or_env(
+ values,
+ "arcee_app_url",
+ "ARCEE_APP_URL",
+ )
+
+ values["arcee_api_version"] = get_from_dict_or_env(
+ values,
+ "arcee_api_version",
+ "ARCEE_API_VERSION",
+ )
+
+ # validate model kwargs
+ if values.get("model_kwargs"):
+ kw = values["model_kwargs"]
+
+ # validate size
+ if kw.get("size") is not None:
+ if not kw.get("size") >= 0:
+ raise ValueError("`size` must be positive")
+
+ # validate filters
+ if kw.get("filters") is not None:
+ if not isinstance(kw.get("filters"), List):
+ raise ValueError("`filters` must be a list")
+ for f in kw.get("filters"):
+ DALMFilter(**f)
+ return values
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Generate text from Arcee DALM.
+
+ Args:
+ prompt: Prompt to generate text from.
+ size: The max number of context results to retrieve.
+ Defaults to 3. (Can be less if filters are provided).
+ filters: Filters to apply to the context dataset.
+ """
+
+ try:
+ if not self._client:
+ raise ValueError("Client is not initialized.")
+ return self._client.generate(prompt=prompt, **kwargs)
+ except Exception as e:
+ raise Exception(f"Failed to generate text: {e}") from e
diff --git a/libs/community/langchain_community/llms/aviary.py b/libs/community/langchain_community/llms/aviary.py
new file mode 100644
index 00000000000..60c3794422d
--- /dev/null
+++ b/libs/community/langchain_community/llms/aviary.py
@@ -0,0 +1,197 @@
+import dataclasses
+import os
+from typing import Any, Dict, List, Mapping, Optional, Union, cast
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+TIMEOUT = 60
+
+
+@dataclasses.dataclass
+class AviaryBackend:
+ """Aviary backend.
+
+ Attributes:
+ backend_url: The URL for the Aviary backend.
+ bearer: The bearer token for the Aviary backend.
+ """
+
+ backend_url: str
+ bearer: str
+
+ def __post_init__(self) -> None:
+ self.header = {"Authorization": self.bearer}
+
+ @classmethod
+ def from_env(cls) -> "AviaryBackend":
+ aviary_url = os.getenv("AVIARY_URL")
+ assert aviary_url, "AVIARY_URL must be set"
+
+ aviary_token = os.getenv("AVIARY_TOKEN", "")
+
+ bearer = f"Bearer {aviary_token}" if aviary_token else ""
+ aviary_url += "/" if not aviary_url.endswith("/") else ""
+
+ return cls(aviary_url, bearer)
+
+
+def get_models() -> List[str]:
+ """List available models"""
+ backend = AviaryBackend.from_env()
+ request_url = backend.backend_url + "-/routes"
+ response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT)
+ try:
+ result = response.json()
+ except requests.JSONDecodeError as e:
+ raise RuntimeError(
+ f"Error decoding JSON from {request_url}. Text response: {response.text}"
+ ) from e
+ result = sorted(
+ [k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k]
+ )
+ return result
+
+
+def get_completions(
+ model: str,
+ prompt: str,
+ use_prompt_format: bool = True,
+ version: str = "",
+) -> Dict[str, Union[str, float, int]]:
+ """Get completions from Aviary models."""
+
+ backend = AviaryBackend.from_env()
+ url = backend.backend_url + model.replace("/", "--") + "/" + version + "query"
+ response = requests.post(
+ url,
+ headers=backend.header,
+ json={"prompt": prompt, "use_prompt_format": use_prompt_format},
+ timeout=TIMEOUT,
+ )
+ try:
+ return response.json()
+ except requests.JSONDecodeError as e:
+ raise RuntimeError(
+ f"Error decoding JSON from {url}. Text response: {response.text}"
+ ) from e
+
+
+class Aviary(LLM):
+ """Aviary hosted models.
+
+ Aviary is a backend for hosted models. You can
+ find out more about aviary at
+ http://github.com/ray-project/aviary
+
+ To get a list of the models supported on an
+ aviary, follow the instructions on the website to
+ install the aviary CLI and then use:
+ `aviary models`
+
+ AVIARY_URL and AVIARY_TOKEN environment variables must be set.
+
+ Attributes:
+ model: The name of the model to use. Defaults to "amazon/LightGPT".
+ aviary_url: The URL for the Aviary backend. Defaults to None.
+ aviary_token: The bearer token for the Aviary backend. Defaults to None.
+ use_prompt_format: If True, the prompt template for the model will be ignored.
+ Defaults to True.
+ version: API version to use for Aviary. Defaults to None.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Aviary
+ os.environ["AVIARY_URL"] = ""
+ os.environ["AVIARY_TOKEN"] = ""
+ light = Aviary(model='amazon/LightGPT')
+ output = light('How do you make fried rice?')
+ """
+
+ model: str = "amazon/LightGPT"
+ aviary_url: Optional[str] = None
+ aviary_token: Optional[str] = None
+ # If True the prompt template for the model will be ignored.
+ use_prompt_format: bool = True
+ # API version to use for Aviary
+ version: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
+ aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN")
+
+ # Set env viarables for aviary sdk
+ os.environ["AVIARY_URL"] = aviary_url
+ os.environ["AVIARY_TOKEN"] = aviary_token
+
+ try:
+ aviary_models = get_models()
+ except requests.exceptions.RequestException as e:
+ raise ValueError(e)
+
+ model = values.get("model")
+ if model and model not in aviary_models:
+ raise ValueError(f"{aviary_url} does not support model {values['model']}.")
+
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model_name": self.model,
+ "aviary_url": self.aviary_url,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return f"aviary-{self.model.replace('/', '-')}"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Aviary
+ Args:
+ prompt: The prompt to pass into the model.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = aviary("Tell me a joke.")
+ """
+ kwargs = {"use_prompt_format": self.use_prompt_format}
+ if self.version:
+ kwargs["version"] = self.version
+
+ output = get_completions(
+ model=self.model,
+ prompt=prompt,
+ **kwargs,
+ )
+
+ text = cast(str, output["generated_text"])
+ if stop:
+ text = enforce_stop_tokens(text, stop)
+
+ return text
diff --git a/libs/community/langchain_community/llms/azureml_endpoint.py b/libs/community/langchain_community/llms/azureml_endpoint.py
new file mode 100644
index 00000000000..c9e73df6c63
--- /dev/null
+++ b/libs/community/langchain_community/llms/azureml_endpoint.py
@@ -0,0 +1,291 @@
+import json
+import urllib.request
+import warnings
+from abc import abstractmethod
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import BaseModel, validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class AzureMLEndpointClient(object):
+ """AzureML Managed Endpoint client."""
+
+ def __init__(
+ self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = ""
+ ) -> None:
+ """Initialize the class."""
+ if not endpoint_api_key or not endpoint_url:
+ raise ValueError(
+ """A key/token and REST endpoint should
+ be provided to invoke the endpoint"""
+ )
+ self.endpoint_url = endpoint_url
+ self.endpoint_api_key = endpoint_api_key
+ self.deployment_name = deployment_name
+
+ def call(self, body: bytes, **kwargs: Any) -> bytes:
+ """call."""
+
+ # The azureml-model-deployment header will force the request to go to a
+ # specific deployment. Remove this header to have the request observe the
+ # endpoint traffic rules.
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": ("Bearer " + self.endpoint_api_key),
+ }
+ if self.deployment_name != "":
+ headers["azureml-model-deployment"] = self.deployment_name
+
+ req = urllib.request.Request(self.endpoint_url, body, headers)
+ response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50))
+ result = response.read()
+ return result
+
+
+class ContentFormatterBase:
+ """Transform request and response of AzureML endpoint to match with
+ required schema.
+ """
+
+ """
+ Example:
+ .. code-block:: python
+
+ class ContentFormatter(ContentFormatterBase):
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def format_request_payload(
+ self,
+ prompt: str,
+ model_kwargs: Dict
+ ) -> bytes:
+ input_str = json.dumps(
+ {
+ "inputs": {"input_string": [prompt]},
+ "parameters": model_kwargs,
+ }
+ )
+ return str.encode(input_str)
+
+ def format_response_payload(self, output: str) -> str:
+ response_json = json.loads(output)
+ return response_json[0]["0"]
+ """
+ content_type: Optional[str] = "application/json"
+ """The MIME type of the input data passed to the endpoint"""
+
+ accepts: Optional[str] = "application/json"
+ """The MIME type of the response data returned from the endpoint"""
+
+ @staticmethod
+ def escape_special_characters(prompt: str) -> str:
+ """Escapes any special characters in `prompt`"""
+ escape_map = {
+ "\\": "\\\\",
+ '"': '\\"',
+ "\b": "\\b",
+ "\f": "\\f",
+ "\n": "\\n",
+ "\r": "\\r",
+ "\t": "\\t",
+ }
+
+ # Replace each occurrence of the specified characters with escaped versions
+ for escape_sequence, escaped_sequence in escape_map.items():
+ prompt = prompt.replace(escape_sequence, escaped_sequence)
+
+ return prompt
+
+ @abstractmethod
+ def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
+ """Formats the request body according to the input schema of
+ the model. Returns bytes or seekable file like object in the
+ format specified in the content_type request header.
+ """
+
+ @abstractmethod
+ def format_response_payload(self, output: bytes) -> str:
+ """Formats the response body according to the output
+ schema of the model. Returns the data type that is
+ received from the response.
+ """
+
+
+class GPT2ContentFormatter(ContentFormatterBase):
+ """Content handler for GPT2"""
+
+ def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
+ prompt = ContentFormatterBase.escape_special_characters(prompt)
+ request_payload = json.dumps(
+ {"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
+ )
+ return str.encode(request_payload)
+
+ def format_response_payload(self, output: bytes) -> str:
+ return json.loads(output)[0]["0"]
+
+
+class OSSContentFormatter(GPT2ContentFormatter):
+ """Deprecated: Kept for backwards compatibility
+
+ Content handler for LLMs from the OSS catalog."""
+
+ content_formatter: Any = None
+
+ def __init__(self) -> None:
+ super().__init__()
+ warnings.warn(
+ """`OSSContentFormatter` will be deprecated in the future.
+ Please use `GPT2ContentFormatter` instead.
+ """
+ )
+
+
+class HFContentFormatter(ContentFormatterBase):
+ """Content handler for LLMs from the HuggingFace catalog."""
+
+ def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
+ ContentFormatterBase.escape_special_characters(prompt)
+ request_payload = json.dumps(
+ {"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
+ )
+ return str.encode(request_payload)
+
+ def format_response_payload(self, output: bytes) -> str:
+ return json.loads(output)[0]["generated_text"]
+
+
+class DollyContentFormatter(ContentFormatterBase):
+ """Content handler for the Dolly-v2-12b model"""
+
+ def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
+ prompt = ContentFormatterBase.escape_special_characters(prompt)
+ request_payload = json.dumps(
+ {
+ "input_data": {"input_string": [f'"{prompt}"']},
+ "parameters": model_kwargs,
+ }
+ )
+ return str.encode(request_payload)
+
+ def format_response_payload(self, output: bytes) -> str:
+ return json.loads(output)[0]
+
+
+class LlamaContentFormatter(ContentFormatterBase):
+ """Content formatter for LLaMa"""
+
+ def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
+ """Formats the request according to the chosen api"""
+ prompt = ContentFormatterBase.escape_special_characters(prompt)
+ request_payload = json.dumps(
+ {
+ "input_data": {
+ "input_string": [f'"{prompt}"'],
+ "parameters": model_kwargs,
+ }
+ }
+ )
+ return str.encode(request_payload)
+
+ def format_response_payload(self, output: bytes) -> str:
+ """Formats response"""
+ return json.loads(output)[0]["0"]
+
+
+class AzureMLOnlineEndpoint(LLM, BaseModel):
+ """Azure ML Online Endpoint models.
+
+ Example:
+ .. code-block:: python
+
+ azure_llm = AzureMLOnlineEndpoint(
+ endpoint_url="https://..inference.ml.azure.com/score",
+ endpoint_api_key="my-api-key",
+ content_formatter=content_formatter,
+ )
+ """ # noqa: E501
+
+ endpoint_url: str = ""
+ """URL of pre-existing Endpoint. Should be passed to constructor or specified as
+ env var `AZUREML_ENDPOINT_URL`."""
+
+ endpoint_api_key: str = ""
+ """Authentication Key for Endpoint. Should be passed to constructor or specified as
+ env var `AZUREML_ENDPOINT_API_KEY`."""
+
+ deployment_name: str = ""
+ """Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
+ to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
+
+ http_client: Any = None #: :meta private:
+
+ content_formatter: Any = None
+ """The content formatter that provides an input and output
+ transform function to handle formats between the LLM and
+ the endpoint"""
+
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+
+ @validator("http_client", always=True, allow_reuse=True)
+ @classmethod
+ def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
+ """Validate that api key and python package exists in environment."""
+ endpoint_key = get_from_dict_or_env(
+ values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
+ )
+ endpoint_url = get_from_dict_or_env(
+ values, "endpoint_url", "AZUREML_ENDPOINT_URL"
+ )
+ deployment_name = get_from_dict_or_env(
+ values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
+ )
+ http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
+ return http_client
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"deployment_name": self.deployment_name},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "azureml_endpoint"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to an AzureML Managed Online endpoint.
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ The string generated by the model.
+ Example:
+ .. code-block:: python
+ response = azureml_model("Tell me a joke.")
+ """
+ _model_kwargs = self.model_kwargs or {}
+
+ request_payload = self.content_formatter.format_request_payload(
+ prompt, _model_kwargs
+ )
+ response_payload = self.http_client.call(request_payload, **kwargs)
+ generated_text = self.content_formatter.format_response_payload(
+ response_payload
+ )
+ return generated_text
diff --git a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py
new file mode 100644
index 00000000000..09d765de9d4
--- /dev/null
+++ b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py
@@ -0,0 +1,220 @@
+from __future__ import annotations
+
+import logging
+from typing import (
+ Any,
+ AsyncIterator,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class QianfanLLMEndpoint(LLM):
+ """Baidu Qianfan hosted open source or customized models.
+
+ To use, you should have the ``qianfan`` python package installed, and
+ the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with
+ your API key and Secret Key.
+
+ ak, sk are required parameters which you could get from
+ https://cloud.baidu.com/product/wenxinworkshop
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import QianfanLLMEndpoint
+ qianfan_model = QianfanLLMEndpoint(model="ERNIE-Bot",
+ endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
+ """
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+
+ client: Any
+
+ qianfan_ak: Optional[str] = None
+ qianfan_sk: Optional[str] = None
+
+ streaming: Optional[bool] = False
+ """Whether to stream the results or not."""
+
+ model: str = "ERNIE-Bot-turbo"
+ """Model name.
+ you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
+
+ preset models are mapping to an endpoint.
+ `model` will be ignored if `endpoint` is set
+ """
+
+ endpoint: Optional[str] = None
+ """Endpoint of the Qianfan LLM, required if custom model used."""
+
+ request_timeout: Optional[int] = 60
+ """request timeout for chat http requests"""
+
+ top_p: Optional[float] = 0.8
+ temperature: Optional[float] = 0.95
+ penalty_score: Optional[float] = 1
+ """Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
+ In the case of other model, passing these params will not affect the result.
+ """
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ values["qianfan_ak"] = get_from_dict_or_env(
+ values,
+ "qianfan_ak",
+ "QIANFAN_AK",
+ )
+ values["qianfan_sk"] = get_from_dict_or_env(
+ values,
+ "qianfan_sk",
+ "QIANFAN_SK",
+ )
+
+ params = {
+ "ak": values["qianfan_ak"],
+ "sk": values["qianfan_sk"],
+ "model": values["model"],
+ }
+ if values["endpoint"] is not None and values["endpoint"] != "":
+ params["endpoint"] = values["endpoint"]
+ try:
+ import qianfan
+
+ values["client"] = qianfan.Completion(**params)
+ except ImportError:
+ raise ImportError(
+ "qianfan package not found, please install it with "
+ "`pip install qianfan`"
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return {
+ **{"endpoint": self.endpoint, "model": self.model},
+ **super()._identifying_params,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "baidu-qianfan-endpoint"
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Qianfan API."""
+ normal_params = {
+ "model": self.model,
+ "endpoint": self.endpoint,
+ "stream": self.streaming,
+ "request_timeout": self.request_timeout,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ "penalty_score": self.penalty_score,
+ }
+
+ return {**normal_params, **self.model_kwargs}
+
+ def _convert_prompt_msg_params(
+ self,
+ prompt: str,
+ **kwargs: Any,
+ ) -> dict:
+ if "streaming" in kwargs:
+ kwargs["stream"] = kwargs.pop("streaming")
+ return {
+ **{"prompt": prompt, "model": self.model},
+ **self._default_params,
+ **kwargs,
+ }
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to an qianfan models endpoint for each generation with a prompt.
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+ response = qianfan_model("Tell me a joke.")
+ """
+ if self.streaming:
+ completion = ""
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
+ completion += chunk.text
+ return completion
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
+ response_payload = self.client.do(**params)
+
+ return response_payload["result"]
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ if self.streaming:
+ completion = ""
+ async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
+ completion += chunk.text
+ return completion
+
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
+ response_payload = await self.client.ado(**params)
+
+ return response_payload["result"]
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
+ for res in self.client.do(**params):
+ if res:
+ chunk = GenerationChunk(text=res["result"])
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
+ async for res in await self.client.ado(**params):
+ if res:
+ chunk = GenerationChunk(text=res["result"])
+
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.text)
diff --git a/libs/community/langchain_community/llms/bananadev.py b/libs/community/langchain_community/llms/bananadev.py
new file mode 100644
index 00000000000..88ab7f5e58c
--- /dev/null
+++ b/libs/community/langchain_community/llms/bananadev.py
@@ -0,0 +1,136 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class Banana(LLM):
+ """Banana large language models.
+
+ To use, you should have the ``banana-dev`` python package installed,
+ and the environment variable ``BANANA_API_KEY`` set with your API key.
+ This is the team API key available in the Banana dashboard.
+
+ Any parameters that are valid to be passed to the call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Banana
+ banana = Banana(model_key="", model_url_slug="")
+ """
+
+ model_key: str = ""
+ """model key to use"""
+
+ model_url_slug: str = ""
+ """model endpoint to use"""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not
+ explicitly specified."""
+
+ banana_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic config."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""{field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ banana_api_key = get_from_dict_or_env(
+ values, "banana_api_key", "BANANA_API_KEY"
+ )
+ values["banana_api_key"] = banana_api_key
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"model_key": self.model_key},
+ **{"model_url_slug": self.model_url_slug},
+ **{"model_kwargs": self.model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "bananadev"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call to Banana endpoint."""
+ try:
+ from banana_dev import Client
+ except ImportError:
+ raise ImportError(
+ "Could not import banana-dev python package. "
+ "Please install it with `pip install banana-dev`."
+ )
+ params = self.model_kwargs or {}
+ params = {**params, **kwargs}
+ api_key = self.banana_api_key
+ model_key = self.model_key
+ model_url_slug = self.model_url_slug
+ model_inputs = {
+ # a json specific to your model.
+ "prompt": prompt,
+ **params,
+ }
+ model = Client(
+ # Found in main dashboard
+ api_key=api_key,
+ # Both found in model details page
+ model_key=model_key,
+ url=f"https://{model_url_slug}.run.banana.dev",
+ )
+ response, meta = model.call("/", model_inputs)
+ try:
+ text = response["outputs"]
+ except (KeyError, TypeError):
+ raise ValueError(
+ "Response should be of schema: {'outputs': 'text'}."
+ "\nTo fix this:"
+ "\n- fork the source repo of the Banana model"
+ "\n- modify app.py to return the above schema"
+ "\n- deploy that as a custom repo"
+ )
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the model parameters
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/baseten.py b/libs/community/langchain_community/llms/baseten.py
new file mode 100644
index 00000000000..9b3f70b7744
--- /dev/null
+++ b/libs/community/langchain_community/llms/baseten.py
@@ -0,0 +1,73 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Field
+
+logger = logging.getLogger(__name__)
+
+
+class Baseten(LLM):
+ """Baseten models.
+
+ To use, you should have the ``baseten`` python package installed,
+ and run ``baseten.login()`` with your Baseten API key.
+
+ The required ``model`` param can be either a model id or model
+ version id. Using a model version ID will result in
+ slightly faster invocation.
+ Any other model parameters can also
+ be passed in with the format input={model_param: value, ...}
+
+ The Baseten model must accept a dictionary of input with the key
+ "prompt" and return a dictionary with a key "data" which maps
+ to a list of response strings.
+
+ Example:
+ .. code-block:: python
+ from langchain_community.llms import Baseten
+ my_model = Baseten(model="MODEL_ID")
+ output = my_model("prompt")
+ """
+
+ model: str
+ input: Dict[str, Any] = Field(default_factory=dict)
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"model_kwargs": self.model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of model."""
+ return "baseten"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call to Baseten deployed model endpoint."""
+ try:
+ import baseten
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import Baseten Python package. "
+ "Please install it with `pip install baseten`."
+ ) from exc
+
+ # get the model and version
+ try:
+ model = baseten.deployed_model_version_id(self.model)
+ response = model.predict({"prompt": prompt, **kwargs})
+ except baseten.common.core.ApiError:
+ model = baseten.deployed_model_id(self.model)
+ response = model.predict({"prompt": prompt, **kwargs})
+ return "".join(response)
diff --git a/libs/community/langchain_community/llms/beam.py b/libs/community/langchain_community/llms/beam.py
new file mode 100644
index 00000000000..dfdb4375359
--- /dev/null
+++ b/libs/community/langchain_community/llms/beam.py
@@ -0,0 +1,272 @@
+import base64
+import json
+import logging
+import subprocess
+import textwrap
+import time
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_NUM_TRIES = 10
+DEFAULT_SLEEP_TIME = 4
+
+
+class Beam(LLM):
+ """Beam API for gpt2 large language model.
+
+ To use, you should have the ``beam-sdk`` python package installed,
+ and the environment variable ``BEAM_CLIENT_ID`` set with your client id
+ and ``BEAM_CLIENT_SECRET`` set with your client secret. Information on how
+ to get this is available here: https://docs.beam.cloud/account/api-keys.
+
+ The wrapper can then be called as follows, where the name, cpu, memory, gpu,
+ python version, and python packages can be updated accordingly. Once deployed,
+ the instance can be called.
+
+ Example:
+ .. code-block:: python
+
+ llm = Beam(model_name="gpt2",
+ name="langchain-gpt2",
+ cpu=8,
+ memory="32Gi",
+ gpu="A10G",
+ python_version="python3.8",
+ python_packages=[
+ "diffusers[torch]>=0.10",
+ "transformers",
+ "torch",
+ "pillow",
+ "accelerate",
+ "safetensors",
+ "xformers",],
+ max_length=50)
+ llm._deploy()
+ call_result = llm._call(input)
+
+ """
+
+ model_name: str = ""
+ name: str = ""
+ cpu: str = ""
+ memory: str = ""
+ gpu: str = ""
+ python_version: str = ""
+ python_packages: List[str] = []
+ max_length: str = ""
+ url: str = ""
+ """model endpoint to use"""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not
+ explicitly specified."""
+
+ beam_client_id: str = ""
+ beam_client_secret: str = ""
+ app_id: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic config."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""{field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ beam_client_id = get_from_dict_or_env(
+ values, "beam_client_id", "BEAM_CLIENT_ID"
+ )
+ beam_client_secret = get_from_dict_or_env(
+ values, "beam_client_secret", "BEAM_CLIENT_SECRET"
+ )
+ values["beam_client_id"] = beam_client_id
+ values["beam_client_secret"] = beam_client_secret
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model_name": self.model_name,
+ "name": self.name,
+ "cpu": self.cpu,
+ "memory": self.memory,
+ "gpu": self.gpu,
+ "python_version": self.python_version,
+ "python_packages": self.python_packages,
+ "max_length": self.max_length,
+ "model_kwargs": self.model_kwargs,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "beam"
+
+ def app_creation(self) -> None:
+ """Creates a Python file which will contain your Beam app definition."""
+ script = textwrap.dedent(
+ """\
+ import beam
+
+ # The environment your code will run on
+ app = beam.App(
+ name="{name}",
+ cpu={cpu},
+ memory="{memory}",
+ gpu="{gpu}",
+ python_version="{python_version}",
+ python_packages={python_packages},
+ )
+
+ app.Trigger.RestAPI(
+ inputs={{"prompt": beam.Types.String(), "max_length": beam.Types.String()}},
+ outputs={{"text": beam.Types.String()}},
+ handler="run.py:beam_langchain",
+ )
+
+ """
+ )
+
+ script_name = "app.py"
+ with open(script_name, "w") as file:
+ file.write(
+ script.format(
+ name=self.name,
+ cpu=self.cpu,
+ memory=self.memory,
+ gpu=self.gpu,
+ python_version=self.python_version,
+ python_packages=self.python_packages,
+ )
+ )
+
+ def run_creation(self) -> None:
+ """Creates a Python file which will be deployed on beam."""
+ script = textwrap.dedent(
+ """
+ import os
+ import transformers
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
+
+ model_name = "{model_name}"
+
+ def beam_langchain(**inputs):
+ prompt = inputs["prompt"]
+ length = inputs["max_length"]
+
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
+ model = GPT2LMHeadModel.from_pretrained(model_name)
+ encodedPrompt = tokenizer.encode(prompt, return_tensors='pt')
+ outputs = model.generate(encodedPrompt, max_length=int(length),
+ do_sample=True, pad_token_id=tokenizer.eos_token_id)
+ output = tokenizer.decode(outputs[0], skip_special_tokens=True)
+
+ print(output)
+ return {{"text": output}}
+
+ """
+ )
+
+ script_name = "run.py"
+ with open(script_name, "w") as file:
+ file.write(script.format(model_name=self.model_name))
+
+ def _deploy(self) -> str:
+ """Call to Beam."""
+ try:
+ import beam # type: ignore
+
+ if beam.__path__ == "":
+ raise ImportError
+ except ImportError:
+ raise ImportError(
+ "Could not import beam python package. "
+ "Please install it with `curl "
+ "https://raw.githubusercontent.com/slai-labs"
+ "/get-beam/main/get-beam.sh -sSfL | sh`."
+ )
+ self.app_creation()
+ self.run_creation()
+
+ process = subprocess.run(
+ "beam deploy app.py", shell=True, capture_output=True, text=True
+ )
+
+ if process.returncode == 0:
+ output = process.stdout
+ logger.info(output)
+ lines = output.split("\n")
+
+ for line in lines:
+ if line.startswith(" i Send requests to: https://apps.beam.cloud/"):
+ self.app_id = line.split("/")[-1]
+ self.url = line.split(":")[1].strip()
+ return self.app_id
+
+ raise ValueError(
+ f"""Failed to retrieve the appID from the deployment output.
+ Deployment output: {output}"""
+ )
+ else:
+ raise ValueError(f"Deployment failed. Error: {process.stderr}")
+
+ @property
+ def authorization(self) -> str:
+ if self.beam_client_id:
+ credential_str = self.beam_client_id + ":" + self.beam_client_secret
+ else:
+ credential_str = self.beam_client_secret
+ return base64.b64encode(credential_str.encode()).decode()
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[list] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call to Beam."""
+ url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url
+ payload = {"prompt": prompt, "max_length": self.max_length}
+ payload.update(kwargs)
+ headers = {
+ "Accept": "*/*",
+ "Accept-Encoding": "gzip, deflate",
+ "Authorization": "Basic " + self.authorization,
+ "Connection": "keep-alive",
+ "Content-Type": "application/json",
+ }
+
+ for _ in range(DEFAULT_NUM_TRIES):
+ request = requests.post(url, headers=headers, data=json.dumps(payload))
+ if request.status_code == 200:
+ return request.json()["text"]
+ time.sleep(DEFAULT_SLEEP_TIME)
+ logger.warning("Unable to successfully call model.")
+ return ""
diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py
new file mode 100644
index 00000000000..dca8cd5f8fd
--- /dev/null
+++ b/libs/community/langchain_community/llms/bedrock.py
@@ -0,0 +1,448 @@
+import json
+import warnings
+from abc import ABC
+from typing import Any, Dict, Iterator, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+from langchain_community.utilities.anthropic import (
+ get_num_tokens_anthropic,
+ get_token_ids_anthropic,
+)
+
+HUMAN_PROMPT = "\n\nHuman:"
+ASSISTANT_PROMPT = "\n\nAssistant:"
+ALTERNATION_ERROR = (
+ "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'."
+)
+
+
+def _add_newlines_before_ha(input_text: str) -> str:
+ new_text = input_text
+ for word in ["Human:", "Assistant:"]:
+ new_text = new_text.replace(word, "\n\n" + word)
+ for i in range(2):
+ new_text = new_text.replace("\n\n\n" + word, "\n\n" + word)
+ return new_text
+
+
+def _human_assistant_format(input_text: str) -> str:
+ if input_text.count("Human:") == 0 or (
+ input_text.find("Human:") > input_text.find("Assistant:")
+ and "Assistant:" in input_text
+ ):
+ input_text = HUMAN_PROMPT + " " + input_text # SILENT CORRECTION
+ if input_text.count("Assistant:") == 0:
+ input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION
+ if input_text[: len("Human:")] == "Human:":
+ input_text = "\n\n" + input_text
+ input_text = _add_newlines_before_ha(input_text)
+ count = 0
+ # track alternation
+ for i in range(len(input_text)):
+ if input_text[i : i + len(HUMAN_PROMPT)] == HUMAN_PROMPT:
+ if count % 2 == 0:
+ count += 1
+ else:
+ warnings.warn(ALTERNATION_ERROR + f" Received {input_text}")
+ if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT:
+ if count % 2 == 1:
+ count += 1
+ else:
+ warnings.warn(ALTERNATION_ERROR + f" Received {input_text}")
+
+ if count % 2 == 1: # Only saw Human, no Assistant
+ input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION
+
+ return input_text
+
+
+class LLMInputOutputAdapter:
+ """Adapter class to prepare the inputs from Langchain to a format
+ that LLM model expects.
+
+ It also provides helper function to extract
+ the generated text from the model response."""
+
+ provider_to_output_key_map = {
+ "anthropic": "completion",
+ "amazon": "outputText",
+ "cohere": "text",
+ "meta": "generation",
+ }
+
+ @classmethod
+ def prepare_input(
+ cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ input_body = {**model_kwargs}
+ if provider == "anthropic":
+ input_body["prompt"] = _human_assistant_format(prompt)
+ elif provider in ("ai21", "cohere", "meta"):
+ input_body["prompt"] = prompt
+ elif provider == "amazon":
+ input_body = dict()
+ input_body["inputText"] = prompt
+ input_body["textGenerationConfig"] = {**model_kwargs}
+ else:
+ input_body["inputText"] = prompt
+
+ if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
+ input_body["max_tokens_to_sample"] = 256
+
+ return input_body
+
+ @classmethod
+ def prepare_output(cls, provider: str, response: Any) -> str:
+ if provider == "anthropic":
+ response_body = json.loads(response.get("body").read().decode())
+ return response_body.get("completion")
+ else:
+ response_body = json.loads(response.get("body").read())
+
+ if provider == "ai21":
+ return response_body.get("completions")[0].get("data").get("text")
+ elif provider == "cohere":
+ return response_body.get("generations")[0].get("text")
+ elif provider == "meta":
+ return response_body.get("generation")
+ else:
+ return response_body.get("results")[0].get("outputText")
+
+ @classmethod
+ def prepare_output_stream(
+ cls, provider: str, response: Any, stop: Optional[List[str]] = None
+ ) -> Iterator[GenerationChunk]:
+ stream = response.get("body")
+
+ if not stream:
+ return
+
+ if provider not in cls.provider_to_output_key_map:
+ raise ValueError(
+ f"Unknown streaming response output key for provider: {provider}"
+ )
+
+ for event in stream:
+ chunk = event.get("chunk")
+ if chunk:
+ chunk_obj = json.loads(chunk.get("bytes").decode())
+ if provider == "cohere" and (
+ chunk_obj["is_finished"]
+ or chunk_obj[cls.provider_to_output_key_map[provider]]
+ == ""
+ ):
+ return
+
+ # chunk obj format varies with provider
+ yield GenerationChunk(
+ text=chunk_obj[cls.provider_to_output_key_map[provider]]
+ )
+
+
+class BedrockBase(BaseModel, ABC):
+ """Base class for Bedrock models."""
+
+ client: Any = Field(exclude=True) #: :meta private:
+
+ region_name: Optional[str] = None
+ """The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
+ or region specified in ~/.aws/config in case it is not provided here.
+ """
+
+ credentials_profile_name: Optional[str] = Field(default=None, exclude=True)
+ """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
+ has either access keys or role information specified.
+ If not specified, the default credential profile or, if on an EC2 instance,
+ credentials from IMDS will be used.
+ See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+ """
+
+ model_id: str
+ """Id of the model to call, e.g., amazon.titan-text-express-v1, this is
+ equivalent to the modelId property in the list-foundation-models api"""
+
+ model_kwargs: Optional[Dict] = None
+ """Keyword arguments to pass to the model."""
+
+ endpoint_url: Optional[str] = None
+ """Needed if you don't want to default to us-east-1 endpoint"""
+
+ streaming: bool = False
+ """Whether to stream the results."""
+
+ provider_stop_sequence_key_name_map: Mapping[str, str] = {
+ "anthropic": "stop_sequences",
+ "amazon": "stopSequences",
+ "ai21": "stop_sequences",
+ "cohere": "stop_sequences",
+ }
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that AWS credentials to and python package exists in environment."""
+
+ # Skip creating new client if passed in constructor
+ if values["client"] is not None:
+ return values
+
+ try:
+ import boto3
+
+ if values["credentials_profile_name"] is not None:
+ session = boto3.Session(profile_name=values["credentials_profile_name"])
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ values["region_name"] = get_from_dict_or_env(
+ values,
+ "region_name",
+ "AWS_DEFAULT_REGION",
+ default=session.region_name,
+ )
+
+ client_params = {}
+ if values["region_name"]:
+ client_params["region_name"] = values["region_name"]
+ if values["endpoint_url"]:
+ client_params["endpoint_url"] = values["endpoint_url"]
+
+ values["client"] = session.client("bedrock-runtime", **client_params)
+
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ except Exception as e:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ def _get_provider(self) -> str:
+ return self.model_id.split(".")[0]
+
+ @property
+ def _model_is_anthropic(self) -> bool:
+ return self._get_provider() == "anthropic"
+
+ def _prepare_input_and_invoke(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ _model_kwargs = self.model_kwargs or {}
+
+ provider = self._get_provider()
+ params = {**_model_kwargs, **kwargs}
+ input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
+ body = json.dumps(input_body)
+ accept = "application/json"
+ contentType = "application/json"
+
+ try:
+ response = self.client.invoke_model(
+ body=body, modelId=self.model_id, accept=accept, contentType=contentType
+ )
+ text = LLMInputOutputAdapter.prepare_output(provider, response)
+
+ except Exception as e:
+ raise ValueError(f"Error raised by bedrock service: {e}")
+
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+
+ return text
+
+ def _prepare_input_and_invoke_stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ _model_kwargs = self.model_kwargs or {}
+ provider = self._get_provider()
+
+ if stop:
+ if provider not in self.provider_stop_sequence_key_name_map:
+ raise ValueError(
+ f"Stop sequence key name for {provider} is not supported."
+ )
+
+ # stop sequence from _generate() overrides
+ # stop sequences in the class attribute
+ _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
+
+ if provider == "cohere":
+ _model_kwargs["stream"] = True
+
+ params = {**_model_kwargs, **kwargs}
+ input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
+ body = json.dumps(input_body)
+
+ try:
+ response = self.client.invoke_model_with_response_stream(
+ body=body,
+ modelId=self.model_id,
+ accept="application/json",
+ contentType="application/json",
+ )
+ except Exception as e:
+ raise ValueError(f"Error raised by bedrock service: {e}")
+
+ for chunk in LLMInputOutputAdapter.prepare_output_stream(
+ provider, response, stop
+ ):
+ yield chunk
+ if run_manager is not None:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+
+class Bedrock(LLM, BedrockBase):
+ """Bedrock models.
+
+ To authenticate, the AWS client uses the following methods to
+ automatically load credentials:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+
+ If a specific credential profile should be used, you must pass
+ the name of the profile from the ~/.aws/credentials file that is to be used.
+
+ Make sure the credentials / roles used have the required policies to
+ access the Bedrock service.
+ """
+
+ """
+ Example:
+ .. code-block:: python
+
+ from bedrock_langchain.bedrock_llm import BedrockLLM
+
+ llm = BedrockLLM(
+ credentials_profile_name="default",
+ model_id="amazon.titan-text-express-v1",
+ streaming=True
+ )
+
+ """
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "amazon_bedrock"
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this model can be serialized by Langchain."""
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "llms", "bedrock"]
+
+ @property
+ def lc_attributes(self) -> Dict[str, Any]:
+ attributes: Dict[str, Any] = {}
+
+ if self.region_name:
+ attributes["region_name"] = self.region_name
+
+ return attributes
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ """Call out to Bedrock service with streaming.
+
+ Args:
+ prompt (str): The prompt to pass into the model
+ stop (Optional[List[str]], optional): Stop sequences. These will
+ override any stop sequences in the `model_kwargs` attribute.
+ Defaults to None.
+ run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
+ run managers used to process the output. Defaults to None.
+
+ Returns:
+ Iterator[GenerationChunk]: Generator that yields the streamed responses.
+
+ Yields:
+ Iterator[GenerationChunk]: Responses from the model.
+ """
+ return self._prepare_input_and_invoke_stream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ )
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Bedrock service model.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = llm("Tell me a joke.")
+ """
+
+ if self.streaming:
+ completion = ""
+ for chunk in self._stream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ completion += chunk.text
+ return completion
+
+ return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
+
+ def get_num_tokens(self, text: str) -> int:
+ if self._model_is_anthropic:
+ return get_num_tokens_anthropic(text)
+ else:
+ return super().get_num_tokens(text)
+
+ def get_token_ids(self, text: str) -> List[int]:
+ if self._model_is_anthropic:
+ return get_token_ids_anthropic(text)
+ else:
+ return super().get_token_ids(text)
diff --git a/libs/community/langchain_community/llms/bittensor.py b/libs/community/langchain_community/llms/bittensor.py
new file mode 100644
index 00000000000..3d28533f514
--- /dev/null
+++ b/libs/community/langchain_community/llms/bittensor.py
@@ -0,0 +1,174 @@
+import http.client
+import json
+import ssl
+from typing import Any, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+
+
+class NIBittensorLLM(LLM):
+ """NIBittensor LLMs
+
+ NIBittensorLLM is created by Neural Internet (https://neuralinternet.ai/),
+ powered by Bittensor, a decentralized network full of different AI models.
+
+ To analyze API_KEYS and logs of your usage visit
+ https://api.neuralinternet.ai/api-keys
+ https://api.neuralinternet.ai/logs
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import NIBittensorLLM
+ llm = NIBittensorLLM()
+ """
+
+ system_prompt: Optional[str]
+ """Provide system prompt that you want to supply it to model before every prompt"""
+
+ top_responses: Optional[int] = 0
+ """Provide top_responses to get Top N miner responses on one request.May get delayed
+ Don't use in Production"""
+
+ @property
+ def _llm_type(self) -> str:
+ return "NIBittensorLLM"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """
+ Wrapper around the bittensor top miner models. Its built by Neural Internet.
+
+ Call the Neural Internet's BTVEP Server and return the output.
+
+ Parameters (optional):
+ system_prompt(str): A system prompt defining how your model should respond.
+ top_responses(int): Total top miner responses to retrieve from Bittensor
+ protocol.
+
+ Return:
+ The generated response(s).
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import NIBittensorLLM
+ llm = NIBittensorLLM(system_prompt="Act like you are programmer with \
+ 5+ years of experience.")
+ """
+
+ # Creating HTTPS connection with SSL
+ context = ssl.create_default_context()
+ context.check_hostname = True
+ conn = http.client.HTTPSConnection("test.neuralinternet.ai", context=context)
+
+ # Sanitizing User Input before passing to API.
+ if isinstance(self.top_responses, int):
+ top_n = min(100, self.top_responses)
+ else:
+ top_n = 0
+
+ default_prompt = "You are an assistant which is created by Neural Internet(NI) \
+ in decentralized network named as a Bittensor."
+ if self.system_prompt is None:
+ system_prompt = (
+ default_prompt
+ + " Your task is to provide accurate response based on user prompt"
+ )
+ else:
+ system_prompt = default_prompt + str(self.system_prompt)
+
+ # Retrieving API KEY to pass into header of each request
+ conn.request("GET", "/admin/api-keys/")
+ api_key_response = conn.getresponse()
+ api_keys_data = (
+ api_key_response.read().decode("utf-8").replace("\n", "").replace("\t", "")
+ )
+ api_keys_json = json.loads(api_keys_data)
+ api_key = api_keys_json[0]["api_key"]
+
+ # Creating Header and getting top benchmark miner uids
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}",
+ "Endpoint-Version": "2023-05-19",
+ }
+ conn.request("GET", "/top_miner_uids", headers=headers)
+ miner_response = conn.getresponse()
+ miner_data = (
+ miner_response.read().decode("utf-8").replace("\n", "").replace("\t", "")
+ )
+ uids = json.loads(miner_data)
+
+ # Condition for benchmark miner response
+ if isinstance(uids, list) and uids and not top_n:
+ for uid in uids:
+ try:
+ payload = json.dumps(
+ {
+ "uids": [uid],
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": prompt},
+ ],
+ }
+ )
+
+ conn.request("POST", "/chat", payload, headers)
+ init_response = conn.getresponse()
+ init_data = (
+ init_response.read()
+ .decode("utf-8")
+ .replace("\n", "")
+ .replace("\t", "")
+ )
+ init_json = json.loads(init_data)
+ if "choices" not in init_json:
+ continue
+ reply = init_json["choices"][0]["message"]["content"]
+ conn.close()
+ return reply
+ except Exception:
+ continue
+
+ # For top miner based on bittensor response
+ try:
+ payload = json.dumps(
+ {
+ "top_n": top_n,
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": prompt},
+ ],
+ }
+ )
+
+ conn.request("POST", "/chat", payload, headers)
+ response = conn.getresponse()
+ utf_string = (
+ response.read().decode("utf-8").replace("\n", "").replace("\t", "")
+ )
+ if top_n:
+ conn.close()
+ return utf_string
+ json_resp = json.loads(utf_string)
+ reply = json_resp["choices"][0]["message"]["content"]
+ conn.close()
+ return reply
+ except Exception as e:
+ conn.request("GET", f"/error_msg?e={e}&p={prompt}", headers=headers)
+ return "Sorry I am unable to provide response now, Please try again later."
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "system_prompt": self.system_prompt,
+ "top_responses": self.top_responses,
+ }
diff --git a/libs/community/langchain_community/llms/cerebriumai.py b/libs/community/langchain_community/llms/cerebriumai.py
new file mode 100644
index 00000000000..c9e219995ae
--- /dev/null
+++ b/libs/community/langchain_community/llms/cerebriumai.py
@@ -0,0 +1,113 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional, cast
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class CerebriumAI(LLM):
+ """CerebriumAI large language models.
+
+ To use, you should have the ``cerebrium`` python package installed.
+ You should also have the environment variable ``CEREBRIUMAI_API_KEY``
+ set with your API key or pass it as a named argument in the constructor.
+
+ Any parameters that are valid to be passed to the call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import CerebriumAI
+ cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key")
+
+ """
+
+ endpoint_url: str = ""
+ """model endpoint to use"""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not
+ explicitly specified."""
+
+ cerebriumai_api_key: Optional[SecretStr] = None
+
+ class Config:
+ """Configuration for this pydantic config."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""{field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ cerebriumai_api_key = convert_to_secret_str(
+ get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY")
+ )
+ values["cerebriumai_api_key"] = cerebriumai_api_key
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"endpoint_url": self.endpoint_url},
+ **{"model_kwargs": self.model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "cerebriumai"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ headers: Dict = {
+ "Authorization": cast(
+ SecretStr, self.cerebriumai_api_key
+ ).get_secret_value(),
+ "Content-Type": "application/json",
+ }
+ params = self.model_kwargs or {}
+ payload = {"prompt": prompt, **params, **kwargs}
+ response = requests.post(self.endpoint_url, json=payload, headers=headers)
+ if response.status_code == 200:
+ data = response.json()
+ text = data["result"]
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the model parameters
+ text = enforce_stop_tokens(text, stop)
+ return text
+ else:
+ response.raise_for_status()
+ return ""
diff --git a/libs/community/langchain_community/llms/chatglm.py b/libs/community/langchain_community/llms/chatglm.py
new file mode 100644
index 00000000000..84e2294c05d
--- /dev/null
+++ b/libs/community/langchain_community/llms/chatglm.py
@@ -0,0 +1,129 @@
+import logging
+from typing import Any, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class ChatGLM(LLM):
+ """ChatGLM LLM service.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import ChatGLM
+ endpoint_url = (
+ "http://127.0.0.1:8000"
+ )
+ ChatGLM_llm = ChatGLM(
+ endpoint_url=endpoint_url
+ )
+ """
+
+ endpoint_url: str = "http://127.0.0.1:8000/"
+ """Endpoint URL to use."""
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+ max_token: int = 20000
+ """Max token allowed to pass to the model."""
+ temperature: float = 0.1
+ """LLM model temperature from 0 to 10."""
+ history: List[List] = []
+ """History of the conversation"""
+ top_p: float = 0.7
+ """Top P for nucleus sampling from 0 to 1"""
+ with_history: bool = False
+ """Whether to use history or not"""
+
+ @property
+ def _llm_type(self) -> str:
+ return "chat_glm"
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"endpoint_url": self.endpoint_url},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to a ChatGLM LLM inference endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = chatglm_llm("Who are you?")
+ """
+
+ _model_kwargs = self.model_kwargs or {}
+
+ # HTTP headers for authorization
+ headers = {"Content-Type": "application/json"}
+
+ payload = {
+ "prompt": prompt,
+ "temperature": self.temperature,
+ "history": self.history,
+ "max_length": self.max_token,
+ "top_p": self.top_p,
+ }
+ payload.update(_model_kwargs)
+ payload.update(kwargs)
+
+ logger.debug(f"ChatGLM payload: {payload}")
+
+ # call api
+ try:
+ response = requests.post(self.endpoint_url, headers=headers, json=payload)
+ except requests.exceptions.RequestException as e:
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+
+ logger.debug(f"ChatGLM response: {response}")
+
+ if response.status_code != 200:
+ raise ValueError(f"Failed with response: {response}")
+
+ try:
+ parsed_response = response.json()
+
+ # Check if response content does exists
+ if isinstance(parsed_response, dict):
+ content_keys = "response"
+ if content_keys in parsed_response:
+ text = parsed_response[content_keys]
+ else:
+ raise ValueError(f"No content in response : {parsed_response}")
+ else:
+ raise ValueError(f"Unexpected response type: {parsed_response}")
+
+ except requests.exceptions.JSONDecodeError as e:
+ raise ValueError(
+ f"Error raised during decoding response from inference endpoint: {e}."
+ f"\nResponse: {response.text}"
+ )
+
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ if self.with_history:
+ self.history = self.history + [[None, parsed_response["response"]]]
+ return text
diff --git a/libs/community/langchain_community/llms/clarifai.py b/libs/community/langchain_community/llms/clarifai.py
new file mode 100644
index 00000000000..7690bd9bdd2
--- /dev/null
+++ b/libs/community/langchain_community/llms/clarifai.py
@@ -0,0 +1,214 @@
+import logging
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import Generation, LLMResult
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+EXAMPLE_URL = "https://clarifai.com/openai/chat-completion/models/GPT-4"
+
+
+class Clarifai(LLM):
+ """Clarifai large language models.
+
+ To use, you should have an account on the Clarifai platform,
+ the ``clarifai`` python package installed, and the
+ environment variable ``CLARIFAI_PAT`` set with your PAT key,
+ or pass it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Clarifai
+ clarifai_llm = Clarifai(user_id=USER_ID, app_id=APP_ID, model_id=MODEL_ID)
+ (or)
+ clarifai_llm = Clarifai(model_url=EXAMPLE_URL)
+ """
+
+ model_url: Optional[str] = None
+ """Model url to use."""
+ model_id: Optional[str] = None
+ """Model id to use."""
+ model_version_id: Optional[str] = None
+ """Model version id to use."""
+ app_id: Optional[str] = None
+ """Clarifai application id to use."""
+ user_id: Optional[str] = None
+ """Clarifai user id to use."""
+ pat: Optional[str] = None
+ """Clarifai personal access token to use."""
+ api_base: str = "https://api.clarifai.com"
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that we have all required info to access Clarifai
+ platform and python package exists in environment."""
+ values["pat"] = get_from_dict_or_env(values, "pat", "CLARIFAI_PAT")
+ user_id = values.get("user_id")
+ app_id = values.get("app_id")
+ model_id = values.get("model_id")
+ model_url = values.get("model_url")
+
+ if model_url is not None and model_id is not None:
+ raise ValueError("Please provide either model_url or model_id, not both.")
+
+ if model_url is None and model_id is None:
+ raise ValueError("Please provide one of model_url or model_id.")
+
+ if model_url is None and model_id is not None:
+ if user_id is None or app_id is None:
+ raise ValueError("Please provide a user_id and app_id.")
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Clarifai API."""
+ return {}
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{
+ "model_url": self.model_url,
+ "user_id": self.user_id,
+ "app_id": self.app_id,
+ "model_id": self.model_id,
+ }
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "clarifai"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ inference_params: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Clarfai's PostModelOutputs endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = clarifai_llm("Tell me a joke.")
+ """
+ # If version_id None, Defaults to the latest model version
+ try:
+ from clarifai.client.model import Model
+ except ImportError:
+ raise ImportError(
+ "Could not import clarifai python package. "
+ "Please install it with `pip install clarifai`."
+ )
+ if self.pat is not None:
+ pat = self.pat
+ if self.model_url is not None:
+ _model_init = Model(url=self.model_url, pat=pat)
+ else:
+ _model_init = Model(
+ model_id=self.model_id,
+ user_id=self.user_id,
+ app_id=self.app_id,
+ pat=pat,
+ )
+ try:
+ (inference_params := {}) if inference_params is None else inference_params
+ predict_response = _model_init.predict_by_bytes(
+ bytes(prompt, "utf-8"),
+ input_type="text",
+ inference_params=inference_params,
+ )
+ text = predict_response.outputs[0].data.text.raw
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+
+ except Exception as e:
+ logger.error(f"Predict failed, exception: {e}")
+
+ return text
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ inference_params: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Run the LLM on the given prompt and input."""
+
+ # TODO: add caching here.
+ try:
+ from clarifai.client.input import Inputs
+ from clarifai.client.model import Model
+ except ImportError:
+ raise ImportError(
+ "Could not import clarifai python package. "
+ "Please install it with `pip install clarifai`."
+ )
+ if self.pat is not None:
+ pat = self.pat
+ if self.model_url is not None:
+ _model_init = Model(url=self.model_url, pat=pat)
+ else:
+ _model_init = Model(
+ model_id=self.model_id,
+ user_id=self.user_id,
+ app_id=self.app_id,
+ pat=pat,
+ )
+
+ generations = []
+ batch_size = 32
+ input_obj = Inputs(pat=pat)
+ try:
+ for i in range(0, len(prompts), batch_size):
+ batch = prompts[i : i + batch_size]
+ input_batch = [
+ input_obj.get_text_input(input_id=str(id), raw_text=inp)
+ for id, inp in enumerate(batch)
+ ]
+ (
+ inference_params := {}
+ ) if inference_params is None else inference_params
+ predict_response = _model_init.predict(
+ inputs=input_batch, inference_params=inference_params
+ )
+
+ for output in predict_response.outputs:
+ if stop is not None:
+ text = enforce_stop_tokens(output.data.text.raw, stop)
+ else:
+ text = output.data.text.raw
+
+ generations.append([Generation(text=text)])
+
+ except Exception as e:
+ logger.error(f"Predict failed, exception: {e}")
+
+ return LLMResult(generations=generations)
diff --git a/libs/community/langchain_community/llms/cloudflare_workersai.py b/libs/community/langchain_community/llms/cloudflare_workersai.py
new file mode 100644
index 00000000000..840acdbdb81
--- /dev/null
+++ b/libs/community/langchain_community/llms/cloudflare_workersai.py
@@ -0,0 +1,126 @@
+import json
+import logging
+from typing import Any, Dict, Iterator, List, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+
+logger = logging.getLogger(__name__)
+
+
+class CloudflareWorkersAI(LLM):
+ """Langchain LLM class to help to access Cloudflare Workers AI service.
+
+ To use, you must provide an API token and
+ account ID to access Cloudflare Workers AI, and
+ pass it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI
+
+ my_account_id = "my_account_id"
+ my_api_token = "my_secret_api_token"
+ llm_model = "@cf/meta/llama-2-7b-chat-int8"
+
+ cf_ai = CloudflareWorkersAI(
+ account_id=my_account_id,
+ api_token=my_api_token,
+ model=llm_model
+ )
+ """ # noqa: E501
+
+ account_id: str
+ api_token: str
+ model: str = "@cf/meta/llama-2-7b-chat-int8"
+ base_url: str = "https://api.cloudflare.com/client/v4/accounts"
+ streaming: bool = False
+ endpoint_url: str = ""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize the Cloudflare Workers AI class."""
+ super().__init__(**kwargs)
+
+ self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of LLM."""
+ return "cloudflare"
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Default parameters"""
+ return {}
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Identifying parameters"""
+ return {
+ "account_id": self.account_id,
+ "api_token": self.api_token,
+ "model": self.model,
+ "base_url": self.base_url,
+ }
+
+ def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response:
+ """Call Cloudflare Workers API"""
+ headers = {"Authorization": f"Bearer {self.api_token}"}
+ data = {"prompt": prompt, "stream": self.streaming, **params}
+ response = requests.post(self.endpoint_url, headers=headers, json=data)
+ return response
+
+ def _process_response(self, response: requests.Response) -> str:
+ """Process API response"""
+ if response.ok:
+ data = response.json()
+ return data["result"]["response"]
+ else:
+ raise ValueError(f"Request failed with status {response.status_code}")
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ """Streaming prediction"""
+ original_steaming: bool = self.streaming
+ self.streaming = True
+ _response_prefix_count = len("data: ")
+ _response_stream_end = b"data: [DONE]"
+ for chunk in self._call_api(prompt, kwargs).iter_lines():
+ if chunk == _response_stream_end:
+ break
+ if len(chunk) > _response_prefix_count:
+ try:
+ data = json.loads(chunk[_response_prefix_count:])
+ except Exception as e:
+ logger.debug(chunk)
+ raise e
+ if data is not None and "response" in data:
+ yield GenerationChunk(text=data["response"])
+ if run_manager:
+ run_manager.on_llm_new_token(data["response"])
+ logger.debug("stream end")
+ self.streaming = original_steaming
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Regular prediction"""
+ if self.streaming:
+ return "".join(
+ [c.text for c in self._stream(prompt, stop, run_manager, **kwargs)]
+ )
+ else:
+ response = self._call_api(prompt, kwargs)
+ return self._process_response(response)
diff --git a/libs/community/langchain_community/llms/cohere.py b/libs/community/langchain_community/llms/cohere.py
new file mode 100644
index 00000000000..18a86c44613
--- /dev/null
+++ b/libs/community/langchain_community/llms/cohere.py
@@ -0,0 +1,249 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Callable, Dict, List, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.load.serializable import Serializable
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
+ import cohere
+
+ min_seconds = 4
+ max_seconds = 10
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(llm.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(retry_if_exception_type(cohere.error.CohereError)),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator(llm)
+
+ @retry_decorator
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return llm.client.generate(**kwargs)
+
+ return _completion_with_retry(**kwargs)
+
+
+def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator(llm)
+
+ @retry_decorator
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ return await llm.async_client.generate(**kwargs)
+
+ return _completion_with_retry(**kwargs)
+
+
+class BaseCohere(Serializable):
+ """Base class for Cohere models."""
+
+ client: Any #: :meta private:
+ async_client: Any #: :meta private:
+ model: Optional[str] = Field(default=None)
+ """Model name to use."""
+
+ temperature: float = 0.75
+ """A non-negative float that tunes the degree of randomness in generation."""
+
+ cohere_api_key: Optional[str] = None
+
+ stop: Optional[List[str]] = None
+
+ streaming: bool = Field(default=False)
+ """Whether to stream the results."""
+
+ user_agent: str = "langchain"
+ """Identifier for the application making the request."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ try:
+ import cohere
+ except ImportError:
+ raise ImportError(
+ "Could not import cohere python package. "
+ "Please install it with `pip install cohere`."
+ )
+ else:
+ cohere_api_key = get_from_dict_or_env(
+ values, "cohere_api_key", "COHERE_API_KEY"
+ )
+ client_name = values["user_agent"]
+ values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
+ values["async_client"] = cohere.AsyncClient(
+ cohere_api_key, client_name=client_name
+ )
+ return values
+
+
+class Cohere(LLM, BaseCohere):
+ """Cohere large language models.
+
+ To use, you should have the ``cohere`` python package installed, and the
+ environment variable ``COHERE_API_KEY`` set with your API key, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Cohere
+
+ cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key")
+ """
+
+ max_tokens: int = 256
+ """Denotes the number of tokens to predict per generation."""
+
+ k: int = 0
+ """Number of most likely tokens to consider at each step."""
+
+ p: int = 1
+ """Total probability mass of tokens to consider at each step."""
+
+ frequency_penalty: float = 0.0
+ """Penalizes repeated tokens according to frequency. Between 0 and 1."""
+
+ presence_penalty: float = 0.0
+ """Penalizes repeated tokens. Between 0 and 1."""
+
+ truncate: Optional[str] = None
+ """Specify how the client handles inputs longer than the maximum token
+ length: Truncate from START, END or NONE"""
+
+ max_retries: int = 10
+ """Maximum number of retries to make when generating."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Cohere API."""
+ return {
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ "k": self.k,
+ "p": self.p,
+ "frequency_penalty": self.frequency_penalty,
+ "presence_penalty": self.presence_penalty,
+ "truncate": self.truncate,
+ }
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"cohere_api_key": "COHERE_API_KEY"}
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model": self.model}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "cohere"
+
+ def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
+ params = self._default_params
+ if self.stop is not None and stop is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+ elif self.stop is not None:
+ params["stop_sequences"] = self.stop
+ else:
+ params["stop_sequences"] = stop
+ return {**params, **kwargs}
+
+ def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
+ text = response.generations[0].text
+ # If stop tokens are provided, Cohere's endpoint returns them.
+ # In order to make this consistent with other endpoints, we strip them.
+ if stop:
+ text = enforce_stop_tokens(text, stop)
+ return text
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Cohere's generate endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = cohere("Tell me a joke.")
+ """
+ params = self._invocation_params(stop, **kwargs)
+ response = completion_with_retry(
+ self, model=self.model, prompt=prompt, **params
+ )
+ _stop = params.get("stop_sequences")
+ return self._process_response(response, _stop)
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Async call out to Cohere's generate endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = await cohere("Tell me a joke.")
+ """
+ params = self._invocation_params(stop, **kwargs)
+ response = await acompletion_with_retry(
+ self, model=self.model, prompt=prompt, **params
+ )
+ _stop = params.get("stop_sequences")
+ return self._process_response(response, _stop)
diff --git a/libs/community/langchain_community/llms/ctransformers.py b/libs/community/langchain_community/llms/ctransformers.py
new file mode 100644
index 00000000000..b532b1585c5
--- /dev/null
+++ b/libs/community/langchain_community/llms/ctransformers.py
@@ -0,0 +1,140 @@
+from functools import partial
+from typing import Any, Dict, List, Optional, Sequence
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import root_validator
+
+
+class CTransformers(LLM):
+ """C Transformers LLM models.
+
+ To use, you should have the ``ctransformers`` python package installed.
+ See https://github.com/marella/ctransformers
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import CTransformers
+
+ llm = CTransformers(model="/path/to/ggml-gpt-2.bin", model_type="gpt2")
+ """
+
+ client: Any #: :meta private:
+
+ model: str
+ """The path to a model file or directory or the name of a Hugging Face Hub
+ model repo."""
+
+ model_type: Optional[str] = None
+ """The model type."""
+
+ model_file: Optional[str] = None
+ """The name of the model file in repo or directory."""
+
+ config: Optional[Dict[str, Any]] = None
+ """The config parameters.
+ See https://github.com/marella/ctransformers#config"""
+
+ lib: Optional[str] = None
+ """The path to a shared library or one of `avx2`, `avx`, `basic`."""
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model": self.model,
+ "model_type": self.model_type,
+ "model_file": self.model_file,
+ "config": self.config,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "ctransformers"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that ``ctransformers`` package is installed."""
+ try:
+ from ctransformers import AutoModelForCausalLM
+ except ImportError:
+ raise ImportError(
+ "Could not import `ctransformers` package. "
+ "Please install it with `pip install ctransformers`"
+ )
+
+ config = values["config"] or {}
+ values["client"] = AutoModelForCausalLM.from_pretrained(
+ values["model"],
+ model_type=values["model_type"],
+ model_file=values["model_file"],
+ lib=values["lib"],
+ **config,
+ )
+ return values
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[Sequence[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Generate text from a prompt.
+
+ Args:
+ prompt: The prompt to generate text from.
+ stop: A list of sequences to stop generation when encountered.
+
+ Returns:
+ The generated text.
+
+ Example:
+ .. code-block:: python
+
+ response = llm("Tell me a joke.")
+ """
+ text = []
+ _run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager()
+ for chunk in self.client(prompt, stop=stop, stream=True):
+ text.append(chunk)
+ _run_manager.on_llm_new_token(chunk, verbose=self.verbose)
+ return "".join(text)
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Asynchronous Call out to CTransformers generate method.
+ Very helpful when streaming (like with websockets!)
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: A list of strings to stop generation when encountered.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+ response = llm("Once upon a time, ")
+ """
+ text_callback = None
+ if run_manager:
+ text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
+
+ text = ""
+ for token in self.client(prompt, stop=stop, stream=True):
+ if text_callback:
+ await text_callback(token)
+ text += token
+
+ return text
diff --git a/libs/community/langchain_community/llms/ctranslate2.py b/libs/community/langchain_community/llms/ctranslate2.py
new file mode 100644
index 00000000000..84e357aa8e5
--- /dev/null
+++ b/libs/community/langchain_community/llms/ctranslate2.py
@@ -0,0 +1,128 @@
+from typing import Any, Dict, List, Optional, Union
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import BaseLLM
+from langchain_core.outputs import Generation, LLMResult
+from langchain_core.pydantic_v1 import Field, root_validator
+
+
+class CTranslate2(BaseLLM):
+ """CTranslate2 language model."""
+
+ model_path: str = ""
+ """Path to the CTranslate2 model directory."""
+
+ tokenizer_name: str = ""
+ """Name of the original Hugging Face model needed to load the proper tokenizer."""
+
+ device: str = "cpu"
+ """Device to use (possible values are: cpu, cuda, auto)."""
+
+ device_index: Union[int, List[int]] = 0
+ """Device IDs where to place this generator on."""
+
+ compute_type: Union[str, Dict[str, str]] = "default"
+ """
+ Model computation type or a dictionary mapping a device name to the computation type
+ (possible values are: default, auto, int8, int8_float32, int8_float16,
+ int8_bfloat16, int16, float16, bfloat16, float32).
+ """
+
+ max_length: int = 512
+ """Maximum generation length."""
+
+ sampling_topk: int = 1
+ """Randomly sample predictions from the top K candidates."""
+
+ sampling_topp: float = 1
+ """Keep the most probable tokens whose cumulative probability exceeds this value."""
+
+ sampling_temperature: float = 1
+ """Sampling temperature to generate more random samples."""
+
+ client: Any #: :meta private:
+
+ tokenizer: Any #: :meta private:
+
+ ctranslate2_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """
+ Holds any model parameters valid for `ctranslate2.Generator` call not
+ explicitly specified.
+ """
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+
+ try:
+ import ctranslate2
+ except ImportError:
+ raise ImportError(
+ "Could not import ctranslate2 python package. "
+ "Please install it with `pip install ctranslate2`."
+ )
+
+ try:
+ import transformers
+ except ImportError:
+ raise ImportError(
+ "Could not import transformers python package. "
+ "Please install it with `pip install transformers`."
+ )
+
+ values["client"] = ctranslate2.Generator(
+ model_path=values["model_path"],
+ device=values["device"],
+ device_index=values["device_index"],
+ compute_type=values["compute_type"],
+ **values["ctranslate2_kwargs"],
+ )
+
+ values["tokenizer"] = transformers.AutoTokenizer.from_pretrained(
+ values["tokenizer_name"]
+ )
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters."""
+ return {
+ "max_length": self.max_length,
+ "sampling_topk": self.sampling_topk,
+ "sampling_topp": self.sampling_topp,
+ "sampling_temperature": self.sampling_temperature,
+ }
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ # build sampling parameters
+ params = {**self._default_params, **kwargs}
+
+ # call the model
+ encoded_prompts = self.tokenizer(prompts)["input_ids"]
+ tokenized_prompts = [
+ self.tokenizer.convert_ids_to_tokens(encoded_prompt)
+ for encoded_prompt in encoded_prompts
+ ]
+
+ results = self.client.generate_batch(tokenized_prompts, **params)
+
+ sequences = [result.sequences_ids[0] for result in results]
+ decoded_sequences = [self.tokenizer.decode(seq) for seq in sequences]
+
+ generations = []
+ for text in decoded_sequences:
+ generations.append([Generation(text=text)])
+
+ return LLMResult(generations=generations)
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "ctranslate2"
diff --git a/libs/community/langchain_community/llms/databricks.py b/libs/community/langchain_community/llms/databricks.py
new file mode 100644
index 00000000000..a3f505b5c2e
--- /dev/null
+++ b/libs/community/langchain_community/llms/databricks.py
@@ -0,0 +1,469 @@
+import os
+import warnings
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import LLM
+from langchain_core.pydantic_v1 import (
+ BaseModel,
+ Extra,
+ Field,
+ PrivateAttr,
+ root_validator,
+ validator,
+)
+
+__all__ = ["Databricks"]
+
+
+class _DatabricksClientBase(BaseModel, ABC):
+ """A base JSON API client that talks to Databricks."""
+
+ api_url: str
+ api_token: str
+
+ def request(self, method: str, url: str, request: Any) -> Any:
+ headers = {"Authorization": f"Bearer {self.api_token}"}
+ response = requests.request(
+ method=method, url=url, headers=headers, json=request
+ )
+ # TODO: error handling and automatic retries
+ if not response.ok:
+ raise ValueError(f"HTTP {response.status_code} error: {response.text}")
+ return response.json()
+
+ def _get(self, url: str) -> Any:
+ return self.request("GET", url, None)
+
+ def _post(self, url: str, request: Any) -> Any:
+ return self.request("POST", url, request)
+
+ @abstractmethod
+ def post(
+ self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
+ ) -> Any:
+ ...
+
+ @property
+ def llm(self) -> bool:
+ return False
+
+
+def _transform_completions(response: Dict[str, Any]) -> str:
+ return response["choices"][0]["text"]
+
+
+def _transform_chat(response: Dict[str, Any]) -> str:
+ return response["choices"][0]["message"]["content"]
+
+
+class _DatabricksServingEndpointClient(_DatabricksClientBase):
+ """An API client that talks to a Databricks serving endpoint."""
+
+ host: str
+ endpoint_name: str
+ databricks_uri: str
+ client: Any = None
+ external_or_foundation: bool = False
+ task: Optional[str] = None
+
+ def __init__(self, **data: Any):
+ super().__init__(**data)
+
+ try:
+ from mlflow.deployments import get_deploy_client
+
+ self.client = get_deploy_client(self.databricks_uri)
+ except ImportError as e:
+ raise ImportError(
+ "Failed to create the client. "
+ "Please install mlflow with `pip install mlflow`."
+ ) from e
+
+ endpoint = self.client.get_endpoint(self.endpoint_name)
+ self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
+ "external_model",
+ "foundation_model_api",
+ )
+ self.task = endpoint.get("task")
+
+ @property
+ def llm(self) -> bool:
+ return self.task in ("llm/v1/chat", "llm/v1/completions")
+
+ @root_validator(pre=True)
+ def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ if "api_url" not in values:
+ host = values["host"]
+ endpoint_name = values["endpoint_name"]
+ api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations"
+ values["api_url"] = api_url
+ return values
+
+ def post(
+ self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
+ ) -> Any:
+ if self.external_or_foundation:
+ resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
+ if transform_output_fn:
+ return transform_output_fn(resp)
+
+ if self.task == "llm/v1/chat":
+ return _transform_chat(resp)
+ elif self.task == "llm/v1/completions":
+ return _transform_completions(resp)
+
+ return resp
+ else:
+ # See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
+ wrapped_request = {"dataframe_records": [request]}
+ response = self.client.predict(
+ endpoint=self.endpoint_name, inputs=wrapped_request
+ )
+ preds = response["predictions"]
+ # For a single-record query, the result is not a list.
+ pred = preds[0] if isinstance(preds, list) else preds
+ return transform_output_fn(pred) if transform_output_fn else pred
+
+
+class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
+ """An API client that talks to a Databricks cluster driver proxy app."""
+
+ host: str
+ cluster_id: str
+ cluster_driver_port: str
+
+ @root_validator(pre=True)
+ def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ if "api_url" not in values:
+ host = values["host"]
+ cluster_id = values["cluster_id"]
+ port = values["cluster_driver_port"]
+ api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"
+ values["api_url"] = api_url
+ return values
+
+ def post(
+ self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
+ ) -> Any:
+ resp = self._post(self.api_url, request)
+ return transform_output_fn(resp) if transform_output_fn else resp
+
+
+def get_repl_context() -> Any:
+ """Gets the notebook REPL context if running inside a Databricks notebook.
+ Returns None otherwise.
+ """
+ try:
+ from dbruntime.databricks_repl_context import get_context
+
+ return get_context()
+ except ImportError:
+ raise ImportError(
+ "Cannot access dbruntime, not running inside a Databricks notebook."
+ )
+
+
+def get_default_host() -> str:
+ """Gets the default Databricks workspace hostname.
+ Raises an error if the hostname cannot be automatically determined.
+ """
+ host = os.getenv("DATABRICKS_HOST")
+ if not host:
+ try:
+ host = get_repl_context().browserHostName
+ if not host:
+ raise ValueError("context doesn't contain browserHostName.")
+ except Exception as e:
+ raise ValueError(
+ "host was not set and cannot be automatically inferred. Set "
+ f"environment variable 'DATABRICKS_HOST'. Received error: {e}"
+ )
+ # TODO: support Databricks CLI profile
+ host = host.lstrip("https://").lstrip("http://").rstrip("/")
+ return host
+
+
+def get_default_api_token() -> str:
+ """Gets the default Databricks personal access token.
+ Raises an error if the token cannot be automatically determined.
+ """
+ if api_token := os.getenv("DATABRICKS_TOKEN"):
+ return api_token
+ try:
+ api_token = get_repl_context().apiToken
+ if not api_token:
+ raise ValueError("context doesn't contain apiToken.")
+ except Exception as e:
+ raise ValueError(
+ "api_token was not set and cannot be automatically inferred. Set "
+ f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}"
+ )
+ # TODO: support Databricks CLI profile
+ return api_token
+
+
+class Databricks(LLM):
+
+ """Databricks serving endpoint or a cluster driver proxy app for LLM.
+
+ It supports two endpoint types:
+
+ * **Serving endpoint** (recommended for both production and development).
+ We assume that an LLM was deployed to a serving endpoint.
+ To wrap it as an LLM you must have "Can Query" permission to the endpoint.
+ Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
+ ``cluster_driver_port``.
+
+ If the underlying model is a model registered by MLflow, the expected model
+ signature is:
+
+ * inputs::
+
+ [{"name": "prompt", "type": "string"},
+ {"name": "stop", "type": "list[string]"}]
+
+ * outputs: ``[{"type": "string"}]``
+
+ If the underlying model is an external or foundation model, the response from the
+ endpoint is automatically transformed to the expected format unless
+ ``transform_output_fn`` is provided.
+
+ * **Cluster driver proxy app** (recommended for interactive development).
+ One can load an LLM on a Databricks interactive cluster and start a local HTTP
+ server on the driver node to serve the model at ``/`` using HTTP POST method
+ with JSON input/output.
+ Please use a port number between ``[3000, 8000]`` and let the server listen to
+ the driver IP address or simply ``0.0.0.0`` instead of localhost only.
+ To wrap it as an LLM you must have "Can Attach To" permission to the cluster.
+ Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``.
+ The expected server schema (using JSON schema) is:
+
+ * inputs::
+
+ {"type": "object",
+ "properties": {
+ "prompt": {"type": "string"},
+ "stop": {"type": "array", "items": {"type": "string"}}},
+ "required": ["prompt"]}`
+
+ * outputs: ``{"type": "string"}``
+
+ If the endpoint model signature is different or you want to set extra params,
+ you can use `transform_input_fn` and `transform_output_fn` to apply necessary
+ transformations before and after the query.
+ """
+
+ host: str = Field(default_factory=get_default_host)
+ """Databricks workspace hostname.
+ If not provided, the default value is determined by
+
+ * the ``DATABRICKS_HOST`` environment variable if present, or
+ * the hostname of the current Databricks workspace if running inside
+ a Databricks notebook attached to an interactive cluster in "single user"
+ or "no isolation shared" mode.
+ """
+
+ api_token: str = Field(default_factory=get_default_api_token)
+ """Databricks personal access token.
+ If not provided, the default value is determined by
+
+ * the ``DATABRICKS_TOKEN`` environment variable if present, or
+ * an automatically generated temporary token if running inside a Databricks
+ notebook attached to an interactive cluster in "single user" or
+ "no isolation shared" mode.
+ """
+
+ endpoint_name: Optional[str] = None
+ """Name of the model serving endpoint.
+ You must specify the endpoint name to connect to a model serving endpoint.
+ You must not set both ``endpoint_name`` and ``cluster_id``.
+ """
+
+ cluster_id: Optional[str] = None
+ """ID of the cluster if connecting to a cluster driver proxy app.
+ If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs
+ inside a Databricks notebook attached to an interactive cluster in "single user"
+ or "no isolation shared" mode, the current cluster ID is used as default.
+ You must not set both ``endpoint_name`` and ``cluster_id``.
+ """
+
+ cluster_driver_port: Optional[str] = None
+ """The port number used by the HTTP server running on the cluster driver node.
+ The server should listen on the driver IP address or simply ``0.0.0.0`` to connect.
+ We recommend the server using a port number between ``[3000, 8000]``.
+ """
+
+ model_kwargs: Optional[Dict[str, Any]] = None
+ """
+ Deprecated. Please use ``extra_params`` instead. Extra parameters to pass to
+ the endpoint.
+ """
+
+ transform_input_fn: Optional[Callable] = None
+ """A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
+ request object that the endpoint accepts.
+ For example, you can apply a prompt template to the input prompt.
+ """
+
+ transform_output_fn: Optional[Callable[..., str]] = None
+ """A function that transforms the output from the endpoint to the generated text.
+ """
+
+ databricks_uri: str = "databricks"
+ """The databricks URI. Only used when using a serving endpoint."""
+
+ temperature: float = 0.0
+ """The sampling temperature."""
+ n: int = 1
+ """The number of completion choices to generate."""
+ stop: Optional[List[str]] = None
+ """The stop sequence."""
+ max_tokens: Optional[int] = None
+ """The maximum number of tokens to generate."""
+ extra_params: Dict[str, Any] = Field(default_factory=dict)
+ """Any extra parameters to pass to the endpoint."""
+
+ _client: _DatabricksClientBase = PrivateAttr()
+
+ class Config:
+ extra = Extra.forbid
+ underscore_attrs_are_private = True
+
+ @property
+ def _llm_params(self) -> Dict[str, Any]:
+ params: Dict[str, Any] = {
+ "temperature": self.temperature,
+ "n": self.n,
+ }
+ if self.stop:
+ params["stop"] = self.stop
+ if self.max_tokens is not None:
+ params["max_tokens"] = self.max_tokens
+ return params
+
+ @validator("cluster_id", always=True)
+ def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
+ if v and values["endpoint_name"]:
+ raise ValueError("Cannot set both endpoint_name and cluster_id.")
+ elif values["endpoint_name"]:
+ return None
+ elif v:
+ return v
+ else:
+ try:
+ if v := get_repl_context().clusterId:
+ return v
+ raise ValueError("Context doesn't contain clusterId.")
+ except Exception as e:
+ raise ValueError(
+ "Neither endpoint_name nor cluster_id was set. "
+ "And the cluster_id cannot be automatically determined. Received"
+ f" error: {e}"
+ )
+
+ @validator("cluster_driver_port", always=True)
+ def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
+ if v and values["endpoint_name"]:
+ raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
+ elif values["endpoint_name"]:
+ return None
+ elif v is None:
+ raise ValueError(
+ "Must set cluster_driver_port to connect to a cluster driver."
+ )
+ elif int(v) <= 0:
+ raise ValueError(f"Invalid cluster_driver_port: {v}")
+ else:
+ return v
+
+ @validator("model_kwargs", always=True)
+ def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+ if v:
+ assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
+ assert "stop" not in v, "model_kwargs must not contain key 'stop'"
+ return v
+
+ def __init__(self, **data: Any):
+ super().__init__(**data)
+ if self.model_kwargs is not None and self.extra_params is not None:
+ raise ValueError("Cannot set both extra_params and extra_params.")
+ elif self.model_kwargs is not None:
+ warnings.warn(
+ "model_kwargs is deprecated. Please use extra_params instead.",
+ DeprecationWarning,
+ )
+ if self.endpoint_name:
+ self._client = _DatabricksServingEndpointClient(
+ host=self.host,
+ api_token=self.api_token,
+ endpoint_name=self.endpoint_name,
+ databricks_uri=self.databricks_uri,
+ )
+ elif self.cluster_id and self.cluster_driver_port:
+ self._client = _DatabricksClusterDriverProxyClient(
+ host=self.host,
+ api_token=self.api_token,
+ cluster_id=self.cluster_id,
+ cluster_driver_port=self.cluster_driver_port,
+ )
+ else:
+ raise ValueError(
+ "Must specify either endpoint_name or cluster_id/cluster_driver_port."
+ )
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Return default params."""
+ return {
+ "host": self.host,
+ # "api_token": self.api_token, # Never save the token
+ "endpoint_name": self.endpoint_name,
+ "cluster_id": self.cluster_id,
+ "cluster_driver_port": self.cluster_driver_port,
+ "databricks_uri": self.databricks_uri,
+ "model_kwargs": self.model_kwargs,
+ "temperature": self.temperature,
+ "n": self.n,
+ "stop": self.stop,
+ "max_tokens": self.max_tokens,
+ "extra_params": self.extra_params,
+ # TODO: Support saving transform_input_fn and transform_output_fn
+ # "transform_input_fn": self.transform_input_fn,
+ # "transform_output_fn": self.transform_output_fn,
+ }
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ return self._default_params
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "databricks"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Queries the LLM endpoint with the given prompt and stop sequence."""
+
+ # TODO: support callbacks
+
+ request: Dict[str, Any] = {"prompt": prompt}
+ if self._client.llm:
+ request.update(self._llm_params)
+ request.update(self.model_kwargs or self.extra_params)
+ request.update(kwargs)
+ if stop:
+ request["stop"] = stop
+
+ if self.transform_input_fn:
+ request = self.transform_input_fn(**request)
+
+ return self._client.post(request, transform_output_fn=self.transform_output_fn)
diff --git a/libs/community/langchain_community/llms/deepinfra.py b/libs/community/langchain_community/llms/deepinfra.py
new file mode 100644
index 00000000000..412de97cf7e
--- /dev/null
+++ b/libs/community/langchain_community/llms/deepinfra.py
@@ -0,0 +1,219 @@
+import json
+from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
+
+import aiohttp
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.utilities.requests import Requests
+
+DEFAULT_MODEL_ID = "google/flan-t5-xl"
+
+
+class DeepInfra(LLM):
+ """DeepInfra models.
+
+ To use, you should have the environment variable ``DEEPINFRA_API_TOKEN``
+ set with your API token, or pass it as a named parameter to the
+ constructor.
+
+ Only supports `text-generation` and `text2text-generation` for now.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import DeepInfra
+ di = DeepInfra(model_id="google/flan-t5-xl",
+ deepinfra_api_token="my-api-key")
+ """
+
+ model_id: str = DEFAULT_MODEL_ID
+ model_kwargs: Optional[Dict] = None
+
+ deepinfra_api_token: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ deepinfra_api_token = get_from_dict_or_env(
+ values, "deepinfra_api_token", "DEEPINFRA_API_TOKEN"
+ )
+ values["deepinfra_api_token"] = deepinfra_api_token
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"model_id": self.model_id},
+ **{"model_kwargs": self.model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "deepinfra"
+
+ def _url(self) -> str:
+ return f"https://api.deepinfra.com/v1/inference/{self.model_id}"
+
+ def _headers(self) -> Dict:
+ return {
+ "Authorization": f"bearer {self.deepinfra_api_token}",
+ "Content-Type": "application/json",
+ }
+
+ def _body(self, prompt: str, kwargs: Any) -> Dict:
+ model_kwargs = self.model_kwargs or {}
+ model_kwargs = {**model_kwargs, **kwargs}
+
+ return {
+ "input": prompt,
+ **model_kwargs,
+ }
+
+ def _handle_status(self, code: int, text: Any) -> None:
+ if code >= 500:
+ raise Exception(f"DeepInfra Server: Error {code}")
+ elif code >= 400:
+ raise ValueError(f"DeepInfra received an invalid payload: {text}")
+ elif code != 200:
+ raise Exception(
+ f"DeepInfra returned an unexpected response with status "
+ f"{code}: {text}"
+ )
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to DeepInfra's inference API endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = di("Tell me a joke.")
+ """
+
+ request = Requests(headers=self._headers())
+ response = request.post(url=self._url(), data=self._body(prompt, kwargs))
+
+ self._handle_status(response.status_code, response.text)
+ data = response.json()
+
+ return data["results"][0]["generated_text"]
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ request = Requests(headers=self._headers())
+ async with request.apost(
+ url=self._url(), data=self._body(prompt, kwargs)
+ ) as response:
+ self._handle_status(response.status, response.text)
+ data = await response.json()
+ return data["results"][0]["generated_text"]
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ request = Requests(headers=self._headers())
+ response = request.post(
+ url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
+ )
+
+ self._handle_status(response.status_code, response.text)
+ for line in _parse_stream(response.iter_lines()):
+ chunk = _handle_sse_line(line)
+ if chunk:
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ request = Requests(headers=self._headers())
+ async with request.apost(
+ url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
+ ) as response:
+ self._handle_status(response.status, response.text)
+ async for line in _parse_stream_async(response.content):
+ chunk = _handle_sse_line(line)
+ if chunk:
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.text)
+
+
+def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
+ for line in rbody:
+ _line = _parse_stream_helper(line)
+ if _line is not None:
+ yield _line
+
+
+async def _parse_stream_async(rbody: aiohttp.StreamReader) -> AsyncIterator[str]:
+ async for line in rbody:
+ _line = _parse_stream_helper(line)
+ if _line is not None:
+ yield _line
+
+
+def _parse_stream_helper(line: bytes) -> Optional[str]:
+ if line and line.startswith(b"data:"):
+ if line.startswith(b"data: "):
+ # SSE event may be valid when it contain whitespace
+ line = line[len(b"data: ") :]
+ else:
+ line = line[len(b"data:") :]
+ if line.strip() == b"[DONE]":
+ # return here will cause GeneratorExit exception in urllib3
+ # and it will close http connection with TCP Reset
+ return None
+ else:
+ return line.decode("utf-8")
+ return None
+
+
+def _handle_sse_line(line: str) -> Optional[GenerationChunk]:
+ try:
+ obj = json.loads(line)
+ return GenerationChunk(
+ text=obj.get("token", {}).get("text"),
+ )
+ except Exception:
+ return None
diff --git a/libs/community/langchain_community/llms/deepsparse.py b/libs/community/langchain_community/llms/deepsparse.py
new file mode 100644
index 00000000000..1d8166e687c
--- /dev/null
+++ b/libs/community/langchain_community/llms/deepsparse.py
@@ -0,0 +1,232 @@
+# flake8: noqa
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_community.llms.utils import enforce_stop_tokens
+from langchain_core.outputs import GenerationChunk
+
+
+class DeepSparse(LLM):
+ """Neural Magic DeepSparse LLM interface.
+ To use, you should have the ``deepsparse`` or ``deepsparse-nightly``
+ python package installed. See https://github.com/neuralmagic/deepsparse
+ This interface let's you deploy optimized LLMs straight from the
+ [SparseZoo](https://sparsezoo.neuralmagic.com/?useCase=text_generation)
+ Example:
+ .. code-block:: python
+ from langchain_community.llms import DeepSparse
+ llm = DeepSparse(model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none")
+ """ # noqa: E501
+
+ pipeline: Any #: :meta private:
+
+ model: str
+ """The path to a model file or directory or the name of a SparseZoo model stub."""
+
+ model_config: Optional[Dict[str, Any]] = None
+ """Keyword arguments passed to the pipeline construction.
+ Common parameters are sequence_length, prompt_sequence_length"""
+
+ generation_config: Union[None, str, Dict] = None
+ """GenerationConfig dictionary consisting of parameters used to control
+ sequences generated for each prompt. Common parameters are:
+ max_length, max_new_tokens, num_return_sequences, output_scores,
+ top_p, top_k, repetition_penalty."""
+
+ streaming: bool = False
+ """Whether to stream the results, token by token."""
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model": self.model,
+ "model_config": self.model_config,
+ "generation_config": self.generation_config,
+ "streaming": self.streaming,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "deepsparse"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that ``deepsparse`` package is installed."""
+ try:
+ from deepsparse import Pipeline
+ except ImportError:
+ raise ImportError(
+ "Could not import `deepsparse` package. "
+ "Please install it with `pip install deepsparse`"
+ )
+
+ model_config = values["model_config"] or {}
+
+ values["pipeline"] = Pipeline.create(
+ task="text_generation",
+ model_path=values["model"],
+ **model_config,
+ )
+ return values
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Generate text from a prompt.
+ Args:
+ prompt: The prompt to generate text from.
+ stop: A list of strings to stop generation when encountered.
+ Returns:
+ The generated text.
+ Example:
+ .. code-block:: python
+ from langchain_community.llms import DeepSparse
+ llm = DeepSparse(model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none")
+ llm("Tell me a joke.")
+ """
+ if self.streaming:
+ combined_output = ""
+ for chunk in self._stream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ combined_output += chunk.text
+ text = combined_output
+ else:
+ text = (
+ self.pipeline(
+ sequences=prompt, generation_config=self.generation_config
+ )
+ .generations[0]
+ .text
+ )
+
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+
+ return text
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Generate text from a prompt.
+ Args:
+ prompt: The prompt to generate text from.
+ stop: A list of strings to stop generation when encountered.
+ Returns:
+ The generated text.
+ Example:
+ .. code-block:: python
+ from langchain_community.llms import DeepSparse
+ llm = DeepSparse(model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none")
+ llm("Tell me a joke.")
+ """
+ if self.streaming:
+ combined_output = ""
+ async for chunk in self._astream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ combined_output += chunk.text
+ text = combined_output
+ else:
+ text = (
+ self.pipeline(
+ sequences=prompt, generation_config=self.generation_config
+ )
+ .generations[0]
+ .text
+ )
+
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+
+ return text
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ """Yields results objects as they are generated in real time.
+ It also calls the callback manager's on_llm_new_token event with
+ similar parameters to the OpenAI LLM class method of the same name.
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ A generator representing the stream of tokens being generated.
+ Yields:
+ A dictionary like object containing a string token.
+ Example:
+ .. code-block:: python
+ from langchain_community.llms import DeepSparse
+ llm = DeepSparse(
+ model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none",
+ streaming=True
+ )
+ for chunk in llm.stream("Tell me a joke",
+ stop=["'","\n"]):
+ print(chunk, end='', flush=True)
+ """
+ inference = self.pipeline(
+ sequences=prompt, generation_config=self.generation_config, streaming=True
+ )
+ for token in inference:
+ chunk = GenerationChunk(text=token.generations[0].text)
+ yield chunk
+
+ if run_manager:
+ run_manager.on_llm_new_token(token=chunk.text)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ """Yields results objects as they are generated in real time.
+ It also calls the callback manager's on_llm_new_token event with
+ similar parameters to the OpenAI LLM class method of the same name.
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ A generator representing the stream of tokens being generated.
+ Yields:
+ A dictionary like object containing a string token.
+ Example:
+ .. code-block:: python
+ from langchain_community.llms import DeepSparse
+ llm = DeepSparse(
+ model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none",
+ streaming=True
+ )
+ for chunk in llm.stream("Tell me a joke",
+ stop=["'","\n"]):
+ print(chunk, end='', flush=True)
+ """
+ inference = self.pipeline(
+ sequences=prompt, generation_config=self.generation_config, streaming=True
+ )
+ for token in inference:
+ chunk = GenerationChunk(text=token.generations[0].text)
+ yield chunk
+
+ if run_manager:
+ await run_manager.on_llm_new_token(token=chunk.text)
diff --git a/libs/community/langchain_community/llms/edenai.py b/libs/community/langchain_community/llms/edenai.py
new file mode 100644
index 00000000000..4a116235a62
--- /dev/null
+++ b/libs/community/langchain_community/llms/edenai.py
@@ -0,0 +1,265 @@
+"""Wrapper around EdenAI's Generation API."""
+import logging
+from typing import Any, Dict, List, Literal, Optional
+
+from aiohttp import ClientSession
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+from langchain_community.utilities.requests import Requests
+
+logger = logging.getLogger(__name__)
+
+
+class EdenAI(LLM):
+ """Wrapper around edenai models.
+
+ To use, you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+
+ `feature` and `subfeature` are required, but any other model parameters can also be
+ passed in with the format params={model_param: value, ...}
+
+ for api reference check edenai documentation: http://docs.edenai.co.
+ """
+
+ base_url: str = "https://api.edenai.run/v2"
+
+ edenai_api_key: Optional[str] = None
+
+ feature: Literal["text", "image"] = "text"
+ """Which generative feature to use, use text by default"""
+
+ subfeature: Literal["generation"] = "generation"
+ """Subfeature of above feature, use generation by default"""
+
+ provider: str
+ """Generative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
+
+ model: Optional[str] = None
+ """
+ model name for above provider (eg: 'text-davinci-003' for openai)
+ available models are shown on https://docs.edenai.co/ under 'available providers'
+ """
+
+ # Optional parameters to add depending of chosen feature
+ # see api reference for more infos
+ temperature: Optional[float] = Field(default=None, ge=0, le=1) # for text
+ max_tokens: Optional[int] = Field(default=None, ge=0) # for text
+ resolution: Optional[Literal["256x256", "512x512", "1024x1024"]] = None # for image
+
+ params: Dict[str, Any] = Field(default_factory=dict)
+ """
+ DEPRECATED: use temperature, max_tokens, resolution directly
+ optional parameters to pass to api
+ """
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """extra parameters"""
+
+ stop_sequences: Optional[List[str]] = None
+ """Stop sequences to use."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ values["edenai_api_key"] = get_from_dict_or_env(
+ values, "edenai_api_key", "EDENAI_API_KEY"
+ )
+ return values
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""{field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of model."""
+ return "edenai"
+
+ def _format_output(self, output: dict) -> str:
+ if self.feature == "text":
+ return output[self.provider]["generated_text"]
+ else:
+ return output[self.provider]["items"][0]["image"]
+
+ @staticmethod
+ def get_user_agent() -> str:
+ from langchain_community import __version__
+
+ return f"langchain/{__version__}"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to EdenAI's text generation endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+
+ Returns:
+ json formatted str response.
+ """
+ stops = None
+ if self.stop_sequences is not None and stop is not None:
+ raise ValueError(
+ "stop sequences found in both the input and default params."
+ )
+ elif self.stop_sequences is not None:
+ stops = self.stop_sequences
+ else:
+ stops = stop
+
+ url = f"{self.base_url}/{self.feature}/{self.subfeature}"
+ headers = {
+ "Authorization": f"Bearer {self.edenai_api_key}",
+ "User-Agent": self.get_user_agent(),
+ }
+ payload: Dict[str, Any] = {
+ "providers": self.provider,
+ "text": prompt,
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ "resolution": self.resolution,
+ **self.params,
+ **kwargs,
+ "num_images": 1, # always limit to 1 (ignored for text)
+ }
+
+ # filter None values to not pass them to the http payload
+ payload = {k: v for k, v in payload.items() if v is not None}
+
+ if self.model is not None:
+ payload["settings"] = {self.provider: self.model}
+
+ request = Requests(headers=headers)
+ response = request.post(url=url, data=payload)
+
+ if response.status_code >= 500:
+ raise Exception(f"EdenAI Server: Error {response.status_code}")
+ elif response.status_code >= 400:
+ raise ValueError(f"EdenAI received an invalid payload: {response.text}")
+ elif response.status_code != 200:
+ raise Exception(
+ f"EdenAI returned an unexpected response with status "
+ f"{response.status_code}: {response.text}"
+ )
+
+ data = response.json()
+ provider_response = data[self.provider]
+ if provider_response.get("status") == "fail":
+ err_msg = provider_response.get("error", {}).get("message")
+ raise Exception(err_msg)
+
+ output = self._format_output(data)
+
+ if stops is not None:
+ output = enforce_stop_tokens(output, stops)
+
+ return output
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call EdenAi model to get predictions based on the prompt.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: A list of stop words (optional).
+ run_manager: A callback manager for async interaction with LLMs.
+
+ Returns:
+ The string generated by the model.
+ """
+
+ stops = None
+ if self.stop_sequences is not None and stop is not None:
+ raise ValueError(
+ "stop sequences found in both the input and default params."
+ )
+ elif self.stop_sequences is not None:
+ stops = self.stop_sequences
+ else:
+ stops = stop
+
+ url = f"{self.base_url}/{self.feature}/{self.subfeature}"
+ headers = {
+ "Authorization": f"Bearer {self.edenai_api_key}",
+ "User-Agent": self.get_user_agent(),
+ }
+ payload: Dict[str, Any] = {
+ "providers": self.provider,
+ "text": prompt,
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ "resolution": self.resolution,
+ **self.params,
+ **kwargs,
+ "num_images": 1, # always limit to 1 (ignored for text)
+ }
+
+ # filter `None` values to not pass them to the http payload as null
+ payload = {k: v for k, v in payload.items() if v is not None}
+
+ if self.model is not None:
+ payload["settings"] = {self.provider: self.model}
+
+ async with ClientSession() as session:
+ async with session.post(url, json=payload, headers=headers) as response:
+ if response.status >= 500:
+ raise Exception(f"EdenAI Server: Error {response.status}")
+ elif response.status >= 400:
+ raise ValueError(
+ f"EdenAI received an invalid payload: {response.text}"
+ )
+ elif response.status != 200:
+ raise Exception(
+ f"EdenAI returned an unexpected response with status "
+ f"{response.status}: {response.text}"
+ )
+
+ response_json = await response.json()
+ provider_response = response_json[self.provider]
+ if provider_response.get("status") == "fail":
+ err_msg = provider_response.get("error", {}).get("message")
+ raise Exception(err_msg)
+
+ output = self._format_output(response_json)
+ if stops is not None:
+ output = enforce_stop_tokens(output, stops)
+
+ return output
diff --git a/libs/community/langchain_community/llms/fake.py b/libs/community/langchain_community/llms/fake.py
new file mode 100644
index 00000000000..929fd19eb24
--- /dev/null
+++ b/libs/community/langchain_community/llms/fake.py
@@ -0,0 +1,90 @@
+import asyncio
+import time
+from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models import LanguageModelInput
+from langchain_core.language_models.llms import LLM
+from langchain_core.runnables import RunnableConfig
+
+
+class FakeListLLM(LLM):
+ """Fake LLM for testing purposes."""
+
+ responses: List[str]
+ sleep: Optional[float] = None
+ i: int = 0
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "fake-list"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Return next response"""
+ response = self.responses[self.i]
+ if self.i < len(self.responses) - 1:
+ self.i += 1
+ else:
+ self.i = 0
+ return response
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Return next response"""
+ response = self.responses[self.i]
+ if self.i < len(self.responses) - 1:
+ self.i += 1
+ else:
+ self.i = 0
+ return response
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ return {"responses": self.responses}
+
+
+class FakeStreamingListLLM(FakeListLLM):
+ """Fake streaming list LLM for testing purposes."""
+
+ def stream(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ *,
+ stop: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Iterator[str]:
+ result = self.invoke(input, config)
+ for c in result:
+ if self.sleep is not None:
+ time.sleep(self.sleep)
+ yield c
+
+ async def astream(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ *,
+ stop: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[str]:
+ result = await self.ainvoke(input, config)
+ for c in result:
+ if self.sleep is not None:
+ await asyncio.sleep(self.sleep)
+ yield c
diff --git a/libs/community/langchain_community/llms/fireworks.py b/libs/community/langchain_community/llms/fireworks.py
new file mode 100644
index 00000000000..c7d79066e4e
--- /dev/null
+++ b/libs/community/langchain_community/llms/fireworks.py
@@ -0,0 +1,371 @@
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
+from langchain_core.outputs import Generation, GenerationChunk, LLMResult
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str
+from langchain_core.utils.env import get_from_dict_or_env
+
+
+def _stream_response_to_generation_chunk(
+ stream_response: Any,
+) -> GenerationChunk:
+ """Convert a stream response to a generation chunk."""
+ return GenerationChunk(
+ text=stream_response.choices[0].text,
+ generation_info=dict(
+ finish_reason=stream_response.choices[0].finish_reason,
+ logprobs=stream_response.choices[0].logprobs,
+ ),
+ )
+
+
+class Fireworks(BaseLLM):
+ """Fireworks models."""
+
+ model: str = "accounts/fireworks/models/llama-v2-7b-chat"
+ model_kwargs: dict = Field(
+ default_factory=lambda: {
+ "temperature": 0.7,
+ "max_tokens": 512,
+ "top_p": 1,
+ }.copy()
+ )
+ fireworks_api_key: Optional[SecretStr] = None
+ max_retries: int = 20
+ batch_size: int = 20
+ use_retry: bool = True
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"fireworks_api_key": "FIREWORKS_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "llms", "fireworks"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key in environment."""
+ try:
+ import fireworks.client
+ except ImportError as e:
+ raise ImportError(
+ "Could not import fireworks-ai python package. "
+ "Please install it with `pip install fireworks-ai`."
+ ) from e
+ fireworks_api_key = convert_to_secret_str(
+ get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
+ )
+ fireworks.client.api_key = fireworks_api_key.get_secret_value()
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "fireworks"
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Call out to Fireworks endpoint with k unique prompts.
+ Args:
+ prompts: The prompts to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ The full LLM output.
+ """
+ params = {
+ "model": self.model,
+ **self.model_kwargs,
+ }
+ sub_prompts = self.get_batch_prompts(prompts)
+ choices = []
+ for _prompts in sub_prompts:
+ response = completion_with_retry_batching(
+ self,
+ self.use_retry,
+ prompt=_prompts,
+ run_manager=run_manager,
+ stop=stop,
+ **params,
+ )
+ choices.extend(response)
+
+ return self.create_llm_result(choices, prompts)
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Call out to Fireworks endpoint async with k unique prompts."""
+ params = {
+ "model": self.model,
+ **self.model_kwargs,
+ }
+ sub_prompts = self.get_batch_prompts(prompts)
+ choices = []
+ for _prompts in sub_prompts:
+ response = await acompletion_with_retry_batching(
+ self,
+ self.use_retry,
+ prompt=_prompts,
+ run_manager=run_manager,
+ stop=stop,
+ **params,
+ )
+ choices.extend(response)
+
+ return self.create_llm_result(choices, prompts)
+
+ def get_batch_prompts(
+ self,
+ prompts: List[str],
+ ) -> List[List[str]]:
+ """Get the sub prompts for llm call."""
+ sub_prompts = [
+ prompts[i : i + self.batch_size]
+ for i in range(0, len(prompts), self.batch_size)
+ ]
+ return sub_prompts
+
+ def create_llm_result(self, choices: Any, prompts: List[str]) -> LLMResult:
+ """Create the LLMResult from the choices and prompts."""
+ generations = []
+ for i, _ in enumerate(prompts):
+ sub_choices = choices[i : (i + 1)]
+ generations.append(
+ [
+ Generation(
+ text=choice.__dict__["choices"][0].text,
+ )
+ for choice in sub_choices
+ ]
+ )
+ llm_output = {"model": self.model}
+ return LLMResult(generations=generations, llm_output=llm_output)
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ params = {
+ "model": self.model,
+ "prompt": prompt,
+ "stream": True,
+ **self.model_kwargs,
+ }
+ for stream_resp in completion_with_retry(
+ self, self.use_retry, run_manager=run_manager, stop=stop, **params
+ ):
+ chunk = _stream_response_to_generation_chunk(stream_resp)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ params = {
+ "model": self.model,
+ "prompt": prompt,
+ "stream": True,
+ **self.model_kwargs,
+ }
+ async for stream_resp in await acompletion_with_retry_streaming(
+ self, self.use_retry, run_manager=run_manager, stop=stop, **params
+ ):
+ chunk = _stream_response_to_generation_chunk(stream_resp)
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
+
+
+def conditional_decorator(
+ condition: bool, decorator: Callable[[Any], Any]
+) -> Callable[[Any], Any]:
+ def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
+ if condition:
+ return decorator(func)
+ return func
+
+ return actual_decorator
+
+
+def completion_with_retry(
+ llm: Fireworks,
+ use_retry: bool,
+ *,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ import fireworks.client
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @conditional_decorator(use_retry, retry_decorator)
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return fireworks.client.Completion.create(
+ **kwargs,
+ )
+
+ return _completion_with_retry(**kwargs)
+
+
+async def acompletion_with_retry(
+ llm: Fireworks,
+ use_retry: bool,
+ *,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ import fireworks.client
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @conditional_decorator(use_retry, retry_decorator)
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ return await fireworks.client.Completion.acreate(
+ **kwargs,
+ )
+
+ return await _completion_with_retry(**kwargs)
+
+
+def completion_with_retry_batching(
+ llm: Fireworks,
+ use_retry: bool,
+ *,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ import fireworks.client
+
+ prompt = kwargs["prompt"]
+ del kwargs["prompt"]
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @conditional_decorator(use_retry, retry_decorator)
+ def _completion_with_retry(prompt: str) -> Any:
+ return fireworks.client.Completion.create(**kwargs, prompt=prompt)
+
+ def batch_sync_run() -> List:
+ with ThreadPoolExecutor() as executor:
+ results = list(executor.map(_completion_with_retry, prompt))
+ return results
+
+ return batch_sync_run()
+
+
+async def acompletion_with_retry_batching(
+ llm: Fireworks,
+ use_retry: bool,
+ *,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ import fireworks.client
+
+ prompt = kwargs["prompt"]
+ del kwargs["prompt"]
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @conditional_decorator(use_retry, retry_decorator)
+ async def _completion_with_retry(prompt: str) -> Any:
+ return await fireworks.client.Completion.acreate(**kwargs, prompt=prompt)
+
+ def run_coroutine_in_new_loop(
+ coroutine_func: Any, *args: Dict, **kwargs: Dict
+ ) -> Any:
+ new_loop = asyncio.new_event_loop()
+ try:
+ asyncio.set_event_loop(new_loop)
+ return new_loop.run_until_complete(coroutine_func(*args, **kwargs))
+ finally:
+ new_loop.close()
+
+ async def batch_sync_run() -> List:
+ with ThreadPoolExecutor() as executor:
+ results = list(
+ executor.map(
+ run_coroutine_in_new_loop,
+ [_completion_with_retry] * len(prompt),
+ prompt,
+ )
+ )
+ return results
+
+ return await batch_sync_run()
+
+
+async def acompletion_with_retry_streaming(
+ llm: Fireworks,
+ use_retry: bool,
+ *,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call for streaming."""
+ import fireworks.client
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @conditional_decorator(use_retry, retry_decorator)
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ return fireworks.client.Completion.acreate(
+ **kwargs,
+ )
+
+ return await _completion_with_retry(**kwargs)
+
+
+def _create_retry_decorator(
+ llm: Fireworks,
+ *,
+ run_manager: Optional[
+ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
+ ] = None,
+) -> Callable[[Any], Any]:
+ """Define retry mechanism."""
+ import fireworks.client
+
+ errors = [
+ fireworks.client.error.RateLimitError,
+ fireworks.client.error.InternalServerError,
+ fireworks.client.error.BadGatewayError,
+ fireworks.client.error.ServiceUnavailableError,
+ ]
+ return create_base_retry_decorator(
+ error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
+ )
diff --git a/libs/community/langchain_community/llms/forefrontai.py b/libs/community/langchain_community/llms/forefrontai.py
new file mode 100644
index 00000000000..b4220ad9a92
--- /dev/null
+++ b/libs/community/langchain_community/llms/forefrontai.py
@@ -0,0 +1,119 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class ForefrontAI(LLM):
+ """ForefrontAI large language models.
+
+ To use, you should have the environment variable ``FOREFRONTAI_API_KEY``
+ set with your API key.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import ForefrontAI
+ forefrontai = ForefrontAI(endpoint_url="")
+ """
+
+ endpoint_url: str = ""
+ """Model name to use."""
+
+ temperature: float = 0.7
+ """What sampling temperature to use."""
+
+ length: int = 256
+ """The maximum number of tokens to generate in the completion."""
+
+ top_p: float = 1.0
+ """Total probability mass of tokens to consider at each step."""
+
+ top_k: int = 40
+ """The number of highest probability vocabulary tokens to
+ keep for top-k-filtering."""
+
+ repetition_penalty: int = 1
+ """Penalizes repeated tokens according to frequency."""
+
+ forefrontai_api_key: SecretStr
+
+ base_url: Optional[str] = None
+ """Base url to use, if None decides based on model name."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ values["forefrontai_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "forefrontai_api_key", "FOREFRONTAI_API_KEY")
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Mapping[str, Any]:
+ """Get the default parameters for calling ForefrontAI API."""
+ return {
+ "temperature": self.temperature,
+ "length": self.length,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "repetition_penalty": self.repetition_penalty,
+ }
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"endpoint_url": self.endpoint_url}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "forefrontai"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to ForefrontAI's complete endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = ForefrontAI("Tell me a joke.")
+ """
+ auth_value = f"Bearer {self.forefrontai_api_key.get_secret_value()}"
+ response = requests.post(
+ url=self.endpoint_url,
+ headers={
+ "Authorization": auth_value,
+ "Content-Type": "application/json",
+ },
+ json={"text": prompt, **self._default_params, **kwargs},
+ )
+ response_json = response.json()
+ text = response_json["result"][0]["completion"]
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the model parameters
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/gigachat.py b/libs/community/langchain_community/llms/gigachat.py
new file mode 100644
index 00000000000..61f0893980a
--- /dev/null
+++ b/libs/community/langchain_community/llms/gigachat.py
@@ -0,0 +1,259 @@
+from __future__ import annotations
+
+import logging
+from functools import cached_property
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import BaseLLM
+from langchain_core.load.serializable import Serializable
+from langchain_core.outputs import Generation, GenerationChunk, LLMResult
+from langchain_core.pydantic_v1 import root_validator
+
+logger = logging.getLogger(__name__)
+
+
+class _BaseGigaChat(Serializable):
+ base_url: Optional[str] = None
+ """ Base API URL """
+ auth_url: Optional[str] = None
+ """ Auth URL """
+ credentials: Optional[str] = None
+ """ Auth Token """
+ scope: Optional[str] = None
+ """ Permission scope for access token """
+
+ access_token: Optional[str] = None
+ """ Access token for GigaChat """
+
+ model: Optional[str] = None
+ """Model name to use."""
+ user: Optional[str] = None
+ """ Username for authenticate """
+ password: Optional[str] = None
+ """ Password for authenticate """
+
+ timeout: Optional[float] = None
+ """ Timeout for request """
+ verify_ssl_certs: Optional[bool] = None
+ """ Check certificates for all requests """
+
+ ca_bundle_file: Optional[str] = None
+ cert_file: Optional[str] = None
+ key_file: Optional[str] = None
+ key_file_password: Optional[str] = None
+ # Support for connection to GigaChat through SSL certificates
+
+ profanity: bool = True
+ """ Check for profanity """
+ streaming: bool = False
+ """ Whether to stream the results or not. """
+ temperature: Optional[float] = None
+ """What sampling temperature to use."""
+ max_tokens: Optional[int] = None
+ """ Maximum number of tokens to generate """
+
+ @property
+ def _llm_type(self) -> str:
+ return "giga-chat-model"
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {
+ "credentials": "GIGACHAT_CREDENTIALS",
+ "access_token": "GIGACHAT_ACCESS_TOKEN",
+ "password": "GIGACHAT_PASSWORD",
+ "key_file_password": "GIGACHAT_KEY_FILE_PASSWORD",
+ }
+
+ @property
+ def lc_serializable(self) -> bool:
+ return True
+
+ @cached_property
+ def _client(self) -> Any:
+ """Returns GigaChat API client"""
+ import gigachat
+
+ return gigachat.GigaChat(
+ base_url=self.base_url,
+ auth_url=self.auth_url,
+ credentials=self.credentials,
+ scope=self.scope,
+ access_token=self.access_token,
+ model=self.model,
+ user=self.user,
+ password=self.password,
+ timeout=self.timeout,
+ verify_ssl_certs=self.verify_ssl_certs,
+ ca_bundle_file=self.ca_bundle_file,
+ cert_file=self.cert_file,
+ key_file=self.key_file,
+ key_file_password=self.key_file_password,
+ )
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate authenticate data in environment and python package is installed."""
+ try:
+ import gigachat # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import gigachat python package. "
+ "Please install it with `pip install gigachat`."
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "temperature": self.temperature,
+ "model": self.model,
+ "profanity": self.profanity,
+ "streaming": self.streaming,
+ "max_tokens": self.max_tokens,
+ }
+
+
+class GigaChat(_BaseGigaChat, BaseLLM):
+ """`GigaChat` large language models API.
+
+ To use, you should pass login and password to access GigaChat API or use token.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import GigaChat
+ giga = GigaChat(credentials=..., verify_ssl_certs=False)
+ """
+
+ def _build_payload(self, messages: List[str]) -> Dict[str, Any]:
+ payload: Dict[str, Any] = {
+ "messages": [{"role": "user", "content": m} for m in messages],
+ "profanity_check": self.profanity,
+ }
+ if self.temperature is not None:
+ payload["temperature"] = self.temperature
+ if self.max_tokens is not None:
+ payload["max_tokens"] = self.max_tokens
+ if self.model:
+ payload["model"] = self.model
+
+ if self.verbose:
+ logger.info("Giga request: %s", payload)
+
+ return payload
+
+ def _create_llm_result(self, response: Any) -> LLMResult:
+ generations = []
+ for res in response.choices:
+ finish_reason = res.finish_reason
+ gen = Generation(
+ text=res.message.content,
+ generation_info={"finish_reason": finish_reason},
+ )
+ generations.append([gen])
+ if finish_reason != "stop":
+ logger.warning(
+ "Giga generation stopped with reason: %s",
+ finish_reason,
+ )
+ if self.verbose:
+ logger.info("Giga response: %s", res.message.content)
+ token_usage = response.usage
+ llm_output = {"token_usage": token_usage, "model_name": response.model}
+ return LLMResult(generations=generations, llm_output=llm_output)
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ generation: Optional[GenerationChunk] = None
+ stream_iter = self._stream(
+ prompts[0], stop=stop, run_manager=run_manager, **kwargs
+ )
+ for chunk in stream_iter:
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ return LLMResult(generations=[[generation]])
+
+ payload = self._build_payload(prompts)
+ response = self._client.chat(payload)
+
+ return self._create_llm_result(response)
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ generation: Optional[GenerationChunk] = None
+ stream_iter = self._astream(
+ prompts[0], stop=stop, run_manager=run_manager, **kwargs
+ )
+ async for chunk in stream_iter:
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ return LLMResult(generations=[[generation]])
+
+ payload = self._build_payload(prompts)
+ response = await self._client.achat(payload)
+
+ return self._create_llm_result(response)
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ payload = self._build_payload([prompt])
+
+ for chunk in self._client.stream(payload):
+ if chunk.choices:
+ content = chunk.choices[0].delta.content
+ yield GenerationChunk(text=content)
+ if run_manager:
+ run_manager.on_llm_new_token(content)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ payload = self._build_payload([prompt])
+
+ async for chunk in self._client.astream(payload):
+ if chunk.choices:
+ content = chunk.choices[0].delta.content
+ yield GenerationChunk(text=content)
+ if run_manager:
+ await run_manager.on_llm_new_token(content)
+
+ def get_num_tokens(self, text: str) -> int:
+ """Count approximate number of tokens"""
+ return round(len(text) / 4.6)
diff --git a/libs/community/langchain_community/llms/google_palm.py b/libs/community/langchain_community/llms/google_palm.py
new file mode 100644
index 00000000000..8b0b7b0a574
--- /dev/null
+++ b/libs/community/langchain_community/llms/google_palm.py
@@ -0,0 +1,163 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.outputs import Generation, LLMResult
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms import BaseLLM
+from langchain_community.utilities.vertexai import create_retry_decorator
+
+
+def completion_with_retry(
+ llm: GooglePalm,
+ *args: Any,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = create_retry_decorator(
+ llm, max_retries=llm.max_retries, run_manager=run_manager
+ )
+
+ @retry_decorator
+ def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
+ return llm.client.generate_text(*args, **kwargs)
+
+ return _completion_with_retry(*args, **kwargs)
+
+
+def _strip_erroneous_leading_spaces(text: str) -> str:
+ """Strip erroneous leading spaces from text.
+
+ The PaLM API will sometimes erroneously return a single leading space in all
+ lines > 1. This function strips that space.
+ """
+ has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
+ if has_leading_space:
+ return text.replace("\n ", "\n")
+ else:
+ return text
+
+
+class GooglePalm(BaseLLM, BaseModel):
+ """Google PaLM models."""
+
+ client: Any #: :meta private:
+ google_api_key: Optional[str]
+ model_name: str = "models/text-bison-001"
+ """Model name to use."""
+ temperature: float = 0.7
+ """Run inference with this temperature. Must by in the closed interval
+ [0.0, 1.0]."""
+ top_p: Optional[float] = None
+ """Decode using nucleus sampling: consider the smallest set of tokens whose
+ probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
+ top_k: Optional[int] = None
+ """Decode using top-k sampling: consider the set of top_k most probable tokens.
+ Must be positive."""
+ max_output_tokens: Optional[int] = None
+ """Maximum number of tokens to include in a candidate. Must be greater than zero.
+ If unset, will default to 64."""
+ n: int = 1
+ """Number of chat completions to generate for each prompt. Note that the API may
+ not return the full n completions if duplicates are generated."""
+ max_retries: int = 6
+ """The maximum number of retries to make when generating."""
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"google_api_key": "GOOGLE_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(self) -> bool:
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "llms", "google_palm"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate api key, python package exists."""
+ google_api_key = get_from_dict_or_env(
+ values, "google_api_key", "GOOGLE_API_KEY"
+ )
+ try:
+ import google.generativeai as genai
+
+ genai.configure(api_key=google_api_key)
+ except ImportError:
+ raise ImportError(
+ "Could not import google-generativeai python package. "
+ "Please install it with `pip install google-generativeai`."
+ )
+
+ values["client"] = genai
+
+ if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
+ raise ValueError("temperature must be in the range [0.0, 1.0]")
+
+ if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
+ raise ValueError("top_p must be in the range [0.0, 1.0]")
+
+ if values["top_k"] is not None and values["top_k"] <= 0:
+ raise ValueError("top_k must be positive")
+
+ if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
+ raise ValueError("max_output_tokens must be greater than zero")
+
+ return values
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ generations = []
+ for prompt in prompts:
+ completion = completion_with_retry(
+ self,
+ model=self.model_name,
+ prompt=prompt,
+ stop_sequences=stop,
+ temperature=self.temperature,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ max_output_tokens=self.max_output_tokens,
+ candidate_count=self.n,
+ **kwargs,
+ )
+
+ prompt_generations = []
+ for candidate in completion.candidates:
+ raw_text = candidate["output"]
+ stripped_text = _strip_erroneous_leading_spaces(raw_text)
+ prompt_generations.append(Generation(text=stripped_text))
+ generations.append(prompt_generations)
+
+ return LLMResult(generations=generations)
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "google_palm"
+
+ def get_num_tokens(self, text: str) -> int:
+ """Get the number of tokens present in the text.
+
+ Useful for checking if an input will fit in a model's context window.
+
+ Args:
+ text: The string input to tokenize.
+
+ Returns:
+ The integer number of tokens in the text.
+ """
+ result = self.client.count_text_tokens(model=self.model_name, prompt=text)
+ return result["token_count"]
diff --git a/libs/community/langchain_community/llms/gooseai.py b/libs/community/langchain_community/llms/gooseai.py
new file mode 100644
index 00000000000..27ff257ab63
--- /dev/null
+++ b/libs/community/langchain_community/llms/gooseai.py
@@ -0,0 +1,152 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class GooseAI(LLM):
+ """GooseAI large language models.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``GOOSEAI_API_KEY`` set with your API key.
+
+ Any parameters that are valid to be passed to the openai.create call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import GooseAI
+ gooseai = GooseAI(model_name="gpt-neo-20b")
+
+ """
+
+ client: Any
+
+ model_name: str = "gpt-neo-20b"
+ """Model name to use"""
+
+ temperature: float = 0.7
+ """What sampling temperature to use"""
+
+ max_tokens: int = 256
+ """The maximum number of tokens to generate in the completion.
+ -1 returns as many tokens as possible given the prompt and
+ the models maximal context size."""
+
+ top_p: float = 1
+ """Total probability mass of tokens to consider at each step."""
+
+ min_tokens: int = 1
+ """The minimum number of tokens to generate in the completion."""
+
+ frequency_penalty: float = 0
+ """Penalizes repeated tokens according to frequency."""
+
+ presence_penalty: float = 0
+ """Penalizes repeated tokens."""
+
+ n: int = 1
+ """How many completions to generate for each prompt."""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not explicitly specified."""
+
+ logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
+ """Adjust the probability of specific tokens being generated."""
+
+ gooseai_api_key: Optional[SecretStr] = None
+
+ class Config:
+ """Configuration for this pydantic config."""
+
+ extra = Extra.ignore
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ gooseai_api_key = convert_to_secret_str(
+ get_from_dict_or_env(values, "gooseai_api_key", "GOOSEAI_API_KEY")
+ )
+ values["gooseai_api_key"] = gooseai_api_key
+ try:
+ import openai
+
+ openai.api_key = gooseai_api_key.get_secret_value()
+ openai.api_base = "https://api.goose.ai/v1"
+ values["client"] = openai.Completion
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling GooseAI API."""
+ normal_params = {
+ "temperature": self.temperature,
+ "max_tokens": self.max_tokens,
+ "top_p": self.top_p,
+ "min_tokens": self.min_tokens,
+ "frequency_penalty": self.frequency_penalty,
+ "presence_penalty": self.presence_penalty,
+ "n": self.n,
+ "logit_bias": self.logit_bias,
+ }
+ return {**normal_params, **self.model_kwargs}
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_name": self.model_name}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "gooseai"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the GooseAI API."""
+ params = self._default_params
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+
+ params = {**params, **kwargs}
+
+ response = self.client.create(engine=self.model_name, prompt=prompt, **params)
+ text = response.choices[0].text
+ return text
diff --git a/libs/community/langchain_community/llms/gpt4all.py b/libs/community/langchain_community/llms/gpt4all.py
new file mode 100644
index 00000000000..83dace226bb
--- /dev/null
+++ b/libs/community/langchain_community/llms/gpt4all.py
@@ -0,0 +1,211 @@
+from functools import partial
+from typing import Any, Dict, List, Mapping, Optional, Set
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class GPT4All(LLM):
+ """GPT4All language models.
+
+ To use, you should have the ``gpt4all`` python package installed, the
+ pre-trained model file, and the model's config information.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import GPT4All
+ model = GPT4All(model="./models/gpt4all-model.bin", n_threads=8)
+
+ # Simplest invocation
+ response = model("Once upon a time, ")
+ """
+
+ model: str
+ """Path to the pre-trained GPT4All model file."""
+
+ backend: Optional[str] = Field(None, alias="backend")
+
+ max_tokens: int = Field(200, alias="max_tokens")
+ """Token context window."""
+
+ n_parts: int = Field(-1, alias="n_parts")
+ """Number of parts to split the model into.
+ If -1, the number of parts is automatically determined."""
+
+ seed: int = Field(0, alias="seed")
+ """Seed. If -1, a random seed is used."""
+
+ f16_kv: bool = Field(False, alias="f16_kv")
+ """Use half-precision for key/value cache."""
+
+ logits_all: bool = Field(False, alias="logits_all")
+ """Return logits for all tokens, not just the last token."""
+
+ vocab_only: bool = Field(False, alias="vocab_only")
+ """Only load the vocabulary, no weights."""
+
+ use_mlock: bool = Field(False, alias="use_mlock")
+ """Force system to keep model in RAM."""
+
+ embedding: bool = Field(False, alias="embedding")
+ """Use embedding mode only."""
+
+ n_threads: Optional[int] = Field(4, alias="n_threads")
+ """Number of threads to use."""
+
+ n_predict: Optional[int] = 256
+ """The maximum number of tokens to generate."""
+
+ temp: Optional[float] = 0.7
+ """The temperature to use for sampling."""
+
+ top_p: Optional[float] = 0.1
+ """The top-p value to use for sampling."""
+
+ top_k: Optional[int] = 40
+ """The top-k value to use for sampling."""
+
+ echo: Optional[bool] = False
+ """Whether to echo the prompt."""
+
+ stop: Optional[List[str]] = []
+ """A list of strings to stop generation when encountered."""
+
+ repeat_last_n: Optional[int] = 64
+ "Last n tokens to penalize"
+
+ repeat_penalty: Optional[float] = 1.18
+ """The penalty to apply to repeated tokens."""
+
+ n_batch: int = Field(8, alias="n_batch")
+ """Batch size for prompt processing."""
+
+ streaming: bool = False
+ """Whether to stream the results or not."""
+
+ allow_download: bool = False
+ """If model does not exist in ~/.cache/gpt4all/, download it."""
+
+ device: Optional[str] = Field("cpu", alias="device")
+ """Device name: cpu, gpu, nvidia, intel, amd or DeviceName."""
+
+ client: Any = None #: :meta private:
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @staticmethod
+ def _model_param_names() -> Set[str]:
+ return {
+ "max_tokens",
+ "n_predict",
+ "top_k",
+ "top_p",
+ "temp",
+ "n_batch",
+ "repeat_penalty",
+ "repeat_last_n",
+ }
+
+ def _default_params(self) -> Dict[str, Any]:
+ return {
+ "max_tokens": self.max_tokens,
+ "n_predict": self.n_predict,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ "temp": self.temp,
+ "n_batch": self.n_batch,
+ "repeat_penalty": self.repeat_penalty,
+ "repeat_last_n": self.repeat_last_n,
+ }
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in the environment."""
+ try:
+ from gpt4all import GPT4All as GPT4AllModel
+ except ImportError:
+ raise ImportError(
+ "Could not import gpt4all python package. "
+ "Please install it with `pip install gpt4all`."
+ )
+
+ full_path = values["model"]
+ model_path, delimiter, model_name = full_path.rpartition("/")
+ model_path += delimiter
+
+ values["client"] = GPT4AllModel(
+ model_name,
+ model_path=model_path or None,
+ model_type=values["backend"],
+ allow_download=values["allow_download"],
+ device=values["device"],
+ )
+ if values["n_threads"] is not None:
+ # set n_threads
+ values["client"].model.set_thread_count(values["n_threads"])
+
+ try:
+ values["backend"] = values["client"].model_type
+ except AttributeError:
+ # The below is for compatibility with GPT4All Python bindings <= 0.2.3.
+ values["backend"] = values["client"].model.model_type
+
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model": self.model,
+ **self._default_params(),
+ **{
+ k: v for k, v in self.__dict__.items() if k in self._model_param_names()
+ },
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return the type of llm."""
+ return "gpt4all"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ r"""Call out to GPT4All's generate method.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: A list of strings to stop generation when encountered.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ prompt = "Once upon a time, "
+ response = model(prompt, n_predict=55)
+ """
+ text_callback = None
+ if run_manager:
+ text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
+ text = ""
+ params = {**self._default_params(), **kwargs}
+ for token in self.client.generate(prompt, **params):
+ if text_callback:
+ text_callback(token)
+ text += token
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/gradient_ai.py b/libs/community/langchain_community/llms/gradient_ai.py
new file mode 100644
index 00000000000..23ffc1a193f
--- /dev/null
+++ b/libs/community/langchain_community/llms/gradient_ai.py
@@ -0,0 +1,402 @@
+import asyncio
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
+
+import aiohttp
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import BaseLLM
+from langchain_core.outputs import Generation, LLMResult
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class TrainResult(TypedDict):
+ """Train result."""
+
+ loss: float
+
+
+class GradientLLM(BaseLLM):
+ """Gradient.ai LLM Endpoints.
+
+ GradientLLM is a class to interact with LLMs on gradient.ai
+
+ To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
+ API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
+ or alternatively provide them as keywords to the constructor of this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import GradientLLM
+ GradientLLM(
+ model="99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model",
+ model_kwargs={
+ "max_generated_token_count": 128,
+ "temperature": 0.75,
+ "top_p": 0.95,
+ "top_k": 20,
+ "stop": [],
+ },
+ gradient_workspace_id="12345614fc0_workspace",
+ gradient_access_token="gradientai-access_token",
+ )
+
+ """
+
+ model_id: str = Field(alias="model", min_length=2)
+ "Underlying gradient.ai model id (base or fine-tuned)."
+
+ gradient_workspace_id: Optional[str] = None
+ "Underlying gradient.ai workspace_id."
+
+ gradient_access_token: Optional[str] = None
+ """gradient.ai API Token, which can be generated by going to
+ https://auth.gradient.ai/select-workspace
+ and selecting "Access tokens" under the profile drop-down.
+ """
+
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+
+ gradient_api_url: str = "https://api.gradient.ai/api"
+ """Endpoint URL to use."""
+
+ aiosession: Optional[aiohttp.ClientSession] = None #: :meta private:
+ """ClientSession, private, subject to change in upcoming releases."""
+
+ # LLM call kwargs
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+ extra = Extra.forbid
+
+ @root_validator(allow_reuse=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+
+ values["gradient_access_token"] = get_from_dict_or_env(
+ values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN"
+ )
+ values["gradient_workspace_id"] = get_from_dict_or_env(
+ values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
+ )
+
+ if (
+ values["gradient_access_token"] is None
+ or len(values["gradient_access_token"]) < 10
+ ):
+ raise ValueError("env variable `GRADIENT_ACCESS_TOKEN` must be set")
+
+ if (
+ values["gradient_workspace_id"] is None
+ or len(values["gradient_access_token"]) < 3
+ ):
+ raise ValueError("env variable `GRADIENT_WORKSPACE_ID` must be set")
+
+ if values["model_kwargs"]:
+ kw = values["model_kwargs"]
+ if not 0 <= kw.get("temperature", 0.5) <= 1:
+ raise ValueError("`temperature` must be in the range [0.0, 1.0]")
+
+ if not 0 <= kw.get("top_p", 0.5) <= 1:
+ raise ValueError("`top_p` must be in the range [0.0, 1.0]")
+
+ if 0 >= kw.get("top_k", 0.5):
+ raise ValueError("`top_k` must be positive")
+
+ if 0 >= kw.get("max_generated_token_count", 1):
+ raise ValueError("`max_generated_token_count` must be positive")
+
+ values["gradient_api_url"] = get_from_dict_or_env(
+ values, "gradient_api_url", "GRADIENT_API_URL"
+ )
+
+ try:
+ import gradientai # noqa
+ except ImportError:
+ logging.warning(
+ "DeprecationWarning: `GradientLLM` will use "
+ "`pip install gradientai` in future releases of langchain."
+ )
+ except Exception:
+ pass
+
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"gradient_api_url": self.gradient_api_url},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "gradient"
+
+ def _kwargs_post_fine_tune_request(
+ self, inputs: Sequence[str], kwargs: Mapping[str, Any]
+ ) -> Mapping[str, Any]:
+ """Build the kwargs for the Post request, used by sync
+
+ Args:
+ prompt (str): prompt used in query
+ kwargs (dict): model kwargs in payload
+
+ Returns:
+ Dict[str, Union[str,dict]]: _description_
+ """
+ _model_kwargs = self.model_kwargs or {}
+ _params = {**_model_kwargs, **kwargs}
+
+ multipliers = _params.get("multipliers", None)
+
+ return dict(
+ url=f"{self.gradient_api_url}/models/{self.model_id}/fine-tune",
+ headers={
+ "authorization": f"Bearer {self.gradient_access_token}",
+ "x-gradient-workspace-id": f"{self.gradient_workspace_id}",
+ "accept": "application/json",
+ "content-type": "application/json",
+ },
+ json=dict(
+ samples=tuple(
+ {
+ "inputs": input,
+ }
+ for input in inputs
+ )
+ if multipliers is None
+ else tuple(
+ {
+ "inputs": input,
+ "fineTuningParameters": {
+ "multiplier": multiplier,
+ },
+ }
+ for input, multiplier in zip(inputs, multipliers)
+ ),
+ ),
+ )
+
+ def _kwargs_post_request(
+ self, prompt: str, kwargs: Mapping[str, Any]
+ ) -> Mapping[str, Any]:
+ """Build the kwargs for the Post request, used by sync
+
+ Args:
+ prompt (str): prompt used in query
+ kwargs (dict): model kwargs in payload
+
+ Returns:
+ Dict[str, Union[str,dict]]: _description_
+ """
+ _model_kwargs = self.model_kwargs or {}
+ _params = {**_model_kwargs, **kwargs}
+
+ return dict(
+ url=f"{self.gradient_api_url}/models/{self.model_id}/complete",
+ headers={
+ "authorization": f"Bearer {self.gradient_access_token}",
+ "x-gradient-workspace-id": f"{self.gradient_workspace_id}",
+ "accept": "application/json",
+ "content-type": "application/json",
+ },
+ json=dict(
+ query=prompt,
+ maxGeneratedTokenCount=_params.get("max_generated_token_count", None),
+ temperature=_params.get("temperature", None),
+ topK=_params.get("top_k", None),
+ topP=_params.get("top_p", None),
+ ),
+ )
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call to Gradients API `model/{id}/complete`.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+ """
+ try:
+ response = requests.post(**self._kwargs_post_request(prompt, kwargs))
+ if response.status_code != 200:
+ raise Exception(
+ f"Gradient returned an unexpected response with status "
+ f"{response.status_code}: {response.text}"
+ )
+ except requests.exceptions.RequestException as e:
+ raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
+
+ text = response.json()["generatedOutput"]
+
+ if stop is not None:
+ # Apply stop tokens when making calls to Gradient
+ text = enforce_stop_tokens(text, stop)
+
+ return text
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Async Call to Gradients API `model/{id}/complete`.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+ """
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ **self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
+ ) as response:
+ if response.status != 200:
+ raise Exception(
+ f"Gradient returned an unexpected response with status "
+ f"{response.status}: {response.text}"
+ )
+ text = (await response.json())["generatedOutput"]
+ else:
+ async with self.aiosession.post(
+ **self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
+ ) as response:
+ if response.status != 200:
+ raise Exception(
+ f"Gradient returned an unexpected response with status "
+ f"{response.status}: {response.text}"
+ )
+ text = (await response.json())["generatedOutput"]
+
+ if stop is not None:
+ # Apply stop tokens when making calls to Gradient
+ text = enforce_stop_tokens(text, stop)
+
+ return text
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Run the LLM on the given prompt and input."""
+
+ # same thing with threading
+ def _inner_generate(prompt: str) -> List[Generation]:
+ return [
+ Generation(
+ text=self._call(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ )
+ )
+ ]
+
+ if len(prompts) <= 1:
+ generations = list(map(_inner_generate, prompts))
+ else:
+ with ThreadPoolExecutor(min(8, len(prompts))) as p:
+ generations = list(p.map(_inner_generate, prompts))
+
+ return LLMResult(generations=generations)
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Run the LLM on the given prompt and input."""
+ generations = []
+ for generation in asyncio.gather(
+ [self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)]
+ for prompt in prompts
+ ):
+ generations.append([Generation(text=generation)])
+ return LLMResult(generations=generations)
+
+ def train_unsupervised(
+ self,
+ inputs: Sequence[str],
+ **kwargs: Any,
+ ) -> TrainResult:
+ try:
+ response = requests.post(
+ **self._kwargs_post_fine_tune_request(inputs, kwargs)
+ )
+ if response.status_code != 200:
+ raise Exception(
+ f"Gradient returned an unexpected response with status "
+ f"{response.status_code}: {response.text}"
+ )
+ except requests.exceptions.RequestException as e:
+ raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
+
+ response_json = response.json()
+ loss = response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
+ return TrainResult(loss=loss)
+
+ async def atrain_unsupervised(
+ self,
+ inputs: Sequence[str],
+ **kwargs: Any,
+ ) -> TrainResult:
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ **self._kwargs_post_fine_tune_request(inputs, kwargs)
+ ) as response:
+ if response.status != 200:
+ raise Exception(
+ f"Gradient returned an unexpected response with status "
+ f"{response.status}: {response.text}"
+ )
+ response_json = await response.json()
+ loss = (
+ response_json["sumLoss"]
+ / response_json["numberOfTrainableTokens"]
+ )
+ else:
+ async with self.aiosession.post(
+ **self._kwargs_post_fine_tune_request(inputs, kwargs)
+ ) as response:
+ if response.status != 200:
+ raise Exception(
+ f"Gradient returned an unexpected response with status "
+ f"{response.status}: {response.text}"
+ )
+ response_json = await response.json()
+ loss = (
+ response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
+ )
+
+ return TrainResult(loss=loss)
diff --git a/libs/community/langchain_community/llms/grammars/json.gbnf b/libs/community/langchain_community/llms/grammars/json.gbnf
new file mode 100644
index 00000000000..61bd2b2e65b
--- /dev/null
+++ b/libs/community/langchain_community/llms/grammars/json.gbnf
@@ -0,0 +1,29 @@
+# Grammar for subset of JSON - doesn't support full string or number syntax
+
+root ::= object
+value ::= object | array | string | number | boolean | "null"
+
+object ::=
+ "{" ws (
+ string ":" ws value
+ ("," ws string ":" ws value)*
+ )? "}"
+
+array ::=
+ "[" ws (
+ value
+ ("," ws value)*
+ )? "]"
+
+string ::=
+ "\"" (
+ [^"\\] |
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
+ )* "\"" ws
+
+# Only plain integers currently
+number ::= "-"? [0-9]+ ws
+boolean ::= ("true" | "false") ws
+
+# Optional space: by convention, applied in this grammar after literal chars when allowed
+ws ::= ([ \t\n] ws)?
\ No newline at end of file
diff --git a/libs/community/langchain_community/llms/grammars/list.gbnf b/libs/community/langchain_community/llms/grammars/list.gbnf
new file mode 100644
index 00000000000..30ea6e0c849
--- /dev/null
+++ b/libs/community/langchain_community/llms/grammars/list.gbnf
@@ -0,0 +1,14 @@
+root ::= "[" items "]" EOF
+
+items ::= item ("," ws* item)*
+
+item ::= string
+
+string ::=
+ "\"" word (ws+ word)* "\"" ws*
+
+word ::= [a-zA-Z]+
+
+ws ::= " "
+
+EOF ::= "\n"
\ No newline at end of file
diff --git a/libs/community/langchain_community/llms/huggingface_endpoint.py b/libs/community/langchain_community/llms/huggingface_endpoint.py
new file mode 100644
index 00000000000..d429e2fd935
--- /dev/null
+++ b/libs/community/langchain_community/llms/huggingface_endpoint.py
@@ -0,0 +1,156 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
+
+
+class HuggingFaceEndpoint(LLM):
+ """HuggingFace Endpoint models.
+
+ To use, you should have the ``huggingface_hub`` python package installed, and the
+ environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
+ it as a named parameter to the constructor.
+
+ Only supports `text-generation` and `text2text-generation` for now.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import HuggingFaceEndpoint
+ endpoint_url = (
+ "https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
+ )
+ hf = HuggingFaceEndpoint(
+ endpoint_url=endpoint_url,
+ huggingfacehub_api_token="my-api-key"
+ )
+ """
+
+ endpoint_url: str = ""
+ """Endpoint URL to use."""
+ task: Optional[str] = None
+ """Task to call the model with.
+ Should be a task that returns `generated_text` or `summary_text`."""
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+
+ huggingfacehub_api_token: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ huggingfacehub_api_token = get_from_dict_or_env(
+ values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
+ )
+ try:
+ from huggingface_hub.hf_api import HfApi
+
+ try:
+ HfApi(
+ endpoint="https://huggingface.co", # Can be a Private Hub endpoint.
+ token=huggingfacehub_api_token,
+ ).whoami()
+ except Exception as e:
+ raise ValueError(
+ "Could not authenticate with huggingface_hub. "
+ "Please check your API token."
+ ) from e
+
+ except ImportError:
+ raise ImportError(
+ "Could not import huggingface_hub python package. "
+ "Please install it with `pip install huggingface_hub`."
+ )
+ values["huggingfacehub_api_token"] = huggingfacehub_api_token
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"endpoint_url": self.endpoint_url, "task": self.task},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "huggingface_endpoint"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to HuggingFace Hub's inference endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = hf("Tell me a joke.")
+ """
+ _model_kwargs = self.model_kwargs or {}
+
+ # payload samples
+ params = {**_model_kwargs, **kwargs}
+ parameter_payload = {"inputs": prompt, "parameters": params}
+
+ # HTTP headers for authorization
+ headers = {
+ "Authorization": f"Bearer {self.huggingfacehub_api_token}",
+ "Content-Type": "application/json",
+ }
+
+ # send request
+ try:
+ response = requests.post(
+ self.endpoint_url, headers=headers, json=parameter_payload
+ )
+ except requests.exceptions.RequestException as e: # This is the correct syntax
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+ generated_text = response.json()
+ if "error" in generated_text:
+ raise ValueError(
+ f"Error raised by inference API: {generated_text['error']}"
+ )
+ if self.task == "text-generation":
+ text = generated_text[0]["generated_text"]
+ # Remove prompt if included in generated text.
+ if text.startswith(prompt):
+ text = text[len(prompt) :]
+ elif self.task == "text2text-generation":
+ text = generated_text[0]["generated_text"]
+ elif self.task == "summarization":
+ text = generated_text[0]["summary_text"]
+ else:
+ raise ValueError(
+ f"Got invalid task {self.task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ if stop is not None:
+ # This is a bit hacky, but I can't figure out a better way to enforce
+ # stop tokens when making calls to huggingface_hub.
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/huggingface_hub.py b/libs/community/langchain_community/llms/huggingface_hub.py
new file mode 100644
index 00000000000..32facc244b0
--- /dev/null
+++ b/libs/community/langchain_community/llms/huggingface_hub.py
@@ -0,0 +1,130 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+DEFAULT_REPO_ID = "gpt2"
+VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
+
+
+class HuggingFaceHub(LLM):
+ """HuggingFaceHub models.
+
+ To use, you should have the ``huggingface_hub`` python package installed, and the
+ environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
+ it as a named parameter to the constructor.
+
+ Only supports `text-generation`, `text2text-generation` and `summarization` for now.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import HuggingFaceHub
+ hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
+ """
+
+ client: Any #: :meta private:
+ repo_id: str = DEFAULT_REPO_ID
+ """Model name to use."""
+ task: Optional[str] = None
+ """Task to call the model with.
+ Should be a task that returns `generated_text` or `summary_text`."""
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+
+ huggingfacehub_api_token: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ huggingfacehub_api_token = get_from_dict_or_env(
+ values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
+ )
+ try:
+ from huggingface_hub.inference_api import InferenceApi
+
+ repo_id = values["repo_id"]
+ client = InferenceApi(
+ repo_id=repo_id,
+ token=huggingfacehub_api_token,
+ task=values.get("task"),
+ )
+ if client.task not in VALID_TASKS:
+ raise ValueError(
+ f"Got invalid task {client.task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ values["client"] = client
+ except ImportError:
+ raise ValueError(
+ "Could not import huggingface_hub python package. "
+ "Please install it with `pip install huggingface_hub`."
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"repo_id": self.repo_id, "task": self.task},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "huggingface_hub"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to HuggingFace Hub's inference endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = hf("Tell me a joke.")
+ """
+ _model_kwargs = self.model_kwargs or {}
+ params = {**_model_kwargs, **kwargs}
+ response = self.client(inputs=prompt, params=params)
+ if "error" in response:
+ raise ValueError(f"Error raised by inference API: {response['error']}")
+ if self.client.task == "text-generation":
+ # Text generation return includes the starter text.
+ text = response[0]["generated_text"][len(prompt) :]
+ elif self.client.task == "text2text-generation":
+ text = response[0]["generated_text"]
+ elif self.client.task == "summarization":
+ text = response[0]["summary_text"]
+ else:
+ raise ValueError(
+ f"Got invalid task {self.client.task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ if stop is not None:
+ # This is a bit hacky, but I can't figure out a better way to enforce
+ # stop tokens when making calls to huggingface_hub.
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/huggingface_pipeline.py b/libs/community/langchain_community/llms/huggingface_pipeline.py
new file mode 100644
index 00000000000..9b2e94db326
--- /dev/null
+++ b/libs/community/langchain_community/llms/huggingface_pipeline.py
@@ -0,0 +1,247 @@
+from __future__ import annotations
+
+import importlib.util
+import logging
+from typing import Any, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import BaseLLM
+from langchain_core.outputs import Generation, LLMResult
+from langchain_core.pydantic_v1 import Extra
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+DEFAULT_MODEL_ID = "gpt2"
+DEFAULT_TASK = "text-generation"
+VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
+DEFAULT_BATCH_SIZE = 4
+
+logger = logging.getLogger(__name__)
+
+
+class HuggingFacePipeline(BaseLLM):
+ """HuggingFace Pipeline API.
+
+ To use, you should have the ``transformers`` python package installed.
+
+ Only supports `text-generation`, `text2text-generation` and `summarization` for now.
+
+ Example using from_model_id:
+ .. code-block:: python
+
+ from langchain_community.llms import HuggingFacePipeline
+ hf = HuggingFacePipeline.from_model_id(
+ model_id="gpt2",
+ task="text-generation",
+ pipeline_kwargs={"max_new_tokens": 10},
+ )
+ Example passing pipeline in directly:
+ .. code-block:: python
+
+ from langchain_community.llms import HuggingFacePipeline
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
+
+ model_id = "gpt2"
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ model = AutoModelForCausalLM.from_pretrained(model_id)
+ pipe = pipeline(
+ "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
+ )
+ hf = HuggingFacePipeline(pipeline=pipe)
+ """
+
+ pipeline: Any #: :meta private:
+ model_id: str = DEFAULT_MODEL_ID
+ """Model name to use."""
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments passed to the model."""
+ pipeline_kwargs: Optional[dict] = None
+ """Keyword arguments passed to the pipeline."""
+ batch_size: int = DEFAULT_BATCH_SIZE
+ """Batch size to use when passing multiple documents to generate."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @classmethod
+ def from_model_id(
+ cls,
+ model_id: str,
+ task: str,
+ device: Optional[int] = -1,
+ device_map: Optional[str] = None,
+ model_kwargs: Optional[dict] = None,
+ pipeline_kwargs: Optional[dict] = None,
+ batch_size: int = DEFAULT_BATCH_SIZE,
+ **kwargs: Any,
+ ) -> HuggingFacePipeline:
+ """Construct the pipeline object from model_id and task."""
+ try:
+ from transformers import (
+ AutoModelForCausalLM,
+ AutoModelForSeq2SeqLM,
+ AutoTokenizer,
+ )
+ from transformers import pipeline as hf_pipeline
+
+ except ImportError:
+ raise ValueError(
+ "Could not import transformers python package. "
+ "Please install it with `pip install transformers`."
+ )
+
+ _model_kwargs = model_kwargs or {}
+ tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
+
+ try:
+ if task == "text-generation":
+ model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
+ elif task in ("text2text-generation", "summarization"):
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
+ else:
+ raise ValueError(
+ f"Got invalid task {task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ except ImportError as e:
+ raise ValueError(
+ f"Could not load the {task} model due to missing dependencies."
+ ) from e
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token_id = model.config.eos_token_id
+
+ if (
+ getattr(model, "is_loaded_in_4bit", False)
+ or getattr(model, "is_loaded_in_8bit", False)
+ ) and device is not None:
+ logger.warning(
+ f"Setting the `device` argument to None from {device} to avoid "
+ "the error caused by attempting to move the model that was already "
+ "loaded on the GPU using the Accelerate module to the same or "
+ "another device."
+ )
+ device = None
+
+ if device is not None and importlib.util.find_spec("torch") is not None:
+ import torch
+
+ cuda_device_count = torch.cuda.device_count()
+ if device < -1 or (device >= cuda_device_count):
+ raise ValueError(
+ f"Got device=={device}, "
+ f"device is required to be within [-1, {cuda_device_count})"
+ )
+ if device_map is not None and device < 0:
+ device = None
+ if device is not None and device < 0 and cuda_device_count > 0:
+ logger.warning(
+ "Device has %d GPUs available. "
+ "Provide device={deviceId} to `from_model_id` to use available"
+ "GPUs for execution. deviceId is -1 (default) for CPU and "
+ "can be a positive integer associated with CUDA device id.",
+ cuda_device_count,
+ )
+ if "trust_remote_code" in _model_kwargs:
+ _model_kwargs = {
+ k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
+ }
+ _pipeline_kwargs = pipeline_kwargs or {}
+ pipeline = hf_pipeline(
+ task=task,
+ model=model,
+ tokenizer=tokenizer,
+ device=device,
+ device_map=device_map,
+ batch_size=batch_size,
+ model_kwargs=_model_kwargs,
+ **_pipeline_kwargs,
+ )
+ if pipeline.task not in VALID_TASKS:
+ raise ValueError(
+ f"Got invalid task {pipeline.task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ return cls(
+ pipeline=pipeline,
+ model_id=model_id,
+ model_kwargs=_model_kwargs,
+ pipeline_kwargs=_pipeline_kwargs,
+ batch_size=batch_size,
+ **kwargs,
+ )
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model_id": self.model_id,
+ "model_kwargs": self.model_kwargs,
+ "pipeline_kwargs": self.pipeline_kwargs,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ return "huggingface_pipeline"
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ # List to hold all results
+ text_generations: List[str] = []
+
+ for i in range(0, len(prompts), self.batch_size):
+ batch_prompts = prompts[i : i + self.batch_size]
+
+ # Process batch of prompts
+ responses = self.pipeline(batch_prompts)
+
+ # Process each response in the batch
+ for j, response in enumerate(responses):
+ if isinstance(response, list):
+ # if model returns multiple generations, pick the top one
+ response = response[0]
+
+ if self.pipeline.task == "text-generation":
+ try:
+ from transformers.pipelines.text_generation import ReturnType
+
+ remove_prompt = (
+ self.pipeline._postprocess_params.get("return_type")
+ != ReturnType.NEW_TEXT
+ )
+ except Exception as e:
+ logger.warning(
+ f"Unable to extract pipeline return_type. "
+ f"Received error:\n\n{e}"
+ )
+ remove_prompt = True
+ if remove_prompt:
+ text = response["generated_text"][len(batch_prompts[j]) :]
+ else:
+ text = response["generated_text"]
+ elif self.pipeline.task == "text2text-generation":
+ text = response["generated_text"]
+ elif self.pipeline.task == "summarization":
+ text = response["summary_text"]
+ else:
+ raise ValueError(
+ f"Got invalid task {self.pipeline.task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ if stop:
+ # Enforce stop tokens
+ text = enforce_stop_tokens(text, stop)
+
+ # Append the processed text to results
+ text_generations.append(text)
+
+ return LLMResult(
+ generations=[[Generation(text=text)] for text in text_generations]
+ )
diff --git a/libs/community/langchain_community/llms/huggingface_text_gen_inference.py b/libs/community/langchain_community/llms/huggingface_text_gen_inference.py
new file mode 100644
index 00000000000..b23aaa6e9bd
--- /dev/null
+++ b/libs/community/langchain_community/llms/huggingface_text_gen_inference.py
@@ -0,0 +1,301 @@
+import logging
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_pydantic_field_names
+
+logger = logging.getLogger(__name__)
+
+
+class HuggingFaceTextGenInference(LLM):
+ """
+ HuggingFace text generation API.
+
+ To use, you should have the `text-generation` python package installed and
+ a text-generation server running.
+
+ Example:
+ .. code-block:: python
+
+ # Basic Example (no streaming)
+ llm = HuggingFaceTextGenInference(
+ inference_server_url="http://localhost:8010/",
+ max_new_tokens=512,
+ top_k=10,
+ top_p=0.95,
+ typical_p=0.95,
+ temperature=0.01,
+ repetition_penalty=1.03,
+ )
+ print(llm("What is Deep Learning?"))
+
+ # Streaming response example
+ from langchain_community.callbacks import streaming_stdout
+
+ callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
+ llm = HuggingFaceTextGenInference(
+ inference_server_url="http://localhost:8010/",
+ max_new_tokens=512,
+ top_k=10,
+ top_p=0.95,
+ typical_p=0.95,
+ temperature=0.01,
+ repetition_penalty=1.03,
+ callbacks=callbacks,
+ streaming=True
+ )
+ print(llm("What is Deep Learning?"))
+
+ """
+
+ max_new_tokens: int = 512
+ """Maximum number of generated tokens"""
+ top_k: Optional[int] = None
+ """The number of highest probability vocabulary tokens to keep for
+ top-k-filtering."""
+ top_p: Optional[float] = 0.95
+ """If set to < 1, only the smallest set of most probable tokens with probabilities
+ that add up to `top_p` or higher are kept for generation."""
+ typical_p: Optional[float] = 0.95
+ """Typical Decoding mass. See [Typical Decoding for Natural Language
+ Generation](https://arxiv.org/abs/2202.00666) for more information."""
+ temperature: Optional[float] = 0.8
+ """The value used to module the logits distribution."""
+ repetition_penalty: Optional[float] = None
+ """The parameter for repetition penalty. 1.0 means no penalty.
+ See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""
+ return_full_text: bool = False
+ """Whether to prepend the prompt to the generated text"""
+ truncate: Optional[int] = None
+ """Truncate inputs tokens to the given size"""
+ stop_sequences: List[str] = Field(default_factory=list)
+ """Stop generating tokens if a member of `stop_sequences` is generated"""
+ seed: Optional[int] = None
+ """Random sampling seed"""
+ inference_server_url: str = ""
+ """text-generation-inference instance base url"""
+ timeout: int = 120
+ """Timeout in seconds"""
+ streaming: bool = False
+ """Whether to generate a stream of tokens asynchronously"""
+ do_sample: bool = False
+ """Activate logits sampling"""
+ watermark: bool = False
+ """Watermarking with [A Watermark for Large Language Models]
+ (https://arxiv.org/abs/2301.10226)"""
+ server_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any text-generation-inference server parameters not explicitly specified"""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `call` not explicitly specified"""
+ client: Any
+ async_client: Any
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+
+ try:
+ import text_generation
+
+ values["client"] = text_generation.Client(
+ values["inference_server_url"],
+ timeout=values["timeout"],
+ **values["server_kwargs"],
+ )
+ values["async_client"] = text_generation.AsyncClient(
+ values["inference_server_url"],
+ timeout=values["timeout"],
+ **values["server_kwargs"],
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import text_generation python package. "
+ "Please install it with `pip install text_generation`."
+ )
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "huggingface_textgen_inference"
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling text generation inference API."""
+ return {
+ "max_new_tokens": self.max_new_tokens,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ "typical_p": self.typical_p,
+ "temperature": self.temperature,
+ "repetition_penalty": self.repetition_penalty,
+ "return_full_text": self.return_full_text,
+ "truncate": self.truncate,
+ "stop_sequences": self.stop_sequences,
+ "seed": self.seed,
+ "do_sample": self.do_sample,
+ "watermark": self.watermark,
+ **self.model_kwargs,
+ }
+
+ def _invocation_params(
+ self, runtime_stop: Optional[List[str]], **kwargs: Any
+ ) -> Dict[str, Any]:
+ params = {**self._default_params, **kwargs}
+ params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
+ return params
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ if self.streaming:
+ completion = ""
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
+ completion += chunk.text
+ return completion
+
+ invocation_params = self._invocation_params(stop, **kwargs)
+ res = self.client.generate(prompt, **invocation_params)
+ # remove stop sequences from the end of the generated text
+ for stop_seq in invocation_params["stop_sequences"]:
+ if stop_seq in res.generated_text:
+ res.generated_text = res.generated_text[
+ : res.generated_text.index(stop_seq)
+ ]
+ return res.generated_text
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ if self.streaming:
+ completion = ""
+ async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
+ completion += chunk.text
+ return completion
+
+ invocation_params = self._invocation_params(stop, **kwargs)
+ res = await self.async_client.generate(prompt, **invocation_params)
+ # remove stop sequences from the end of the generated text
+ for stop_seq in invocation_params["stop_sequences"]:
+ if stop_seq in res.generated_text:
+ res.generated_text = res.generated_text[
+ : res.generated_text.index(stop_seq)
+ ]
+ return res.generated_text
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ invocation_params = self._invocation_params(stop, **kwargs)
+
+ for res in self.client.generate_stream(prompt, **invocation_params):
+ # identify stop sequence in generated text, if any
+ stop_seq_found: Optional[str] = None
+ for stop_seq in invocation_params["stop_sequences"]:
+ if stop_seq in res.token.text:
+ stop_seq_found = stop_seq
+
+ # identify text to yield
+ text: Optional[str] = None
+ if res.token.special:
+ text = None
+ elif stop_seq_found:
+ text = res.token.text[: res.token.text.index(stop_seq_found)]
+ else:
+ text = res.token.text
+
+ # yield text, if any
+ if text:
+ chunk = GenerationChunk(text=text)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text)
+
+ # break if stop sequence found
+ if stop_seq_found:
+ break
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ invocation_params = self._invocation_params(stop, **kwargs)
+
+ async for res in self.async_client.generate_stream(prompt, **invocation_params):
+ # identify stop sequence in generated text, if any
+ stop_seq_found: Optional[str] = None
+ for stop_seq in invocation_params["stop_sequences"]:
+ if stop_seq in res.token.text:
+ stop_seq_found = stop_seq
+
+ # identify text to yield
+ text: Optional[str] = None
+ if res.token.special:
+ text = None
+ elif stop_seq_found:
+ text = res.token.text[: res.token.text.index(stop_seq_found)]
+ else:
+ text = res.token.text
+
+ # yield text, if any
+ if text:
+ chunk = GenerationChunk(text=text)
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.text)
+
+ # break if stop sequence found
+ if stop_seq_found:
+ break
diff --git a/libs/community/langchain_community/llms/human.py b/libs/community/langchain_community/llms/human.py
new file mode 100644
index 00000000000..8ee75db3c4a
--- /dev/null
+++ b/libs/community/langchain_community/llms/human.py
@@ -0,0 +1,85 @@
+from typing import Any, Callable, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Field
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+def _display_prompt(prompt: str) -> None:
+ """Displays the given prompt to the user."""
+ print(f"\n{prompt}")
+
+
+def _collect_user_input(
+ separator: Optional[str] = None, stop: Optional[List[str]] = None
+) -> str:
+ """Collects and returns user input as a single string."""
+ separator = separator or "\n"
+ lines = []
+
+ while True:
+ line = input()
+ if not line:
+ break
+ lines.append(line)
+
+ if stop and any(seq in line for seq in stop):
+ break
+ # Combine all lines into a single string
+ multi_line_input = separator.join(lines)
+ return multi_line_input
+
+
+class HumanInputLLM(LLM):
+ """
+ It returns user input as the response.
+ """
+
+ input_func: Callable = Field(default_factory=lambda: _collect_user_input)
+ prompt_func: Callable[[str], None] = Field(default_factory=lambda: _display_prompt)
+ separator: str = "\n"
+ input_kwargs: Mapping[str, Any] = {}
+ prompt_kwargs: Mapping[str, Any] = {}
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """
+ Returns an empty dictionary as there are no identifying parameters.
+ """
+ return {}
+
+ @property
+ def _llm_type(self) -> str:
+ """Returns the type of LLM."""
+ return "human-input"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """
+ Displays the prompt to the user and returns their input as a response.
+
+ Args:
+ prompt (str): The prompt to be displayed to the user.
+ stop (Optional[List[str]]): A list of stop strings.
+ run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.
+
+ Returns:
+ str: The user's input as a response.
+ """
+ self.prompt_func(prompt, **self.prompt_kwargs)
+ user_input = self.input_func(
+ separator=self.separator, stop=stop, **self.input_kwargs
+ )
+
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the human themselves
+ user_input = enforce_stop_tokens(user_input, stop)
+ return user_input
diff --git a/libs/community/langchain_community/llms/javelin_ai_gateway.py b/libs/community/langchain_community/llms/javelin_ai_gateway.py
new file mode 100644
index 00000000000..d53e4a9bf72
--- /dev/null
+++ b/libs/community/langchain_community/llms/javelin_ai_gateway.py
@@ -0,0 +1,151 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+
+# Ignoring type because below is valid pydantic code
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Parameters for the Javelin AI Gateway LLM."""
+
+ temperature: float = 0.0
+ stop: Optional[List[str]] = None
+ max_tokens: Optional[int] = None
+
+
+class JavelinAIGateway(LLM):
+ """Javelin AI Gateway LLMs.
+
+ To use, you should have the ``javelin_sdk`` python package installed.
+ For more information, see https://docs.getjavelin.io
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import JavelinAIGateway
+
+ completions = JavelinAIGateway(
+ gateway_uri="",
+ route="",
+ params={
+ "temperature": 0.1
+ }
+ )
+ """
+
+ route: str
+ """The route to use for the Javelin AI Gateway API."""
+
+ client: Optional[Any] = None
+ """The Javelin AI Gateway client."""
+
+ gateway_uri: Optional[str] = None
+ """The URI of the Javelin AI Gateway API."""
+
+ params: Optional[Params] = None
+ """Parameters for the Javelin AI Gateway API."""
+
+ javelin_api_key: Optional[str] = None
+ """The API key for the Javelin AI Gateway API."""
+
+ def __init__(self, **kwargs: Any):
+ try:
+ from javelin_sdk import (
+ JavelinClient,
+ UnauthorizedError,
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import javelin_sdk python package. "
+ "Please install it with `pip install javelin_sdk`."
+ )
+ super().__init__(**kwargs)
+ if self.gateway_uri:
+ try:
+ self.client = JavelinClient(
+ base_url=self.gateway_uri, api_key=self.javelin_api_key
+ )
+ except UnauthorizedError as e:
+ raise ValueError("Javelin: Incorrect API Key.") from e
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Javelin AI Gateway API."""
+ params: Dict[str, Any] = {
+ "gateway_uri": self.gateway_uri,
+ "route": self.route,
+ "javelin_api_key": self.javelin_api_key,
+ **(self.params.dict() if self.params else {}),
+ }
+ return params
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return self._default_params
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the Javelin AI Gateway API."""
+ data: Dict[str, Any] = {
+ "prompt": prompt,
+ **(self.params.dict() if self.params else {}),
+ }
+ if s := (stop or (self.params.stop if self.params else None)):
+ data["stop"] = s
+
+ if self.client is not None:
+ resp = self.client.query_route(self.route, query_body=data)
+ else:
+ raise ValueError("Javelin client is not initialized.")
+
+ resp_dict = resp.dict()
+
+ try:
+ return resp_dict["llm_response"]["choices"][0]["text"]
+ except KeyError:
+ return ""
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call async the Javelin AI Gateway API."""
+ data: Dict[str, Any] = {
+ "prompt": prompt,
+ **(self.params.dict() if self.params else {}),
+ }
+ if s := (stop or (self.params.stop if self.params else None)):
+ data["stop"] = s
+
+ if self.client is not None:
+ resp = await self.client.aquery_route(self.route, query_body=data)
+ else:
+ raise ValueError("Javelin client is not initialized.")
+
+ resp_dict = resp.dict()
+
+ try:
+ return resp_dict["llm_response"]["choices"][0]["text"]
+ except KeyError:
+ return ""
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "javelin-ai-gateway"
diff --git a/libs/community/langchain_community/llms/koboldai.py b/libs/community/langchain_community/llms/koboldai.py
new file mode 100644
index 00000000000..ad121755386
--- /dev/null
+++ b/libs/community/langchain_community/llms/koboldai.py
@@ -0,0 +1,197 @@
+import logging
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+
+logger = logging.getLogger(__name__)
+
+
+def clean_url(url: str) -> str:
+ """Remove trailing slash and /api from url if present."""
+ if url.endswith("/api"):
+ return url[:-4]
+ elif url.endswith("/"):
+ return url[:-1]
+ else:
+ return url
+
+
+class KoboldApiLLM(LLM):
+ """Kobold API language model.
+
+ It includes several fields that can be used to control the text generation process.
+
+ To use this class, instantiate it with the required parameters and call it with a
+ prompt to generate text. For example:
+
+ kobold = KoboldApiLLM(endpoint="http://localhost:5000")
+ result = kobold("Write a story about a dragon.")
+
+ This will send a POST request to the Kobold API with the provided prompt and
+ generate text.
+ """
+
+ endpoint: str
+ """The API endpoint to use for generating text."""
+
+ use_story: Optional[bool] = False
+ """ Whether or not to use the story from the KoboldAI GUI when generating text. """
+
+ use_authors_note: Optional[bool] = False
+ """Whether to use the author's note from the KoboldAI GUI when generating text.
+
+ This has no effect unless use_story is also enabled.
+ """
+
+ use_world_info: Optional[bool] = False
+ """Whether to use the world info from the KoboldAI GUI when generating text."""
+
+ use_memory: Optional[bool] = False
+ """Whether to use the memory from the KoboldAI GUI when generating text."""
+
+ max_context_length: Optional[int] = 1600
+ """Maximum number of tokens to send to the model.
+
+ minimum: 1
+ """
+
+ max_length: Optional[int] = 80
+ """Number of tokens to generate.
+
+ maximum: 512
+ minimum: 1
+ """
+
+ rep_pen: Optional[float] = 1.12
+ """Base repetition penalty value.
+
+ minimum: 1
+ """
+
+ rep_pen_range: Optional[int] = 1024
+ """Repetition penalty range.
+
+ minimum: 0
+ """
+
+ rep_pen_slope: Optional[float] = 0.9
+ """Repetition penalty slope.
+
+ minimum: 0
+ """
+
+ temperature: Optional[float] = 0.6
+ """Temperature value.
+
+ exclusiveMinimum: 0
+ """
+
+ tfs: Optional[float] = 0.9
+ """Tail free sampling value.
+
+ maximum: 1
+ minimum: 0
+ """
+
+ top_a: Optional[float] = 0.9
+ """Top-a sampling value.
+
+ minimum: 0
+ """
+
+ top_p: Optional[float] = 0.95
+ """Top-p sampling value.
+
+ maximum: 1
+ minimum: 0
+ """
+
+ top_k: Optional[int] = 0
+ """Top-k sampling value.
+
+ minimum: 0
+ """
+
+ typical: Optional[float] = 0.5
+ """Typical sampling value.
+
+ maximum: 1
+ minimum: 0
+ """
+
+ @property
+ def _llm_type(self) -> str:
+ return "koboldai"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the API and return the output.
+
+ Args:
+ prompt: The prompt to use for generation.
+ stop: A list of strings to stop generation when encountered.
+
+ Returns:
+ The generated text.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import KoboldApiLLM
+
+ llm = KoboldApiLLM(endpoint="http://localhost:5000")
+ llm("Write a story about dragons.")
+ """
+ data: Dict[str, Any] = {
+ "prompt": prompt,
+ "use_story": self.use_story,
+ "use_authors_note": self.use_authors_note,
+ "use_world_info": self.use_world_info,
+ "use_memory": self.use_memory,
+ "max_context_length": self.max_context_length,
+ "max_length": self.max_length,
+ "rep_pen": self.rep_pen,
+ "rep_pen_range": self.rep_pen_range,
+ "rep_pen_slope": self.rep_pen_slope,
+ "temperature": self.temperature,
+ "tfs": self.tfs,
+ "top_a": self.top_a,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "typical": self.typical,
+ }
+
+ if stop is not None:
+ data["stop_sequence"] = stop
+
+ response = requests.post(
+ f"{clean_url(self.endpoint)}/api/v1/generate", json=data
+ )
+
+ response.raise_for_status()
+ json_response = response.json()
+
+ if (
+ "results" in json_response
+ and len(json_response["results"]) > 0
+ and "text" in json_response["results"][0]
+ ):
+ text = json_response["results"][0]["text"].strip()
+
+ if stop is not None:
+ for sequence in stop:
+ if text.endswith(sequence):
+ text = text[: -len(sequence)].rstrip()
+
+ return text
+ else:
+ raise ValueError(
+ f"Unexpected response format from Kobold API: {json_response}"
+ )
diff --git a/libs/community/langchain_community/llms/llamacpp.py b/libs/community/langchain_community/llms/llamacpp.py
new file mode 100644
index 00000000000..c29ca74ef75
--- /dev/null
+++ b/libs/community/langchain_community/llms/llamacpp.py
@@ -0,0 +1,358 @@
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_pydantic_field_names
+from langchain_core.utils.utils import build_extra_kwargs
+
+if TYPE_CHECKING:
+ from llama_cpp import LlamaGrammar
+
+logger = logging.getLogger(__name__)
+
+
+class LlamaCpp(LLM):
+ """llama.cpp model.
+
+ To use, you should have the llama-cpp-python library installed, and provide the
+ path to the Llama model as a named parameter to the constructor.
+ Check out: https://github.com/abetlen/llama-cpp-python
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import LlamaCpp
+ llm = LlamaCpp(model_path="/path/to/llama/model")
+ """
+
+ client: Any #: :meta private:
+ model_path: str
+ """The path to the Llama model file."""
+
+ lora_base: Optional[str] = None
+ """The path to the Llama LoRA base model."""
+
+ lora_path: Optional[str] = None
+ """The path to the Llama LoRA. If None, no LoRa is loaded."""
+
+ n_ctx: int = Field(512, alias="n_ctx")
+ """Token context window."""
+
+ n_parts: int = Field(-1, alias="n_parts")
+ """Number of parts to split the model into.
+ If -1, the number of parts is automatically determined."""
+
+ seed: int = Field(-1, alias="seed")
+ """Seed. If -1, a random seed is used."""
+
+ f16_kv: bool = Field(True, alias="f16_kv")
+ """Use half-precision for key/value cache."""
+
+ logits_all: bool = Field(False, alias="logits_all")
+ """Return logits for all tokens, not just the last token."""
+
+ vocab_only: bool = Field(False, alias="vocab_only")
+ """Only load the vocabulary, no weights."""
+
+ use_mlock: bool = Field(False, alias="use_mlock")
+ """Force system to keep model in RAM."""
+
+ n_threads: Optional[int] = Field(None, alias="n_threads")
+ """Number of threads to use.
+ If None, the number of threads is automatically determined."""
+
+ n_batch: Optional[int] = Field(8, alias="n_batch")
+ """Number of tokens to process in parallel.
+ Should be a number between 1 and n_ctx."""
+
+ n_gpu_layers: Optional[int] = Field(None, alias="n_gpu_layers")
+ """Number of layers to be loaded into gpu memory. Default None."""
+
+ suffix: Optional[str] = Field(None)
+ """A suffix to append to the generated text. If None, no suffix is appended."""
+
+ max_tokens: Optional[int] = 256
+ """The maximum number of tokens to generate."""
+
+ temperature: Optional[float] = 0.8
+ """The temperature to use for sampling."""
+
+ top_p: Optional[float] = 0.95
+ """The top-p value to use for sampling."""
+
+ logprobs: Optional[int] = Field(None)
+ """The number of logprobs to return. If None, no logprobs are returned."""
+
+ echo: Optional[bool] = False
+ """Whether to echo the prompt."""
+
+ stop: Optional[List[str]] = []
+ """A list of strings to stop generation when encountered."""
+
+ repeat_penalty: Optional[float] = 1.1
+ """The penalty to apply to repeated tokens."""
+
+ top_k: Optional[int] = 40
+ """The top-k value to use for sampling."""
+
+ last_n_tokens_size: Optional[int] = 64
+ """The number of tokens to look back when applying the repeat_penalty."""
+
+ use_mmap: Optional[bool] = True
+ """Whether to keep the model loaded in RAM"""
+
+ rope_freq_scale: float = 1.0
+ """Scale factor for rope sampling."""
+
+ rope_freq_base: float = 10000.0
+ """Base frequency for rope sampling."""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Any additional parameters to pass to llama_cpp.Llama."""
+
+ streaming: bool = True
+ """Whether to stream the results, token by token."""
+
+ grammar_path: Optional[Union[str, Path]] = None
+ """
+ grammar_path: Path to the .gbnf file that defines formal grammars
+ for constraining model outputs. For instance, the grammar can be used
+ to force the model to generate valid JSON or to speak exclusively in emojis. At most
+ one of grammar_path and grammar should be passed in.
+ """
+ grammar: Optional[Union[str, LlamaGrammar]] = None
+ """
+ grammar: formal grammar for constraining model outputs. For instance, the grammar
+ can be used to force the model to generate valid JSON or to speak exclusively in
+ emojis. At most one of grammar_path and grammar should be passed in.
+ """
+
+ verbose: bool = True
+ """Print verbose output to stderr."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that llama-cpp-python library is installed."""
+ try:
+ from llama_cpp import Llama, LlamaGrammar
+ except ImportError:
+ raise ImportError(
+ "Could not import llama-cpp-python library. "
+ "Please install the llama-cpp-python library to "
+ "use this embedding model: pip install llama-cpp-python"
+ )
+
+ model_path = values["model_path"]
+ model_param_names = [
+ "rope_freq_scale",
+ "rope_freq_base",
+ "lora_path",
+ "lora_base",
+ "n_ctx",
+ "n_parts",
+ "seed",
+ "f16_kv",
+ "logits_all",
+ "vocab_only",
+ "use_mlock",
+ "n_threads",
+ "n_batch",
+ "use_mmap",
+ "last_n_tokens_size",
+ "verbose",
+ ]
+ model_params = {k: values[k] for k in model_param_names}
+ # For backwards compatibility, only include if non-null.
+ if values["n_gpu_layers"] is not None:
+ model_params["n_gpu_layers"] = values["n_gpu_layers"]
+
+ model_params.update(values["model_kwargs"])
+
+ try:
+ values["client"] = Llama(model_path, **model_params)
+ except Exception as e:
+ raise ValueError(
+ f"Could not load Llama model from path: {model_path}. "
+ f"Received error {e}"
+ )
+
+ if values["grammar"] and values["grammar_path"]:
+ grammar = values["grammar"]
+ grammar_path = values["grammar_path"]
+ raise ValueError(
+ "Can only pass in one of grammar and grammar_path. Received "
+ f"{grammar=} and {grammar_path=}."
+ )
+ elif isinstance(values["grammar"], str):
+ values["grammar"] = LlamaGrammar.from_string(values["grammar"])
+ elif values["grammar_path"]:
+ values["grammar"] = LlamaGrammar.from_file(values["grammar_path"])
+ else:
+ pass
+ return values
+
+ @root_validator(pre=True)
+ def build_model_kwargs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ values["model_kwargs"] = build_extra_kwargs(
+ extra, values, all_required_field_names
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling llama_cpp."""
+ params = {
+ "suffix": self.suffix,
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "logprobs": self.logprobs,
+ "echo": self.echo,
+ "stop_sequences": self.stop, # key here is convention among LLM classes
+ "repeat_penalty": self.repeat_penalty,
+ "top_k": self.top_k,
+ }
+ if self.grammar:
+ params["grammar"] = self.grammar
+ return params
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_path": self.model_path}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "llamacpp"
+
+ def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
+ """
+ Performs sanity check, preparing parameters in format needed by llama_cpp.
+
+ Args:
+ stop (Optional[List[str]]): List of stop sequences for llama_cpp.
+
+ Returns:
+ Dictionary containing the combined parameters.
+ """
+
+ # Raise error if stop sequences are in both input and default params
+ if self.stop and stop is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+
+ params = self._default_params
+
+ # llama_cpp expects the "stop" key not this, so we remove it:
+ params.pop("stop_sequences")
+
+ # then sets it as configured, or default to an empty list:
+ params["stop"] = self.stop or stop or []
+
+ return params
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the Llama model and return the output.
+
+ Args:
+ prompt: The prompt to use for generation.
+ stop: A list of strings to stop generation when encountered.
+
+ Returns:
+ The generated text.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import LlamaCpp
+ llm = LlamaCpp(model_path="/path/to/local/llama/model.bin")
+ llm("This is a prompt.")
+ """
+ if self.streaming:
+ # If streaming is enabled, we use the stream
+ # method that yields as they are generated
+ # and return the combined strings from the first choices's text:
+ combined_text_output = ""
+ for chunk in self._stream(
+ prompt=prompt,
+ stop=stop,
+ run_manager=run_manager,
+ **kwargs,
+ ):
+ combined_text_output += chunk.text
+ return combined_text_output
+ else:
+ params = self._get_parameters(stop)
+ params = {**params, **kwargs}
+ result = self.client(prompt=prompt, **params)
+ return result["choices"][0]["text"]
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ """Yields results objects as they are generated in real time.
+
+ It also calls the callback manager's on_llm_new_token event with
+ similar parameters to the OpenAI LLM class method of the same name.
+
+ Args:
+ prompt: The prompts to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ A generator representing the stream of tokens being generated.
+
+ Yields:
+ A dictionary like objects containing a string token and metadata.
+ See llama-cpp-python docs and below for more.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import LlamaCpp
+ llm = LlamaCpp(
+ model_path="/path/to/local/model.bin",
+ temperature = 0.5
+ )
+ for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
+ stop=["'","\n"]):
+ result = chunk["choices"][0]
+ print(result["text"], end='', flush=True)
+
+ """
+ params = {**self._get_parameters(stop), **kwargs}
+ result = self.client(prompt=prompt, stream=True, **params)
+ for part in result:
+ logprobs = part["choices"][0].get("logprobs", None)
+ chunk = GenerationChunk(
+ text=part["choices"][0]["text"],
+ generation_info={"logprobs": logprobs},
+ )
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(
+ token=chunk.text, verbose=self.verbose, log_probs=logprobs
+ )
+
+ def get_num_tokens(self, text: str) -> int:
+ tokenized_text = self.client.tokenize(text.encode("utf-8"))
+ return len(tokenized_text)
diff --git a/libs/community/langchain_community/llms/loading.py b/libs/community/langchain_community/llms/loading.py
new file mode 100644
index 00000000000..c9621b784df
--- /dev/null
+++ b/libs/community/langchain_community/llms/loading.py
@@ -0,0 +1,44 @@
+"""Base interface for loading large language model APIs."""
+import json
+from pathlib import Path
+from typing import Union
+
+import yaml
+from langchain_core.language_models.llms import BaseLLM
+
+from langchain_community.llms import get_type_to_cls_dict
+
+
+def load_llm_from_config(config: dict) -> BaseLLM:
+ """Load LLM from Config Dict."""
+ if "_type" not in config:
+ raise ValueError("Must specify an LLM Type in config")
+ config_type = config.pop("_type")
+
+ type_to_cls_dict = get_type_to_cls_dict()
+
+ if config_type not in type_to_cls_dict:
+ raise ValueError(f"Loading {config_type} LLM not supported")
+
+ llm_cls = type_to_cls_dict[config_type]()
+ return llm_cls(**config)
+
+
+def load_llm(file: Union[str, Path]) -> BaseLLM:
+ """Load LLM from file."""
+ # Convert file to Path object.
+ if isinstance(file, str):
+ file_path = Path(file)
+ else:
+ file_path = file
+ # Load from either json or yaml.
+ if file_path.suffix == ".json":
+ with open(file_path) as f:
+ config = json.load(f)
+ elif file_path.suffix == ".yaml":
+ with open(file_path, "r") as f:
+ config = yaml.safe_load(f)
+ else:
+ raise ValueError("File type must be json or yaml")
+ # Load the LLM from the config now.
+ return load_llm_from_config(config)
diff --git a/libs/community/langchain_community/llms/manifest.py b/libs/community/langchain_community/llms/manifest.py
new file mode 100644
index 00000000000..2852ab1d7c7
--- /dev/null
+++ b/libs/community/langchain_community/llms/manifest.py
@@ -0,0 +1,63 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+
+
+class ManifestWrapper(LLM):
+ """HazyResearch's Manifest library."""
+
+ client: Any #: :meta private:
+ llm_kwargs: Optional[Dict] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+ try:
+ from manifest import Manifest
+
+ if not isinstance(values["client"], Manifest):
+ raise ValueError
+ except ImportError:
+ raise ImportError(
+ "Could not import manifest python package. "
+ "Please install it with `pip install manifest-ml`."
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ kwargs = self.llm_kwargs or {}
+ return {
+ **self.client.client_pool.get_current_client().get_model_params(),
+ **kwargs,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "manifest"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to LLM through Manifest."""
+ if stop is not None and len(stop) != 1:
+ raise NotImplementedError(
+ f"Manifest currently only supports a single stop token, got {stop}"
+ )
+ params = self.llm_kwargs or {}
+ params = {**params, **kwargs}
+ if stop is not None:
+ params["stop_token"] = stop
+ return self.client.run(prompt, **params)
diff --git a/libs/community/langchain_community/llms/minimax.py b/libs/community/langchain_community/llms/minimax.py
new file mode 100644
index 00000000000..a2375b7445f
--- /dev/null
+++ b/libs/community/langchain_community/llms/minimax.py
@@ -0,0 +1,156 @@
+"""Wrapper around Minimax APIs."""
+from __future__ import annotations
+
+import logging
+from typing import (
+ Any,
+ Dict,
+ List,
+ Optional,
+)
+
+import requests
+from langchain_core.callbacks import (
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class _MinimaxEndpointClient(BaseModel):
+ """An API client that talks to a Minimax llm endpoint."""
+
+ host: str
+ group_id: str
+ api_key: SecretStr
+ api_url: str
+
+ @root_validator(pre=True, allow_reuse=True)
+ def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ if "api_url" not in values:
+ host = values["host"]
+ group_id = values["group_id"]
+ api_url = f"{host}/v1/text/chatcompletion?GroupId={group_id}"
+ values["api_url"] = api_url
+ return values
+
+ def post(self, request: Any) -> Any:
+ headers = {"Authorization": f"Bearer {self.api_key.get_secret_value()}"}
+ response = requests.post(self.api_url, headers=headers, json=request)
+ # TODO: error handling and automatic retries
+ if not response.ok:
+ raise ValueError(f"HTTP {response.status_code} error: {response.text}")
+ if response.json()["base_resp"]["status_code"] > 0:
+ raise ValueError(
+ f"API {response.json()['base_resp']['status_code']}"
+ f" error: {response.json()['base_resp']['status_msg']}"
+ )
+ return response.json()["reply"]
+
+
+class MinimaxCommon(BaseModel):
+ """Common parameters for Minimax large language models."""
+
+ _client: _MinimaxEndpointClient
+ model: str = "abab5.5-chat"
+ """Model name to use."""
+ max_tokens: int = 256
+ """Denotes the number of tokens to predict per generation."""
+ temperature: float = 0.7
+ """A non-negative float that tunes the degree of randomness in generation."""
+ top_p: float = 0.95
+ """Total probability mass of tokens to consider at each step."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not explicitly specified."""
+ minimax_api_host: Optional[str] = None
+ minimax_group_id: Optional[str] = None
+ minimax_api_key: Optional[SecretStr] = None
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["minimax_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
+ )
+ values["minimax_group_id"] = get_from_dict_or_env(
+ values, "minimax_group_id", "MINIMAX_GROUP_ID"
+ )
+ # Get custom api url from environment.
+ values["minimax_api_host"] = get_from_dict_or_env(
+ values,
+ "minimax_api_host",
+ "MINIMAX_API_HOST",
+ default="https://api.minimax.chat",
+ )
+ values["_client"] = _MinimaxEndpointClient(
+ host=values["minimax_api_host"],
+ api_key=values["minimax_api_key"],
+ group_id=values["minimax_group_id"],
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling OpenAI API."""
+ return {
+ "model": self.model,
+ "tokens_to_generate": self.max_tokens,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ **self.model_kwargs,
+ }
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model": self.model}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "minimax"
+
+
+class Minimax(MinimaxCommon, LLM):
+ """Wrapper around Minimax large language models.
+ To use, you should have the environment variable
+ ``MINIMAX_API_KEY`` and ``MINIMAX_GROUP_ID`` set with your API key,
+ or pass them as a named parameter to the constructor.
+ Example:
+ . code-block:: python
+ from langchain_community.llms.minimax import Minimax
+ minimax = Minimax(model="", minimax_api_key="my-api-key",
+ minimax_group_id="my-group-id")
+ """
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ r"""Call out to Minimax's completion endpoint to chat
+ Args:
+ prompt: The prompt to pass into the model.
+ Returns:
+ The string generated by the model.
+ Example:
+ .. code-block:: python
+ response = minimax("Tell me a joke.")
+ """
+ request = self._default_params
+ request["messages"] = [{"sender_type": "USER", "text": prompt}]
+ request.update(kwargs)
+ text = self._client.post(request)
+ if stop is not None:
+ # This is required since the stop tokens
+ # are not enforced by the model parameters
+ text = enforce_stop_tokens(text, stop)
+
+ return text
diff --git a/libs/community/langchain_community/llms/mlflow.py b/libs/community/langchain_community/llms/mlflow.py
new file mode 100644
index 00000000000..8c4c1118d9b
--- /dev/null
+++ b/libs/community/langchain_community/llms/mlflow.py
@@ -0,0 +1,122 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Mapping, Optional
+from urllib.parse import urlparse
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import LLM
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, PrivateAttr
+
+
+# Ignoring type because below is valid pydantic code
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Parameters for MLflow"""
+
+ temperature: float = 0.0
+ n: int = 1
+ stop: Optional[List[str]] = None
+ max_tokens: Optional[int] = None
+
+
+class Mlflow(LLM):
+ """Wrapper around completions LLMs in MLflow.
+
+ To use, you should have the `mlflow[genai]` python package installed.
+ For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Mlflow
+
+ completions = Mlflow(
+ target_uri="http://localhost:5000",
+ endpoint="test",
+ params={"temperature": 0.1}
+ )
+ """
+
+ endpoint: str
+ """The endpoint to use."""
+ target_uri: str
+ """The target URI to use."""
+ temperature: float = 0.0
+ """The sampling temperature."""
+ n: int = 1
+ """The number of completion choices to generate."""
+ stop: Optional[List[str]] = None
+ """The stop sequence."""
+ max_tokens: Optional[int] = None
+ """The maximum number of tokens to generate."""
+ extra_params: Dict[str, Any] = Field(default_factory=dict)
+ """Any extra parameters to pass to the endpoint."""
+
+ """Extra parameters such as `temperature`."""
+ _client: Any = PrivateAttr()
+
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
+ self._validate_uri()
+ try:
+ from mlflow.deployments import get_deploy_client
+
+ self._client = get_deploy_client(self.target_uri)
+ except ImportError as e:
+ raise ImportError(
+ "Failed to create the client. "
+ "Please run `pip install mlflow[genai]` to install "
+ "required dependencies."
+ ) from e
+
+ def _validate_uri(self) -> None:
+ if self.target_uri == "databricks":
+ return
+ allowed = ["http", "https", "databricks"]
+ if urlparse(self.target_uri).scheme not in allowed:
+ raise ValueError(
+ f"Invalid target URI: {self.target_uri}. "
+ f"The scheme must be one of {allowed}."
+ )
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ return {
+ "target_uri": self.target_uri,
+ "endpoint": self.endpoint,
+ "temperature": self.temperature,
+ "n": self.n,
+ "stop": self.stop,
+ "max_tokens": self.max_tokens,
+ "extra_params": self.extra_params,
+ }
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ return self._default_params
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ data: Dict[str, Any] = {
+ "prompt": prompt,
+ "temperature": self.temperature,
+ "n": self.n,
+ **self.extra_params,
+ **kwargs,
+ }
+ if stop := self.stop or stop:
+ data["stop"] = stop
+ if self.max_tokens is not None:
+ data["max_tokens"] = self.max_tokens
+
+ resp = self._client.predict(endpoint=self.endpoint, inputs=data)
+ return resp["choices"][0]["text"]
+
+ @property
+ def _llm_type(self) -> str:
+ return "mlflow"
diff --git a/libs/community/langchain_community/llms/mlflow_ai_gateway.py b/libs/community/langchain_community/llms/mlflow_ai_gateway.py
new file mode 100644
index 00000000000..776307a6bd3
--- /dev/null
+++ b/libs/community/langchain_community/llms/mlflow_ai_gateway.py
@@ -0,0 +1,104 @@
+from __future__ import annotations
+
+import warnings
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+
+# Ignoring type because below is valid pydantic code
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Parameters for the MLflow AI Gateway LLM."""
+
+ temperature: float = 0.0
+ candidate_count: int = 1
+ """The number of candidates to return."""
+ stop: Optional[List[str]] = None
+ max_tokens: Optional[int] = None
+
+
+class MlflowAIGateway(LLM):
+ """
+ Wrapper around completions LLMs in the MLflow AI Gateway.
+
+ To use, you should have the ``mlflow[gateway]`` python package installed.
+ For more information, see https://mlflow.org/docs/latest/gateway/index.html.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import MlflowAIGateway
+
+ completions = MlflowAIGateway(
+ gateway_uri="",
+ route="",
+ params={
+ "temperature": 0.1
+ }
+ )
+ """
+
+ route: str
+ gateway_uri: Optional[str] = None
+ params: Optional[Params] = None
+
+ def __init__(self, **kwargs: Any):
+ warnings.warn(
+ "`MlflowAIGateway` is deprecated. Use `Mlflow` or `Databricks` instead.",
+ DeprecationWarning,
+ )
+ try:
+ import mlflow.gateway
+ except ImportError as e:
+ raise ImportError(
+ "Could not import `mlflow.gateway` module. "
+ "Please install it with `pip install mlflow[gateway]`."
+ ) from e
+
+ super().__init__(**kwargs)
+ if self.gateway_uri:
+ mlflow.gateway.set_gateway_uri(self.gateway_uri)
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ params: Dict[str, Any] = {
+ "gateway_uri": self.gateway_uri,
+ "route": self.route,
+ **(self.params.dict() if self.params else {}),
+ }
+ return params
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ return self._default_params
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ try:
+ import mlflow.gateway
+ except ImportError as e:
+ raise ImportError(
+ "Could not import `mlflow.gateway` module. "
+ "Please install it with `pip install mlflow[gateway]`."
+ ) from e
+
+ data: Dict[str, Any] = {
+ "prompt": prompt,
+ **(self.params.dict() if self.params else {}),
+ }
+ if s := (stop or (self.params.stop if self.params else None)):
+ data["stop"] = s
+ resp = mlflow.gateway.query(self.route, data=data)
+ return resp["candidates"][0]["text"]
+
+ @property
+ def _llm_type(self) -> str:
+ return "mlflow-ai-gateway"
diff --git a/libs/community/langchain_community/llms/modal.py b/libs/community/langchain_community/llms/modal.py
new file mode 100644
index 00000000000..6ccac3c0d8e
--- /dev/null
+++ b/libs/community/langchain_community/llms/modal.py
@@ -0,0 +1,100 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class Modal(LLM):
+ """Modal large language models.
+
+ To use, you should have the ``modal-client`` python package installed.
+
+ Any parameters that are valid to be passed to the call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Modal
+ modal = Modal(endpoint_url="")
+
+ """
+
+ endpoint_url: str = ""
+ """model endpoint to use"""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not
+ explicitly specified."""
+
+ class Config:
+ """Configuration for this pydantic config."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""{field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"endpoint_url": self.endpoint_url},
+ **{"model_kwargs": self.model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "modal"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call to Modal endpoint."""
+ params = self.model_kwargs or {}
+ params = {**params, **kwargs}
+ response = requests.post(
+ url=self.endpoint_url,
+ headers={
+ "Content-Type": "application/json",
+ },
+ json={"prompt": prompt, **params},
+ )
+ try:
+ if prompt in response.json()["prompt"]:
+ response_json = response.json()
+ except KeyError:
+ raise KeyError("LangChain requires 'prompt' key in response.")
+ text = response_json["prompt"]
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the model parameters
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/mosaicml.py b/libs/community/langchain_community/llms/mosaicml.py
new file mode 100644
index 00000000000..b73e0d5e214
--- /dev/null
+++ b/libs/community/langchain_community/llms/mosaicml.py
@@ -0,0 +1,188 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+INSTRUCTION_KEY = "### Instruction:"
+RESPONSE_KEY = "### Response:"
+INTRO_BLURB = (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request."
+)
+PROMPT_FOR_GENERATION_FORMAT = """{intro}
+{instruction_key}
+{instruction}
+{response_key}
+""".format(
+ intro=INTRO_BLURB,
+ instruction_key=INSTRUCTION_KEY,
+ instruction="{instruction}",
+ response_key=RESPONSE_KEY,
+)
+
+
+class MosaicML(LLM):
+ """MosaicML LLM service.
+
+ To use, you should have the
+ environment variable ``MOSAICML_API_TOKEN`` set with your API token, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import MosaicML
+ endpoint_url = (
+ "https://models.hosted-on.mosaicml.hosting/mpt-7b-instruct/v1/predict"
+ )
+ mosaic_llm = MosaicML(
+ endpoint_url=endpoint_url,
+ mosaicml_api_token="my-api-key"
+ )
+ """
+
+ endpoint_url: str = (
+ "https://models.hosted-on.mosaicml.hosting/mpt-7b-instruct/v1/predict"
+ )
+ """Endpoint URL to use."""
+ inject_instruction_format: bool = False
+ """Whether to inject the instruction format into the prompt."""
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+ retry_sleep: float = 1.0
+ """How long to try sleeping for if a rate limit is encountered"""
+
+ mosaicml_api_token: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ mosaicml_api_token = get_from_dict_or_env(
+ values, "mosaicml_api_token", "MOSAICML_API_TOKEN"
+ )
+ values["mosaicml_api_token"] = mosaicml_api_token
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"endpoint_url": self.endpoint_url},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "mosaic"
+
+ def _transform_prompt(self, prompt: str) -> str:
+ """Transform prompt."""
+ if self.inject_instruction_format:
+ prompt = PROMPT_FOR_GENERATION_FORMAT.format(
+ instruction=prompt,
+ )
+ return prompt
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ is_retry: bool = False,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to a MosaicML LLM inference endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = mosaic_llm("Tell me a joke.")
+ """
+ _model_kwargs = self.model_kwargs or {}
+
+ prompt = self._transform_prompt(prompt)
+
+ payload = {"inputs": [prompt]}
+ payload.update(_model_kwargs)
+ payload.update(kwargs)
+
+ # HTTP headers for authorization
+ headers = {
+ "Authorization": f"{self.mosaicml_api_token}",
+ "Content-Type": "application/json",
+ }
+
+ # send request
+ try:
+ response = requests.post(self.endpoint_url, headers=headers, json=payload)
+ except requests.exceptions.RequestException as e:
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+
+ try:
+ if response.status_code == 429:
+ if not is_retry:
+ import time
+
+ time.sleep(self.retry_sleep)
+
+ return self._call(prompt, stop, run_manager, is_retry=True)
+
+ raise ValueError(
+ f"Error raised by inference API: rate limit exceeded.\nResponse: "
+ f"{response.text}"
+ )
+
+ parsed_response = response.json()
+
+ # The inference API has changed a couple of times, so we add some handling
+ # to be robust to multiple response formats.
+ if isinstance(parsed_response, dict):
+ output_keys = ["data", "output", "outputs"]
+ for key in output_keys:
+ if key in parsed_response:
+ output_item = parsed_response[key]
+ break
+ else:
+ raise ValueError(
+ f"No valid key ({', '.join(output_keys)}) in response:"
+ f" {parsed_response}"
+ )
+ if isinstance(output_item, list):
+ text = output_item[0]
+ else:
+ text = output_item
+ else:
+ raise ValueError(f"Unexpected response type: {parsed_response}")
+
+ # Older versions of the API include the input in the output response
+ if text.startswith(prompt):
+ text = text[len(prompt) :]
+
+ except requests.exceptions.JSONDecodeError as e:
+ raise ValueError(
+ f"Error raised by inference API: {e}.\nResponse: {response.text}"
+ )
+
+ # TODO: replace when MosaicML supports custom stop tokens natively
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/nlpcloud.py b/libs/community/langchain_community/llms/nlpcloud.py
new file mode 100644
index 00000000000..bdff6404290
--- /dev/null
+++ b/libs/community/langchain_community/llms/nlpcloud.py
@@ -0,0 +1,145 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+
+class NLPCloud(LLM):
+ """NLPCloud large language models.
+
+ To use, you should have the ``nlpcloud`` python package installed, and the
+ environment variable ``NLPCLOUD_API_KEY`` set with your API key.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import NLPCloud
+ nlpcloud = NLPCloud(model="finetuned-gpt-neox-20b")
+ """
+
+ client: Any #: :meta private:
+ model_name: str = "finetuned-gpt-neox-20b"
+ """Model name to use."""
+ gpu: bool = True
+ """Whether to use a GPU or not"""
+ lang: str = "en"
+ """Language to use (multilingual addon)"""
+ temperature: float = 0.7
+ """What sampling temperature to use."""
+ max_length: int = 256
+ """The maximum number of tokens to generate in the completion."""
+ length_no_input: bool = True
+ """Whether min_length and max_length should include the length of the input."""
+ remove_input: bool = True
+ """Remove input text from API response"""
+ remove_end_sequence: bool = True
+ """Whether or not to remove the end sequence token."""
+ bad_words: List[str] = []
+ """List of tokens not allowed to be generated."""
+ top_p: int = 1
+ """Total probability mass of tokens to consider at each step."""
+ top_k: int = 50
+ """The number of highest probability tokens to keep for top-k filtering."""
+ repetition_penalty: float = 1.0
+ """Penalizes repeated tokens. 1.0 means no penalty."""
+ num_beams: int = 1
+ """Number of beams for beam search."""
+ num_return_sequences: int = 1
+ """How many completions to generate for each prompt."""
+
+ nlpcloud_api_key: Optional[SecretStr] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["nlpcloud_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "nlpcloud_api_key", "NLPCLOUD_API_KEY")
+ )
+ try:
+ import nlpcloud
+
+ values["client"] = nlpcloud.Client(
+ values["model_name"],
+ values["nlpcloud_api_key"].get_secret_value(),
+ gpu=values["gpu"],
+ lang=values["lang"],
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import nlpcloud python package. "
+ "Please install it with `pip install nlpcloud`."
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Mapping[str, Any]:
+ """Get the default parameters for calling NLPCloud API."""
+ return {
+ "temperature": self.temperature,
+ "max_length": self.max_length,
+ "length_no_input": self.length_no_input,
+ "remove_input": self.remove_input,
+ "remove_end_sequence": self.remove_end_sequence,
+ "bad_words": self.bad_words,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "repetition_penalty": self.repetition_penalty,
+ "num_beams": self.num_beams,
+ "num_return_sequences": self.num_return_sequences,
+ }
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"model_name": self.model_name},
+ **{"gpu": self.gpu},
+ **{"lang": self.lang},
+ **self._default_params,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "nlpcloud"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to NLPCloud's create endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Not supported by this interface (pass in init method)
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = nlpcloud("Tell me a joke.")
+ """
+ if stop and len(stop) > 1:
+ raise ValueError(
+ "NLPCloud only supports a single stop sequence per generation."
+ "Pass in a list of length 1."
+ )
+ elif stop and len(stop) == 1:
+ end_sequence = stop[0]
+ else:
+ end_sequence = None
+ params = {**self._default_params, **kwargs}
+ response = self.client.generation(prompt, end_sequence=end_sequence, **params)
+ return response["generated_text"]
diff --git a/libs/community/langchain_community/llms/octoai_endpoint.py b/libs/community/langchain_community/llms/octoai_endpoint.py
new file mode 100644
index 00000000000..a6002b8ae06
--- /dev/null
+++ b/libs/community/langchain_community/llms/octoai_endpoint.py
@@ -0,0 +1,151 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class OctoAIEndpoint(LLM):
+ """OctoAI LLM Endpoints.
+
+ OctoAIEndpoint is a class to interact with OctoAI
+ Compute Service large language model endpoints.
+
+ To use, you should have the ``octoai`` python package installed, and the
+ environment variable ``OCTOAI_API_TOKEN`` set with your API token, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
+ OctoAIEndpoint(
+ octoai_api_token="octoai-api-key",
+ endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
+ model_kwargs={
+ "max_new_tokens": 200,
+ "temperature": 0.75,
+ "top_p": 0.95,
+ "repetition_penalty": 1,
+ "seed": None,
+ "stop": [],
+ },
+ )
+
+ from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
+ OctoAIEndpoint(
+ octoai_api_token="octoai-api-key",
+ endpoint_url="https://llama-2-7b-chat-demo-kk0powt97tmb.octoai.run/v1/chat/completions",
+ model_kwargs={
+ "model": "llama-2-7b-chat",
+ "messages": [
+ {
+ "role": "system",
+ "content": "Below is an instruction that describes a task.
+ Write a response that completes the request."
+ }
+ ],
+ "stream": False,
+ "max_tokens": 256
+ }
+ )
+
+ """
+
+ endpoint_url: Optional[str] = None
+ """Endpoint URL to use."""
+
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+
+ octoai_api_token: Optional[str] = None
+ """OCTOAI API Token"""
+
+ streaming: bool = False
+ """Whether to generate a stream of tokens asynchronously"""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(allow_reuse=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ octoai_api_token = get_from_dict_or_env(
+ values, "octoai_api_token", "OCTOAI_API_TOKEN"
+ )
+ values["endpoint_url"] = get_from_dict_or_env(
+ values, "endpoint_url", "ENDPOINT_URL"
+ )
+
+ values["octoai_api_token"] = octoai_api_token
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"endpoint_url": self.endpoint_url},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "octoai_endpoint"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to OctoAI's inference endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ """
+ _model_kwargs = self.model_kwargs or {}
+
+ try:
+ # Initialize the OctoAI client
+ from octoai import client
+
+ octoai_client = client.Client(token=self.octoai_api_token)
+
+ if "model" in _model_kwargs:
+ parameter_payload = _model_kwargs
+ parameter_payload["messages"].append(
+ {"role": "user", "content": prompt}
+ )
+ # Send the request using the OctoAI client
+ output = octoai_client.infer(self.endpoint_url, parameter_payload)
+ text = output.get("choices")[0].get("message").get("content")
+ else:
+ # Prepare the payload JSON
+ parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
+
+ # Send the request using the OctoAI client
+ resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
+ text = resp_json["generated_text"]
+
+ except Exception as e:
+ # Handle any errors raised by the inference endpoint
+ raise ValueError(f"Error raised by the inference endpoint: {e}") from e
+
+ if stop is not None:
+ # Apply stop tokens when making calls to OctoAI
+ text = enforce_stop_tokens(text, stop)
+
+ return text
diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py
new file mode 100644
index 00000000000..3551ba446ef
--- /dev/null
+++ b/libs/community/langchain_community/llms/ollama.py
@@ -0,0 +1,273 @@
+import json
+from typing import Any, Dict, Iterator, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.language_models.llms import BaseLLM
+from langchain_core.outputs import GenerationChunk, LLMResult
+from langchain_core.pydantic_v1 import Extra
+
+
+def _stream_response_to_generation_chunk(
+ stream_response: str,
+) -> GenerationChunk:
+ """Convert a stream response to a generation chunk."""
+ parsed_response = json.loads(stream_response)
+ generation_info = parsed_response if parsed_response.get("done") is True else None
+ return GenerationChunk(
+ text=parsed_response.get("response", ""), generation_info=generation_info
+ )
+
+
+class _OllamaCommon(BaseLanguageModel):
+ base_url: str = "http://localhost:11434"
+ """Base url the model is hosted under."""
+
+ model: str = "llama2"
+ """Model name to use."""
+
+ mirostat: Optional[int] = None
+ """Enable Mirostat sampling for controlling perplexity.
+ (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
+
+ mirostat_eta: Optional[float] = None
+ """Influences how quickly the algorithm responds to feedback
+ from the generated text. A lower learning rate will result in
+ slower adjustments, while a higher learning rate will make
+ the algorithm more responsive. (Default: 0.1)"""
+
+ mirostat_tau: Optional[float] = None
+ """Controls the balance between coherence and diversity
+ of the output. A lower value will result in more focused and
+ coherent text. (Default: 5.0)"""
+
+ num_ctx: Optional[int] = None
+ """Sets the size of the context window used to generate the
+ next token. (Default: 2048) """
+
+ num_gpu: Optional[int] = None
+ """The number of GPUs to use. On macOS it defaults to 1 to
+ enable metal support, 0 to disable."""
+
+ num_thread: Optional[int] = None
+ """Sets the number of threads to use during computation.
+ By default, Ollama will detect this for optimal performance.
+ It is recommended to set this value to the number of physical
+ CPU cores your system has (as opposed to the logical number of cores)."""
+
+ repeat_last_n: Optional[int] = None
+ """Sets how far back for the model to look back to prevent
+ repetition. (Default: 64, 0 = disabled, -1 = num_ctx)"""
+
+ repeat_penalty: Optional[float] = None
+ """Sets how strongly to penalize repetitions. A higher value (e.g., 1.5)
+ will penalize repetitions more strongly, while a lower value (e.g., 0.9)
+ will be more lenient. (Default: 1.1)"""
+
+ temperature: Optional[float] = None
+ """The temperature of the model. Increasing the temperature will
+ make the model answer more creatively. (Default: 0.8)"""
+
+ stop: Optional[List[str]] = None
+ """Sets the stop tokens to use."""
+
+ tfs_z: Optional[float] = None
+ """Tail free sampling is used to reduce the impact of less probable
+ tokens from the output. A higher value (e.g., 2.0) will reduce the
+ impact more, while a value of 1.0 disables this setting. (default: 1)"""
+
+ top_k: Optional[int] = None
+ """Reduces the probability of generating nonsense. A higher value (e.g. 100)
+ will give more diverse answers, while a lower value (e.g. 10)
+ will be more conservative. (Default: 40)"""
+
+ top_p: Optional[int] = None
+ """Works together with top-k. A higher value (e.g., 0.95) will lead
+ to more diverse text, while a lower value (e.g., 0.5) will
+ generate more focused and conservative text. (Default: 0.9)"""
+
+ system: Optional[str] = None
+ """system prompt (overrides what is defined in the Modelfile)"""
+
+ template: Optional[str] = None
+ """full prompt or prompt template (overrides what is defined in the Modelfile)"""
+
+ format: Optional[str] = None
+ """Specify the format of the output (e.g., json)"""
+
+ timeout: Optional[int] = None
+ """Timeout for the request stream"""
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Ollama."""
+ return {
+ "model": self.model,
+ "format": self.format,
+ "options": {
+ "mirostat": self.mirostat,
+ "mirostat_eta": self.mirostat_eta,
+ "mirostat_tau": self.mirostat_tau,
+ "num_ctx": self.num_ctx,
+ "num_gpu": self.num_gpu,
+ "num_thread": self.num_thread,
+ "repeat_last_n": self.repeat_last_n,
+ "repeat_penalty": self.repeat_penalty,
+ "temperature": self.temperature,
+ "stop": self.stop,
+ "tfs_z": self.tfs_z,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ },
+ "system": self.system,
+ "template": self.template,
+ }
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model": self.model, "format": self.format}, **self._default_params}
+
+ def _create_stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Iterator[str]:
+ if self.stop is not None and stop is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+ elif self.stop is not None:
+ stop = self.stop
+ elif stop is None:
+ stop = []
+
+ params = self._default_params
+
+ if "model" in kwargs:
+ params["model"] = kwargs["model"]
+
+ if "options" in kwargs:
+ params["options"] = kwargs["options"]
+ else:
+ params["options"] = {
+ **params["options"],
+ "stop": stop,
+ **kwargs,
+ }
+
+ response = requests.post(
+ url=f"{self.base_url}/api/generate/",
+ headers={"Content-Type": "application/json"},
+ json={"prompt": prompt, **params},
+ stream=True,
+ timeout=self.timeout,
+ )
+ response.encoding = "utf-8"
+ if response.status_code != 200:
+ optional_detail = response.json().get("error")
+ raise ValueError(
+ f"Ollama call failed with status code {response.status_code}."
+ f" Details: {optional_detail}"
+ )
+ return response.iter_lines(decode_unicode=True)
+
+ def _stream_with_aggregation(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ verbose: bool = False,
+ **kwargs: Any,
+ ) -> GenerationChunk:
+ final_chunk: Optional[GenerationChunk] = None
+ for stream_resp in self._create_stream(prompt, stop, **kwargs):
+ if stream_resp:
+ chunk = _stream_response_to_generation_chunk(stream_resp)
+ if final_chunk is None:
+ final_chunk = chunk
+ else:
+ final_chunk += chunk
+ if run_manager:
+ run_manager.on_llm_new_token(
+ chunk.text,
+ verbose=verbose,
+ )
+ if final_chunk is None:
+ raise ValueError("No data received from Ollama stream.")
+
+ return final_chunk
+
+
+class Ollama(BaseLLM, _OllamaCommon):
+ """Ollama locally runs large language models.
+
+ To use, follow the instructions at https://ollama.ai/.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Ollama
+ ollama = Ollama(model="llama2")
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "ollama-llm"
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Call out to Ollama's generate endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = ollama("Tell me a joke.")
+ """
+ # TODO: add caching here.
+ generations = []
+ for prompt in prompts:
+ final_chunk = super()._stream_with_aggregation(
+ prompt,
+ stop=stop,
+ run_manager=run_manager,
+ verbose=self.verbose,
+ **kwargs,
+ )
+ generations.append([final_chunk])
+ return LLMResult(generations=generations)
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ for stream_resp in self._create_stream(prompt, stop, **kwargs):
+ if stream_resp:
+ chunk = _stream_response_to_generation_chunk(stream_resp)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(
+ chunk.text,
+ verbose=self.verbose,
+ )
diff --git a/libs/community/langchain_community/llms/opaqueprompts.py b/libs/community/langchain_community/llms/opaqueprompts.py
new file mode 100644
index 00000000000..34f14c515c1
--- /dev/null
+++ b/libs/community/langchain_community/llms/opaqueprompts.py
@@ -0,0 +1,116 @@
+import logging
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class OpaquePrompts(LLM):
+ """An LLM wrapper that uses OpaquePrompts to sanitize prompts.
+
+ Wraps another LLM and sanitizes prompts before passing it to the LLM, then
+ de-sanitizes the response.
+
+ To use, you should have the ``opaqueprompts`` python package installed,
+ and the environment variable ``OPAQUEPROMPTS_API_KEY`` set with
+ your API key, or pass it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import OpaquePrompts
+ from langchain_community.chat_models import ChatOpenAI
+
+ op_llm = OpaquePrompts(base_llm=ChatOpenAI())
+ """
+
+ base_llm: BaseLanguageModel
+ """The base LLM to use."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validates that the OpaquePrompts API key and the Python package exist."""
+ try:
+ import opaqueprompts as op
+ except ImportError:
+ raise ImportError(
+ "Could not import the `opaqueprompts` Python package, "
+ "please install it with `pip install opaqueprompts`."
+ )
+ if op.__package__ is None:
+ raise ValueError(
+ "Could not properly import `opaqueprompts`, "
+ "opaqueprompts.__package__ is None."
+ )
+
+ api_key = get_from_dict_or_env(
+ values, "opaqueprompts_api_key", "OPAQUEPROMPTS_API_KEY", default=""
+ )
+ if not api_key:
+ raise ValueError(
+ "Could not find OPAQUEPROMPTS_API_KEY in the environment. "
+ "Please set it to your OpaquePrompts API key."
+ "You can get it by creating an account on the OpaquePrompts website: "
+ "https://opaqueprompts.opaque.co/ ."
+ )
+ return values
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call base LLM with sanitization before and de-sanitization after.
+
+ Args:
+ prompt: The prompt to pass into the model.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = op_llm("Tell me a joke.")
+ """
+ import opaqueprompts as op
+
+ _run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager()
+
+ # sanitize the prompt by replacing the sensitive information with a placeholder
+ sanitize_response: op.SanitizeResponse = op.sanitize([prompt])
+ sanitized_prompt_value_str = sanitize_response.sanitized_texts[0]
+
+ # TODO: Add in callbacks once child runs for LLMs are supported by LangSmith.
+ # call the LLM with the sanitized prompt and get the response
+ llm_response = self.base_llm.predict(
+ sanitized_prompt_value_str,
+ stop=stop,
+ )
+
+ # desanitize the response by restoring the original sensitive information
+ desanitize_response: op.DesanitizeResponse = op.desanitize(
+ llm_response,
+ secure_context=sanitize_response.secure_context,
+ )
+ return desanitize_response.desanitized_text
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of LLM.
+
+ This is an override of the base class method.
+ """
+ return "opaqueprompts"
diff --git a/libs/community/langchain_community/llms/openai.py b/libs/community/langchain_community/llms/openai.py
new file mode 100644
index 00000000000..3e325bbb2c0
--- /dev/null
+++ b/libs/community/langchain_community/llms/openai.py
@@ -0,0 +1,1226 @@
+from __future__ import annotations
+
+import logging
+import os
+import sys
+import warnings
+from typing import (
+ AbstractSet,
+ Any,
+ AsyncIterator,
+ Callable,
+ Collection,
+ Dict,
+ Iterator,
+ List,
+ Literal,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
+from langchain_core.outputs import Generation, GenerationChunk, LLMResult
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
+from langchain_core.utils.utils import build_extra_kwargs
+
+from langchain_community.utils.openai import is_openai_v1
+
+logger = logging.getLogger(__name__)
+
+
+def update_token_usage(
+ keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
+) -> None:
+ """Update token usage."""
+ _keys_to_use = keys.intersection(response["usage"])
+ for _key in _keys_to_use:
+ if _key not in token_usage:
+ token_usage[_key] = response["usage"][_key]
+ else:
+ token_usage[_key] += response["usage"][_key]
+
+
+def _stream_response_to_generation_chunk(
+ stream_response: Dict[str, Any],
+) -> GenerationChunk:
+ """Convert a stream response to a generation chunk."""
+ if not stream_response["choices"]:
+ return GenerationChunk(text="")
+ return GenerationChunk(
+ text=stream_response["choices"][0]["text"],
+ generation_info=dict(
+ finish_reason=stream_response["choices"][0].get("finish_reason", None),
+ logprobs=stream_response["choices"][0].get("logprobs", None),
+ ),
+ )
+
+
+def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
+ """Update response from the stream response."""
+ response["choices"][0]["text"] += stream_response["choices"][0]["text"]
+ response["choices"][0]["finish_reason"] = stream_response["choices"][0].get(
+ "finish_reason", None
+ )
+ response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"]
+
+
+def _streaming_response_template() -> Dict[str, Any]:
+ return {
+ "choices": [
+ {
+ "text": "",
+ "finish_reason": None,
+ "logprobs": None,
+ }
+ ]
+ }
+
+
+def _create_retry_decorator(
+ llm: Union[BaseOpenAI, OpenAIChat],
+ run_manager: Optional[
+ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
+ ] = None,
+) -> Callable[[Any], Any]:
+ import openai
+
+ errors = [
+ openai.error.Timeout,
+ openai.error.APIError,
+ openai.error.APIConnectionError,
+ openai.error.RateLimitError,
+ openai.error.ServiceUnavailableError,
+ ]
+ return create_base_retry_decorator(
+ error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
+ )
+
+
+def completion_with_retry(
+ llm: Union[BaseOpenAI, OpenAIChat],
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ if is_openai_v1():
+ return llm.client.create(**kwargs)
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @retry_decorator
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return llm.client.create(**kwargs)
+
+ return _completion_with_retry(**kwargs)
+
+
+async def acompletion_with_retry(
+ llm: Union[BaseOpenAI, OpenAIChat],
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the async completion call."""
+ if is_openai_v1():
+ return await llm.async_client.create(**kwargs)
+
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
+
+ @retry_decorator
+ async def _completion_with_retry(**kwargs: Any) -> Any:
+ # Use OpenAI's async api https://github.com/openai/openai-python#async-api
+ return await llm.client.acreate(**kwargs)
+
+ return await _completion_with_retry(**kwargs)
+
+
+class BaseOpenAI(BaseLLM):
+ """Base OpenAI large language model class."""
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"openai_api_key": "OPENAI_API_KEY"}
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "llms", "openai"]
+
+ @property
+ def lc_attributes(self) -> Dict[str, Any]:
+ attributes: Dict[str, Any] = {}
+ if self.openai_api_base:
+ attributes["openai_api_base"] = self.openai_api_base
+
+ if self.openai_organization:
+ attributes["openai_organization"] = self.openai_organization
+
+ if self.openai_proxy:
+ attributes["openai_proxy"] = self.openai_proxy
+
+ return attributes
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return True
+
+ client: Any = Field(default=None, exclude=True) #: :meta private:
+ async_client: Any = Field(default=None, exclude=True) #: :meta private:
+ model_name: str = Field(default="text-davinci-003", alias="model")
+ """Model name to use."""
+ temperature: float = 0.7
+ """What sampling temperature to use."""
+ max_tokens: int = 256
+ """The maximum number of tokens to generate in the completion.
+ -1 returns as many tokens as possible given the prompt and
+ the models maximal context size."""
+ top_p: float = 1
+ """Total probability mass of tokens to consider at each step."""
+ frequency_penalty: float = 0
+ """Penalizes repeated tokens according to frequency."""
+ presence_penalty: float = 0
+ """Penalizes repeated tokens."""
+ n: int = 1
+ """How many completions to generate for each prompt."""
+ best_of: int = 1
+ """Generates best_of completions server-side and returns the "best"."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not explicitly specified."""
+ # When updating this to use a SecretStr
+ # Check for classes that derive from this class (as some of them
+ # may assume openai_api_key is a str)
+ openai_api_key: Optional[str] = Field(default=None, alias="api_key")
+ """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
+ openai_api_base: Optional[str] = Field(default=None, alias="base_url")
+ """Base URL path for API requests, leave blank if not using a proxy or service
+ emulator."""
+ openai_organization: Optional[str] = Field(default=None, alias="organization")
+ """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
+ # to support explicit proxy for OpenAI
+ openai_proxy: Optional[str] = None
+ batch_size: int = 20
+ """Batch size to use when passing multiple documents to generate."""
+ request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
+ default=None, alias="timeout"
+ )
+ """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
+ None."""
+ logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
+ """Adjust the probability of specific tokens being generated."""
+ max_retries: int = 2
+ """Maximum number of retries to make when generating."""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
+ """Set of special tokens that are allowedγ"""
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all"
+ """Set of special tokens that are not allowedγ"""
+ tiktoken_model_name: Optional[str] = None
+ """The model name to pass to tiktoken when using this class.
+ Tiktoken is used to count the number of tokens in documents to constrain
+ them to be under a certain limit. By default, when set to None, this will
+ be the same as the embedding model name. However, there are some cases
+ where you may want to use this Embedding class with a model name not
+ supported by tiktoken. This can include when using Azure embeddings or
+ when using one of the many model providers that expose an OpenAI-like
+ API but with different models. In those cases, in order to avoid erroring
+ when tiktoken is called, you can specify a model name to use here."""
+ default_headers: Union[Mapping[str, str], None] = None
+ default_query: Union[Mapping[str, object], None] = None
+ # Configure a custom httpx client. See the
+ # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
+ http_client: Union[Any, None] = None
+ """Optional httpx.Client."""
+
+ def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
+ """Initialize the OpenAI object."""
+ model_name = data.get("model_name", "")
+ if (
+ model_name.startswith("gpt-3.5-turbo") or model_name.startswith("gpt-4")
+ ) and "-instruct" not in model_name:
+ warnings.warn(
+ "You are trying to use a chat model. This way of initializing it is "
+ "no longer supported. Instead, please use: "
+ "`from langchain_community.chat_models import ChatOpenAI`"
+ )
+ return OpenAIChat(**data)
+ return super().__new__(cls)
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ values["model_kwargs"] = build_extra_kwargs(
+ extra, values, all_required_field_names
+ )
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ if values["n"] < 1:
+ raise ValueError("n must be at least 1.")
+ if values["streaming"] and values["n"] > 1:
+ raise ValueError("Cannot stream results when n > 1.")
+ if values["streaming"] and values["best_of"] > 1:
+ raise ValueError("Cannot stream results when best_of > 1.")
+
+ values["openai_api_key"] = get_from_dict_or_env(
+ values, "openai_api_key", "OPENAI_API_KEY"
+ )
+ values["openai_api_base"] = values["openai_api_base"] or os.getenv(
+ "OPENAI_API_BASE"
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values,
+ "openai_proxy",
+ "OPENAI_PROXY",
+ default="",
+ )
+ values["openai_organization"] = (
+ values["openai_organization"]
+ or os.getenv("OPENAI_ORG_ID")
+ or os.getenv("OPENAI_ORGANIZATION")
+ )
+ try:
+ import openai
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+
+ if is_openai_v1():
+ client_params = {
+ "api_key": values["openai_api_key"],
+ "organization": values["openai_organization"],
+ "base_url": values["openai_api_base"],
+ "timeout": values["request_timeout"],
+ "max_retries": values["max_retries"],
+ "default_headers": values["default_headers"],
+ "default_query": values["default_query"],
+ "http_client": values["http_client"],
+ }
+ if not values.get("client"):
+ values["client"] = openai.OpenAI(**client_params).completions
+ if not values.get("async_client"):
+ values["async_client"] = openai.AsyncOpenAI(**client_params).completions
+ elif not values.get("client"):
+ values["client"] = openai.Completion
+ else:
+ pass
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling OpenAI API."""
+ normal_params: Dict[str, Any] = {
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "frequency_penalty": self.frequency_penalty,
+ "presence_penalty": self.presence_penalty,
+ "n": self.n,
+ "logit_bias": self.logit_bias,
+ }
+
+ if self.max_tokens is not None:
+ normal_params["max_tokens"] = self.max_tokens
+ if self.request_timeout is not None and not is_openai_v1():
+ normal_params["request_timeout"] = self.request_timeout
+
+ # Azure gpt-35-turbo doesn't support best_of
+ # don't specify best_of if it is 1
+ if self.best_of > 1:
+ normal_params["best_of"] = self.best_of
+
+ return {**normal_params, **self.model_kwargs}
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ params = {**self._invocation_params, **kwargs, "stream": True}
+ self.get_sub_prompts(params, [prompt], stop) # this mutates params
+ for stream_resp in completion_with_retry(
+ self, prompt=prompt, run_manager=run_manager, **params
+ ):
+ if not isinstance(stream_resp, dict):
+ stream_resp = stream_resp.dict()
+ chunk = _stream_response_to_generation_chunk(stream_resp)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(
+ chunk.text,
+ chunk=chunk,
+ verbose=self.verbose,
+ logprobs=chunk.generation_info["logprobs"]
+ if chunk.generation_info
+ else None,
+ )
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ params = {**self._invocation_params, **kwargs, "stream": True}
+ self.get_sub_prompts(params, [prompt], stop) # this mutates params
+ async for stream_resp in await acompletion_with_retry(
+ self, prompt=prompt, run_manager=run_manager, **params
+ ):
+ if not isinstance(stream_resp, dict):
+ stream_resp = stream_resp.dict()
+ chunk = _stream_response_to_generation_chunk(stream_resp)
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(
+ chunk.text,
+ chunk=chunk,
+ verbose=self.verbose,
+ logprobs=chunk.generation_info["logprobs"]
+ if chunk.generation_info
+ else None,
+ )
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Call out to OpenAI's endpoint with k unique prompts.
+
+ Args:
+ prompts: The prompts to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The full LLM output.
+
+ Example:
+ .. code-block:: python
+
+ response = openai.generate(["Tell me a joke."])
+ """
+ # TODO: write a unit test for this
+ params = self._invocation_params
+ params = {**params, **kwargs}
+ sub_prompts = self.get_sub_prompts(params, prompts, stop)
+ choices = []
+ token_usage: Dict[str, int] = {}
+ # Get the token usage from the response.
+ # Includes prompt, completion, and total tokens used.
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
+ system_fingerprint: Optional[str] = None
+ for _prompts in sub_prompts:
+ if self.streaming:
+ if len(_prompts) > 1:
+ raise ValueError("Cannot stream results with multiple prompts.")
+
+ generation: Optional[GenerationChunk] = None
+ for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ choices.append(
+ {
+ "text": generation.text,
+ "finish_reason": generation.generation_info.get("finish_reason")
+ if generation.generation_info
+ else None,
+ "logprobs": generation.generation_info.get("logprobs")
+ if generation.generation_info
+ else None,
+ }
+ )
+ else:
+ response = completion_with_retry(
+ self, prompt=_prompts, run_manager=run_manager, **params
+ )
+ if not isinstance(response, dict):
+ # V1 client returns the response in an PyDantic object instead of
+ # dict. For the transition period, we deep convert it to dict.
+ response = response.dict()
+
+ choices.extend(response["choices"])
+ update_token_usage(_keys, response, token_usage)
+ if not system_fingerprint:
+ system_fingerprint = response.get("system_fingerprint")
+ return self.create_llm_result(
+ choices,
+ prompts,
+ params,
+ token_usage,
+ system_fingerprint=system_fingerprint,
+ )
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Call out to OpenAI's endpoint async with k unique prompts."""
+ params = self._invocation_params
+ params = {**params, **kwargs}
+ sub_prompts = self.get_sub_prompts(params, prompts, stop)
+ choices = []
+ token_usage: Dict[str, int] = {}
+ # Get the token usage from the response.
+ # Includes prompt, completion, and total tokens used.
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
+ system_fingerprint: Optional[str] = None
+ for _prompts in sub_prompts:
+ if self.streaming:
+ if len(_prompts) > 1:
+ raise ValueError("Cannot stream results with multiple prompts.")
+
+ generation: Optional[GenerationChunk] = None
+ async for chunk in self._astream(
+ _prompts[0], stop, run_manager, **kwargs
+ ):
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ choices.append(
+ {
+ "text": generation.text,
+ "finish_reason": generation.generation_info.get("finish_reason")
+ if generation.generation_info
+ else None,
+ "logprobs": generation.generation_info.get("logprobs")
+ if generation.generation_info
+ else None,
+ }
+ )
+ else:
+ response = await acompletion_with_retry(
+ self, prompt=_prompts, run_manager=run_manager, **params
+ )
+ if not isinstance(response, dict):
+ response = response.dict()
+ choices.extend(response["choices"])
+ update_token_usage(_keys, response, token_usage)
+ return self.create_llm_result(
+ choices,
+ prompts,
+ params,
+ token_usage,
+ system_fingerprint=system_fingerprint,
+ )
+
+ def get_sub_prompts(
+ self,
+ params: Dict[str, Any],
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ ) -> List[List[str]]:
+ """Get the sub prompts for llm call."""
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+ if params["max_tokens"] == -1:
+ if len(prompts) != 1:
+ raise ValueError(
+ "max_tokens set to -1 not supported for multiple inputs."
+ )
+ params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
+ sub_prompts = [
+ prompts[i : i + self.batch_size]
+ for i in range(0, len(prompts), self.batch_size)
+ ]
+ return sub_prompts
+
+ def create_llm_result(
+ self,
+ choices: Any,
+ prompts: List[str],
+ params: Dict[str, Any],
+ token_usage: Dict[str, int],
+ *,
+ system_fingerprint: Optional[str] = None,
+ ) -> LLMResult:
+ """Create the LLMResult from the choices and prompts."""
+ generations = []
+ n = params.get("n", self.n)
+ for i, _ in enumerate(prompts):
+ sub_choices = choices[i * n : (i + 1) * n]
+ generations.append(
+ [
+ Generation(
+ text=choice["text"],
+ generation_info=dict(
+ finish_reason=choice.get("finish_reason"),
+ logprobs=choice.get("logprobs"),
+ ),
+ )
+ for choice in sub_choices
+ ]
+ )
+ llm_output = {"token_usage": token_usage, "model_name": self.model_name}
+ if system_fingerprint:
+ llm_output["system_fingerprint"] = system_fingerprint
+ return LLMResult(generations=generations, llm_output=llm_output)
+
+ @property
+ def _invocation_params(self) -> Dict[str, Any]:
+ """Get the parameters used to invoke the model."""
+ openai_creds: Dict[str, Any] = {}
+ if not is_openai_v1():
+ openai_creds.update(
+ {
+ "api_key": self.openai_api_key,
+ "api_base": self.openai_api_base,
+ "organization": self.openai_organization,
+ }
+ )
+ if self.openai_proxy:
+ import openai
+
+ openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
+ return {**openai_creds, **self._default_params}
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_name": self.model_name}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "openai"
+
+ def get_token_ids(self, text: str) -> List[int]:
+ """Get the token IDs using the tiktoken package."""
+ # tiktoken NOT supported for Python < 3.8
+ if sys.version_info[1] < 8:
+ return super().get_num_tokens(text)
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to calculate get_num_tokens. "
+ "Please install it with `pip install tiktoken`."
+ )
+
+ model_name = self.tiktoken_model_name or self.model_name
+ try:
+ enc = tiktoken.encoding_for_model(model_name)
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ model = "cl100k_base"
+ enc = tiktoken.get_encoding(model)
+
+ return enc.encode(
+ text,
+ allowed_special=self.allowed_special,
+ disallowed_special=self.disallowed_special,
+ )
+
+ @staticmethod
+ def modelname_to_contextsize(modelname: str) -> int:
+ """Calculate the maximum number of tokens possible to generate for a model.
+
+ Args:
+ modelname: The modelname we want to know the context size for.
+
+ Returns:
+ The maximum context size
+
+ Example:
+ .. code-block:: python
+
+ max_tokens = openai.modelname_to_contextsize("text-davinci-003")
+ """
+ model_token_mapping = {
+ "gpt-4": 8192,
+ "gpt-4-0314": 8192,
+ "gpt-4-0613": 8192,
+ "gpt-4-32k": 32768,
+ "gpt-4-32k-0314": 32768,
+ "gpt-4-32k-0613": 32768,
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0301": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16385,
+ "gpt-3.5-turbo-16k-0613": 16385,
+ "gpt-3.5-turbo-instruct": 4096,
+ "text-ada-001": 2049,
+ "ada": 2049,
+ "text-babbage-001": 2040,
+ "babbage": 2049,
+ "text-curie-001": 2049,
+ "curie": 2049,
+ "davinci": 2049,
+ "text-davinci-003": 4097,
+ "text-davinci-002": 4097,
+ "code-davinci-002": 8001,
+ "code-davinci-001": 8001,
+ "code-cushman-002": 2048,
+ "code-cushman-001": 2048,
+ }
+
+ # handling finetuned models
+ if "ft-" in modelname:
+ modelname = modelname.split(":")[0]
+
+ context_size = model_token_mapping.get(modelname, None)
+
+ if context_size is None:
+ raise ValueError(
+ f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
+ "Known models are: " + ", ".join(model_token_mapping.keys())
+ )
+
+ return context_size
+
+ @property
+ def max_context_size(self) -> int:
+ """Get max context size for this model."""
+ return self.modelname_to_contextsize(self.model_name)
+
+ def max_tokens_for_prompt(self, prompt: str) -> int:
+ """Calculate the maximum number of tokens possible to generate for a prompt.
+
+ Args:
+ prompt: The prompt to pass into the model.
+
+ Returns:
+ The maximum number of tokens to generate for a prompt.
+
+ Example:
+ .. code-block:: python
+
+ max_tokens = openai.max_token_for_prompt("Tell me a joke.")
+ """
+ num_tokens = self.get_num_tokens(prompt)
+ return self.max_context_size - num_tokens
+
+
+class OpenAI(BaseOpenAI):
+ """OpenAI large language models.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``OPENAI_API_KEY`` set with your API key.
+
+ Any parameters that are valid to be passed to the openai.create call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import OpenAI
+ openai = OpenAI(model_name="text-davinci-003")
+ """
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "llms", "openai"]
+
+ @property
+ def _invocation_params(self) -> Dict[str, Any]:
+ return {**{"model": self.model_name}, **super()._invocation_params}
+
+
+class AzureOpenAI(BaseOpenAI):
+ """Azure-specific OpenAI large language models.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``OPENAI_API_KEY`` set with your API key.
+
+ Any parameters that are valid to be passed to the openai.create call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import AzureOpenAI
+ openai = AzureOpenAI(model_name="text-davinci-003")
+ """
+
+ azure_endpoint: Union[str, None] = None
+ """Your Azure endpoint, including the resource.
+
+ Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
+
+ Example: `https://example-resource.azure.openai.com/`
+ """
+ deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
+ """A model deployment.
+
+ If given sets the base client URL to include `/deployments/{azure_deployment}`.
+ Note: this means you won't be able to use non-deployment endpoints.
+ """
+ openai_api_version: str = Field(default="", alias="api_version")
+ """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
+ openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
+ """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
+ azure_ad_token: Union[str, None] = None
+ """Your Azure Active Directory token.
+
+ Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
+
+ For more:
+ https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
+ """ # noqa: E501
+ azure_ad_token_provider: Union[Callable[[], str], None] = None
+ """A function that returns an Azure Active Directory token.
+
+ Will be invoked on every request.
+ """
+ openai_api_type: str = ""
+ """Legacy, for openai<1.0.0 support."""
+ validate_base_url: bool = True
+ """For backwards compatibility. If legacy val openai_api_base is passed in, try to
+ infer if it is a base_url or azure_endpoint and update accordingly.
+ """
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "llms", "openai"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ if values["n"] < 1:
+ raise ValueError("n must be at least 1.")
+ if values["streaming"] and values["n"] > 1:
+ raise ValueError("Cannot stream results when n > 1.")
+ if values["streaming"] and values["best_of"] > 1:
+ raise ValueError("Cannot stream results when best_of > 1.")
+
+ # Check OPENAI_KEY for backwards compatibility.
+ # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
+ # other forms of azure credentials.
+ values["openai_api_key"] = (
+ values["openai_api_key"]
+ or os.getenv("AZURE_OPENAI_API_KEY")
+ or os.getenv("OPENAI_API_KEY")
+ )
+
+ values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
+ "AZURE_OPENAI_ENDPOINT"
+ )
+ values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
+ "AZURE_OPENAI_AD_TOKEN"
+ )
+ values["openai_api_base"] = values["openai_api_base"] or os.getenv(
+ "OPENAI_API_BASE"
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values,
+ "openai_proxy",
+ "OPENAI_PROXY",
+ default="",
+ )
+ values["openai_organization"] = (
+ values["openai_organization"]
+ or os.getenv("OPENAI_ORG_ID")
+ or os.getenv("OPENAI_ORGANIZATION")
+ )
+ values["openai_api_version"] = values["openai_api_version"] or os.getenv(
+ "OPENAI_API_VERSION"
+ )
+ values["openai_api_type"] = get_from_dict_or_env(
+ values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
+ )
+ try:
+ import openai
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ if is_openai_v1():
+ # For backwards compatibility. Before openai v1, no distinction was made
+ # between azure_endpoint and base_url (openai_api_base).
+ openai_api_base = values["openai_api_base"]
+ if openai_api_base and values["validate_base_url"]:
+ if "/openai" not in openai_api_base:
+ values["openai_api_base"] = (
+ values["openai_api_base"].rstrip("/") + "/openai"
+ )
+ warnings.warn(
+ "As of openai>=1.0.0, Azure endpoints should be specified via "
+ f"the `azure_endpoint` param not `openai_api_base` "
+ f"(or alias `base_url`). Updating `openai_api_base` from "
+ f"{openai_api_base} to {values['openai_api_base']}."
+ )
+ if values["deployment_name"]:
+ warnings.warn(
+ "As of openai>=1.0.0, if `deployment_name` (or alias "
+ "`azure_deployment`) is specified then "
+ "`openai_api_base` (or alias `base_url`) should not be. "
+ "Instead use `deployment_name` (or alias `azure_deployment`) "
+ "and `azure_endpoint`."
+ )
+ if values["deployment_name"] not in values["openai_api_base"]:
+ warnings.warn(
+ "As of openai>=1.0.0, if `openai_api_base` "
+ "(or alias `base_url`) is specified it is expected to be "
+ "of the form "
+ "https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
+ f"Updating {openai_api_base} to "
+ f"{values['openai_api_base']}."
+ )
+ values["openai_api_base"] += (
+ "/deployments/" + values["deployment_name"]
+ )
+ values["deployment_name"] = None
+ client_params = {
+ "api_version": values["openai_api_version"],
+ "azure_endpoint": values["azure_endpoint"],
+ "azure_deployment": values["deployment_name"],
+ "api_key": values["openai_api_key"],
+ "azure_ad_token": values["azure_ad_token"],
+ "azure_ad_token_provider": values["azure_ad_token_provider"],
+ "organization": values["openai_organization"],
+ "base_url": values["openai_api_base"],
+ "timeout": values["request_timeout"],
+ "max_retries": values["max_retries"],
+ "default_headers": values["default_headers"],
+ "default_query": values["default_query"],
+ "http_client": values["http_client"],
+ }
+ values["client"] = openai.AzureOpenAI(**client_params).completions
+ values["async_client"] = openai.AsyncAzureOpenAI(
+ **client_params
+ ).completions
+
+ else:
+ values["client"] = openai.Completion
+
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ return {
+ **{"deployment_name": self.deployment_name},
+ **super()._identifying_params,
+ }
+
+ @property
+ def _invocation_params(self) -> Dict[str, Any]:
+ if is_openai_v1():
+ openai_params = {"model": self.deployment_name}
+ else:
+ openai_params = {
+ "engine": self.deployment_name,
+ "api_type": self.openai_api_type,
+ "api_version": self.openai_api_version,
+ }
+ return {**openai_params, **super()._invocation_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "azure"
+
+ @property
+ def lc_attributes(self) -> Dict[str, Any]:
+ return {
+ "openai_api_type": self.openai_api_type,
+ "openai_api_version": self.openai_api_version,
+ }
+
+
+class OpenAIChat(BaseLLM):
+ """OpenAI Chat large language models.
+
+ To use, you should have the ``openai`` python package installed, and the
+ environment variable ``OPENAI_API_KEY`` set with your API key.
+
+ Any parameters that are valid to be passed to the openai.create call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import OpenAIChat
+ openaichat = OpenAIChat(model_name="gpt-3.5-turbo")
+ """
+
+ client: Any = Field(default=None, exclude=True) #: :meta private:
+ async_client: Any = Field(default=None, exclude=True) #: :meta private:
+ model_name: str = "gpt-3.5-turbo"
+ """Model name to use."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not explicitly specified."""
+ # When updating this to use a SecretStr
+ # Check for classes that derive from this class (as some of them
+ # may assume openai_api_key is a str)
+ openai_api_key: Optional[str] = Field(default=None, alias="api_key")
+ """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
+ openai_api_base: Optional[str] = Field(default=None, alias="base_url")
+ """Base URL path for API requests, leave blank if not using a proxy or service
+ emulator."""
+ # to support explicit proxy for OpenAI
+ openai_proxy: Optional[str] = None
+ max_retries: int = 6
+ """Maximum number of retries to make when generating."""
+ prefix_messages: List = Field(default_factory=list)
+ """Series of messages for Chat input."""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
+ """Set of special tokens that are allowedγ"""
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all"
+ """Set of special tokens that are not allowedγ"""
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ openai_api_key = get_from_dict_or_env(
+ values, "openai_api_key", "OPENAI_API_KEY"
+ )
+ openai_api_base = get_from_dict_or_env(
+ values,
+ "openai_api_base",
+ "OPENAI_API_BASE",
+ default="",
+ )
+ openai_proxy = get_from_dict_or_env(
+ values,
+ "openai_proxy",
+ "OPENAI_PROXY",
+ default="",
+ )
+ openai_organization = get_from_dict_or_env(
+ values, "openai_organization", "OPENAI_ORGANIZATION", default=""
+ )
+ try:
+ import openai
+
+ openai.api_key = openai_api_key
+ if openai_api_base:
+ openai.api_base = openai_api_base
+ if openai_organization:
+ openai.organization = openai_organization
+ if openai_proxy:
+ openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+ try:
+ values["client"] = openai.ChatCompletion
+ except AttributeError:
+ raise ValueError(
+ "`openai` has no `ChatCompletion` attribute, this is likely "
+ "due to an old version of the openai package. Try upgrading it "
+ "with `pip install --upgrade openai`."
+ )
+ warnings.warn(
+ "You are trying to use a chat model. This way of initializing it is "
+ "no longer supported. Instead, please use: "
+ "`from langchain_community.chat_models import ChatOpenAI`"
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling OpenAI API."""
+ return self.model_kwargs
+
+ def _get_chat_params(
+ self, prompts: List[str], stop: Optional[List[str]] = None
+ ) -> Tuple:
+ if len(prompts) > 1:
+ raise ValueError(
+ f"OpenAIChat currently only supports single prompt, got {prompts}"
+ )
+ messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
+ params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
+ if stop is not None:
+ if "stop" in params:
+ raise ValueError("`stop` found in both the input and default params.")
+ params["stop"] = stop
+ if params.get("max_tokens") == -1:
+ # for ChatGPT api, omitting max_tokens is equivalent to having no limit
+ del params["max_tokens"]
+ return messages, params
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ messages, params = self._get_chat_params([prompt], stop)
+ params = {**params, **kwargs, "stream": True}
+ for stream_resp in completion_with_retry(
+ self, messages=messages, run_manager=run_manager, **params
+ ):
+ if not isinstance(stream_resp, dict):
+ stream_resp = stream_resp.dict()
+ token = stream_resp["choices"][0]["delta"].get("content", "")
+ chunk = GenerationChunk(text=token)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(token, chunk=chunk)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ messages, params = self._get_chat_params([prompt], stop)
+ params = {**params, **kwargs, "stream": True}
+ async for stream_resp in await acompletion_with_retry(
+ self, messages=messages, run_manager=run_manager, **params
+ ):
+ if not isinstance(stream_resp, dict):
+ stream_resp = stream_resp.dict()
+ token = stream_resp["choices"][0]["delta"].get("content", "")
+ chunk = GenerationChunk(text=token)
+ yield chunk
+ if run_manager:
+ await run_manager.on_llm_new_token(token, chunk=chunk)
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ if self.streaming:
+ generation: Optional[GenerationChunk] = None
+ for chunk in self._stream(prompts[0], stop, run_manager, **kwargs):
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ return LLMResult(generations=[[generation]])
+
+ messages, params = self._get_chat_params(prompts, stop)
+ params = {**params, **kwargs}
+ full_response = completion_with_retry(
+ self, messages=messages, run_manager=run_manager, **params
+ )
+ if not isinstance(full_response, dict):
+ full_response = full_response.dict()
+ llm_output = {
+ "token_usage": full_response["usage"],
+ "model_name": self.model_name,
+ }
+ return LLMResult(
+ generations=[
+ [Generation(text=full_response["choices"][0]["message"]["content"])]
+ ],
+ llm_output=llm_output,
+ )
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ if self.streaming:
+ generation: Optional[GenerationChunk] = None
+ async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ return LLMResult(generations=[[generation]])
+
+ messages, params = self._get_chat_params(prompts, stop)
+ params = {**params, **kwargs}
+ full_response = await acompletion_with_retry(
+ self, messages=messages, run_manager=run_manager, **params
+ )
+ if not isinstance(full_response, dict):
+ full_response = full_response.dict()
+ llm_output = {
+ "token_usage": full_response["usage"],
+ "model_name": self.model_name,
+ }
+ return LLMResult(
+ generations=[
+ [Generation(text=full_response["choices"][0]["message"]["content"])]
+ ],
+ llm_output=llm_output,
+ )
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_name": self.model_name}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "openai-chat"
+
+ def get_token_ids(self, text: str) -> List[int]:
+ """Get the token IDs using the tiktoken package."""
+ # tiktoken NOT supported for Python < 3.8
+ if sys.version_info[1] < 8:
+ return super().get_token_ids(text)
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to calculate get_num_tokens. "
+ "Please install it with `pip install tiktoken`."
+ )
+
+ enc = tiktoken.encoding_for_model(self.model_name)
+ return enc.encode(
+ text,
+ allowed_special=self.allowed_special,
+ disallowed_special=self.disallowed_special,
+ )
diff --git a/libs/community/langchain_community/llms/openllm.py b/libs/community/langchain_community/llms/openllm.py
new file mode 100644
index 00000000000..afb5a18f9ba
--- /dev/null
+++ b/libs/community/langchain_community/llms/openllm.py
@@ -0,0 +1,331 @@
+from __future__ import annotations
+
+import copy
+import json
+import logging
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ TypedDict,
+ Union,
+ overload,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import PrivateAttr
+
+if TYPE_CHECKING:
+ import openllm
+
+
+ServerType = Literal["http", "grpc"]
+
+
+class IdentifyingParams(TypedDict):
+ """Parameters for identifying a model as a typed dict."""
+
+ model_name: str
+ model_id: Optional[str]
+ server_url: Optional[str]
+ server_type: Optional[ServerType]
+ embedded: bool
+ llm_kwargs: Dict[str, Any]
+
+
+logger = logging.getLogger(__name__)
+
+
+class OpenLLM(LLM):
+ """OpenLLM, supporting both in-process model
+ instance and remote OpenLLM servers.
+
+ To use, you should have the openllm library installed:
+
+ .. code-block:: bash
+
+ pip install openllm
+
+ Learn more at: https://github.com/bentoml/openllm
+
+ Example running an LLM model locally managed by OpenLLM:
+ .. code-block:: python
+
+ from langchain_community.llms import OpenLLM
+ llm = OpenLLM(
+ model_name='flan-t5',
+ model_id='google/flan-t5-large',
+ )
+ llm("What is the difference between a duck and a goose?")
+
+ For all available supported models, you can run 'openllm models'.
+
+ If you have a OpenLLM server running, you can also use it remotely:
+ .. code-block:: python
+
+ from langchain_community.llms import OpenLLM
+ llm = OpenLLM(server_url='http://localhost:3000')
+ llm("What is the difference between a duck and a goose?")
+ """
+
+ model_name: Optional[str] = None
+ """Model name to use. See 'openllm models' for all available models."""
+ model_id: Optional[str] = None
+ """Model Id to use. If not provided, will use the default model for the model name.
+ See 'openllm models' for all available model variants."""
+ server_url: Optional[str] = None
+ """Optional server URL that currently runs a LLMServer with 'openllm start'."""
+ server_type: ServerType = "http"
+ """Optional server type. Either 'http' or 'grpc'."""
+ embedded: bool = True
+ """Initialize this LLM instance in current process by default. Should
+ only set to False when using in conjunction with BentoML Service."""
+ llm_kwargs: Dict[str, Any]
+ """Keyword arguments to be passed to openllm.LLM"""
+
+ _runner: Optional[openllm.LLMRunner] = PrivateAttr(default=None)
+ _client: Union[
+ openllm.client.HTTPClient, openllm.client.GrpcClient, None
+ ] = PrivateAttr(default=None)
+
+ class Config:
+ extra = "forbid"
+
+ @overload
+ def __init__(
+ self,
+ model_name: Optional[str] = ...,
+ *,
+ model_id: Optional[str] = ...,
+ embedded: Literal[True, False] = ...,
+ **llm_kwargs: Any,
+ ) -> None:
+ ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ server_url: str = ...,
+ server_type: Literal["grpc", "http"] = ...,
+ **llm_kwargs: Any,
+ ) -> None:
+ ...
+
+ def __init__(
+ self,
+ model_name: Optional[str] = None,
+ *,
+ model_id: Optional[str] = None,
+ server_url: Optional[str] = None,
+ server_type: Literal["grpc", "http"] = "http",
+ embedded: bool = True,
+ **llm_kwargs: Any,
+ ):
+ try:
+ import openllm
+ except ImportError as e:
+ raise ImportError(
+ "Could not import openllm. Make sure to install it with "
+ "'pip install openllm.'"
+ ) from e
+
+ llm_kwargs = llm_kwargs or {}
+
+ if server_url is not None:
+ logger.debug("'server_url' is provided, returning a openllm.Client")
+ assert (
+ model_id is None and model_name is None
+ ), "'server_url' and {'model_id', 'model_name'} are mutually exclusive"
+ client_cls = (
+ openllm.client.HTTPClient
+ if server_type == "http"
+ else openllm.client.GrpcClient
+ )
+ client = client_cls(server_url)
+
+ super().__init__(
+ **{
+ "server_url": server_url,
+ "server_type": server_type,
+ "llm_kwargs": llm_kwargs,
+ }
+ )
+ self._runner = None # type: ignore
+ self._client = client
+ else:
+ assert model_name is not None, "Must provide 'model_name' or 'server_url'"
+ # since the LLM are relatively huge, we don't actually want to convert the
+ # Runner with embedded when running the server. Instead, we will only set
+ # the init_local here so that LangChain users can still use the LLM
+ # in-process. Wrt to BentoML users, setting embedded=False is the expected
+ # behaviour to invoke the runners remotely.
+ # We need to also enable ensure_available to download and setup the model.
+ runner = openllm.Runner(
+ model_name=model_name,
+ model_id=model_id,
+ init_local=embedded,
+ ensure_available=True,
+ **llm_kwargs,
+ )
+ super().__init__(
+ **{
+ "model_name": model_name,
+ "model_id": model_id,
+ "embedded": embedded,
+ "llm_kwargs": llm_kwargs,
+ }
+ )
+ self._client = None # type: ignore
+ self._runner = runner
+
+ @property
+ def runner(self) -> openllm.LLMRunner:
+ """
+ Get the underlying openllm.LLMRunner instance for integration with BentoML.
+
+ Example:
+ .. code-block:: python
+
+ llm = OpenLLM(
+ model_name='flan-t5',
+ model_id='google/flan-t5-large',
+ embedded=False,
+ )
+ tools = load_tools(["serpapi", "llm-math"], llm=llm)
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION
+ )
+ svc = bentoml.Service("langchain-openllm", runners=[llm.runner])
+
+ @svc.api(input=Text(), output=Text())
+ def chat(input_text: str):
+ return agent.run(input_text)
+ """
+ if self._runner is None:
+ raise ValueError("OpenLLM must be initialized locally with 'model_name'")
+ return self._runner
+
+ @property
+ def _identifying_params(self) -> IdentifyingParams:
+ """Get the identifying parameters."""
+ if self._client is not None:
+ self.llm_kwargs.update(self._client._config())
+ model_name = self._client._metadata()["model_name"]
+ model_id = self._client._metadata()["model_id"]
+ else:
+ if self._runner is None:
+ raise ValueError("Runner must be initialized.")
+ model_name = self.model_name
+ model_id = self.model_id
+ try:
+ self.llm_kwargs.update(
+ json.loads(self._runner.identifying_params["configuration"])
+ )
+ except (TypeError, json.JSONDecodeError):
+ pass
+ return IdentifyingParams(
+ server_url=self.server_url,
+ server_type=self.server_type,
+ embedded=self.embedded,
+ llm_kwargs=self.llm_kwargs,
+ model_name=model_name,
+ model_id=model_id,
+ )
+
+ @property
+ def _llm_type(self) -> str:
+ return "openllm_client" if self._client else "openllm"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ try:
+ import openllm
+ except ImportError as e:
+ raise ImportError(
+ "Could not import openllm. Make sure to install it with "
+ "'pip install openllm'."
+ ) from e
+
+ copied = copy.deepcopy(self.llm_kwargs)
+ copied.update(kwargs)
+ config = openllm.AutoConfig.for_model(
+ self._identifying_params["model_name"], **copied
+ )
+ if self._client:
+ res = self._client.generate(
+ prompt, **config.model_dump(flatten=True)
+ ).responses[0]
+ else:
+ assert self._runner is not None
+ res = self._runner(prompt, **config.model_dump(flatten=True))
+ if isinstance(res, dict) and "text" in res:
+ return res["text"]
+ elif isinstance(res, str):
+ return res
+ else:
+ raise ValueError(
+ "Expected result to be a dict with key 'text' or a string. "
+ f"Received {res}"
+ )
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ try:
+ import openllm
+ except ImportError as e:
+ raise ImportError(
+ "Could not import openllm. Make sure to install it with "
+ "'pip install openllm'."
+ ) from e
+
+ copied = copy.deepcopy(self.llm_kwargs)
+ copied.update(kwargs)
+ config = openllm.AutoConfig.for_model(
+ self._identifying_params["model_name"], **copied
+ )
+ if self._client:
+ async_client = openllm.client.AsyncHTTPClient(self.server_url)
+ res = (
+ await async_client.generate(prompt, **config.model_dump(flatten=True))
+ ).responses[0]
+ else:
+ assert self._runner is not None
+ (
+ prompt,
+ generate_kwargs,
+ postprocess_kwargs,
+ ) = self._runner.llm.sanitize_parameters(prompt, **kwargs)
+ generated_result = await self._runner.generate.async_run(
+ prompt, **generate_kwargs
+ )
+ res = self._runner.llm.postprocess_generate(
+ prompt, generated_result, **postprocess_kwargs
+ )
+
+ if isinstance(res, dict) and "text" in res:
+ return res["text"]
+ elif isinstance(res, str):
+ return res
+ else:
+ raise ValueError(
+ "Expected result to be a dict with key 'text' or a string. "
+ f"Received {res}"
+ )
diff --git a/libs/community/langchain_community/llms/openlm.py b/libs/community/langchain_community/llms/openlm.py
new file mode 100644
index 00000000000..47b303012bf
--- /dev/null
+++ b/libs/community/langchain_community/llms/openlm.py
@@ -0,0 +1,32 @@
+from typing import Any, Dict
+
+from langchain_core.pydantic_v1 import root_validator
+
+from langchain_community.llms.openai import BaseOpenAI
+
+
+class OpenLM(BaseOpenAI):
+ """OpenLM models."""
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ @property
+ def _invocation_params(self) -> Dict[str, Any]:
+ return {**{"model": self.model_name}, **super()._invocation_params}
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ try:
+ import openlm
+
+ values["client"] = openlm.Completion
+ except ImportError:
+ raise ImportError(
+ "Could not import openlm python package. "
+ "Please install it with `pip install openlm`."
+ )
+ if values["streaming"]:
+ raise ValueError("Streaming not supported with openlm")
+ return values
diff --git a/libs/community/langchain_community/llms/pai_eas_endpoint.py b/libs/community/langchain_community/llms/pai_eas_endpoint.py
new file mode 100644
index 00000000000..b74b9ca2a3d
--- /dev/null
+++ b/libs/community/langchain_community/llms/pai_eas_endpoint.py
@@ -0,0 +1,240 @@
+import json
+import logging
+from typing import Any, Dict, Iterator, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class PaiEasEndpoint(LLM):
+ """Langchain LLM class to help to access eass llm service.
+
+ To use this endpoint, must have a deployed eas chat llm service on PAI AliCloud.
+ One can set the environment variable ``eas_service_url`` and ``eas_service_token``.
+ The environment variables can set with your eas service url and service token.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms.pai_eas_endpoint import PaiEasEndpoint
+ eas_chat_endpoint = PaiEasChatEndpoint(
+ eas_service_url="your_service_url",
+ eas_service_token="your_service_token"
+ )
+ """
+
+ """PAI-EAS Service URL"""
+ eas_service_url: str
+
+ """PAI-EAS Service TOKEN"""
+ eas_service_token: str
+
+ """PAI-EAS Service Infer Params"""
+ max_new_tokens: Optional[int] = 512
+ temperature: Optional[float] = 0.95
+ top_p: Optional[float] = 0.1
+ top_k: Optional[int] = 0
+ stop_sequences: Optional[List[str]] = None
+
+ """Enable stream chat mode."""
+ streaming: bool = False
+
+ """Key/value arguments to pass to the model. Reserved for future use"""
+ model_kwargs: Optional[dict] = None
+
+ version: Optional[str] = "2.0"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["eas_service_url"] = get_from_dict_or_env(
+ values, "eas_service_url", "EAS_SERVICE_URL"
+ )
+ values["eas_service_token"] = get_from_dict_or_env(
+ values, "eas_service_token", "EAS_SERVICE_TOKEN"
+ )
+
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "pai_eas_endpoint"
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Cohere API."""
+ return {
+ "max_new_tokens": self.max_new_tokens,
+ "temperature": self.temperature,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ "stop_sequences": [],
+ }
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ "eas_service_url": self.eas_service_url,
+ "eas_service_token": self.eas_service_token,
+ **_model_kwargs,
+ }
+
+ def _invocation_params(
+ self, stop_sequences: Optional[List[str]], **kwargs: Any
+ ) -> dict:
+ params = self._default_params
+ if self.stop_sequences is not None and stop_sequences is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+ elif self.stop_sequences is not None:
+ params["stop"] = self.stop_sequences
+ else:
+ params["stop"] = stop_sequences
+ if self.model_kwargs:
+ params.update(self.model_kwargs)
+ return {**params, **kwargs}
+
+ @staticmethod
+ def _process_response(
+ response: Any, stop: Optional[List[str]], version: Optional[str]
+ ) -> str:
+ if version == "1.0":
+ text = response
+ else:
+ text = response["response"]
+
+ if stop:
+ text = enforce_stop_tokens(text, stop)
+ return "".join(text)
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ params = self._invocation_params(stop, **kwargs)
+ prompt = prompt.strip()
+ response = None
+ try:
+ if self.streaming:
+ completion = ""
+ for chunk in self._stream(prompt, stop, run_manager, **params):
+ completion += chunk.text
+ return completion
+ else:
+ response = self._call_eas(prompt, params)
+ _stop = params.get("stop")
+ return self._process_response(response, _stop, self.version)
+ except Exception as error:
+ raise ValueError(f"Error raised by the service: {error}")
+
+ def _call_eas(self, prompt: str = "", params: Dict = {}) -> Any:
+ """Generate text from the eas service."""
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"{self.eas_service_token}",
+ }
+ if self.version == "1.0":
+ body = {
+ "input_ids": f"{prompt}",
+ }
+ else:
+ body = {
+ "prompt": f"{prompt}",
+ }
+
+ # add params to body
+ for key, value in params.items():
+ body[key] = value
+
+ # make request
+ response = requests.post(self.eas_service_url, headers=headers, json=body)
+
+ if response.status_code != 200:
+ raise Exception(
+ f"Request failed with status code {response.status_code}"
+ f" and message {response.text}"
+ )
+
+ try:
+ return json.loads(response.text)
+ except Exception as e:
+ if isinstance(e, json.decoder.JSONDecodeError):
+ return response.text
+ raise e
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ invocation_params = self._invocation_params(stop, **kwargs)
+
+ headers = {
+ "User-Agent": "Test Client",
+ "Authorization": f"{self.eas_service_token}",
+ }
+
+ if self.version == "1.0":
+ pload = {"input_ids": prompt, **invocation_params}
+ response = requests.post(
+ self.eas_service_url, headers=headers, json=pload, stream=True
+ )
+
+ res = GenerationChunk(text=response.text)
+
+ if run_manager:
+ run_manager.on_llm_new_token(res.text)
+
+ # yield text, if any
+ yield res
+ else:
+ pload = {"prompt": prompt, "use_stream_chat": "True", **invocation_params}
+
+ response = requests.post(
+ self.eas_service_url, headers=headers, json=pload, stream=True
+ )
+
+ for chunk in response.iter_lines(
+ chunk_size=8192, decode_unicode=False, delimiter=b"\0"
+ ):
+ if chunk:
+ data = json.loads(chunk.decode("utf-8"))
+ output = data["response"]
+ # identify stop sequence in generated text, if any
+ stop_seq_found: Optional[str] = None
+ for stop_seq in invocation_params["stop"]:
+ if stop_seq in output:
+ stop_seq_found = stop_seq
+
+ # identify text to yield
+ text: Optional[str] = None
+ if stop_seq_found:
+ text = output[: output.index(stop_seq_found)]
+ else:
+ text = output
+
+ # yield text, if any
+ if text:
+ res = GenerationChunk(text=text)
+ yield res
+ if run_manager:
+ run_manager.on_llm_new_token(res.text)
+
+ # break if stop sequence found
+ if stop_seq_found:
+ break
diff --git a/libs/community/langchain_community/llms/petals.py b/libs/community/langchain_community/llms/petals.py
new file mode 100644
index 00000000000..1508d9d3036
--- /dev/null
+++ b/libs/community/langchain_community/llms/petals.py
@@ -0,0 +1,153 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class Petals(LLM):
+ """Petals Bloom models.
+
+ To use, you should have the ``petals`` python package installed, and the
+ environment variable ``HUGGINGFACE_API_KEY`` set with your API key.
+
+ Any parameters that are valid to be passed to the call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import petals
+ petals = Petals()
+
+ """
+
+ client: Any
+ """The client to use for the API calls."""
+
+ tokenizer: Any
+ """The tokenizer to use for the API calls."""
+
+ model_name: str = "bigscience/bloom-petals"
+ """The model to use."""
+
+ temperature: float = 0.7
+ """What sampling temperature to use"""
+
+ max_new_tokens: int = 256
+ """The maximum number of new tokens to generate in the completion."""
+
+ top_p: float = 0.9
+ """The cumulative probability for top-p sampling."""
+
+ top_k: Optional[int] = None
+ """The number of highest probability vocabulary tokens
+ to keep for top-k-filtering."""
+
+ do_sample: bool = True
+ """Whether or not to use sampling; use greedy decoding otherwise."""
+
+ max_length: Optional[int] = None
+ """The maximum length of the sequence to be generated."""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call
+ not explicitly specified."""
+
+ huggingface_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic config."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ huggingface_api_key = get_from_dict_or_env(
+ values, "huggingface_api_key", "HUGGINGFACE_API_KEY"
+ )
+ try:
+ from petals import AutoDistributedModelForCausalLM
+ from transformers import AutoTokenizer
+
+ model_name = values["model_name"]
+ values["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
+ values["client"] = AutoDistributedModelForCausalLM.from_pretrained(
+ model_name
+ )
+ values["huggingface_api_key"] = huggingface_api_key
+
+ except ImportError:
+ raise ImportError(
+ "Could not import transformers or petals python package."
+ "Please install with `pip install -U transformers petals`."
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Petals API."""
+ normal_params = {
+ "temperature": self.temperature,
+ "max_new_tokens": self.max_new_tokens,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "do_sample": self.do_sample,
+ "max_length": self.max_length,
+ }
+ return {**normal_params, **self.model_kwargs}
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_name": self.model_name}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "petals"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the Petals API."""
+ params = self._default_params
+ params = {**params, **kwargs}
+ inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
+ outputs = self.client.generate(inputs, **params)
+ text = self.tokenizer.decode(outputs[0])
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the model parameters
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/pipelineai.py b/libs/community/langchain_community/llms/pipelineai.py
new file mode 100644
index 00000000000..91182d99760
--- /dev/null
+++ b/libs/community/langchain_community/llms/pipelineai.py
@@ -0,0 +1,115 @@
+import logging
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class PipelineAI(LLM, BaseModel):
+ """PipelineAI large language models.
+
+ To use, you should have the ``pipeline-ai`` python package installed,
+ and the environment variable ``PIPELINE_API_KEY`` set with your API key.
+
+ Any parameters that are valid to be passed to the call can be passed
+ in, even if not explicitly saved on this class.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import PipelineAI
+ pipeline = PipelineAI(pipeline_key="")
+ """
+
+ pipeline_key: str = ""
+ """The id or tag of the target pipeline"""
+
+ pipeline_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any pipeline parameters valid for `create` call not
+ explicitly specified."""
+
+ pipeline_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic config."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("pipeline_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""{field_name} was transferred to pipeline_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["pipeline_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ pipeline_api_key = get_from_dict_or_env(
+ values, "pipeline_api_key", "PIPELINE_API_KEY"
+ )
+ values["pipeline_api_key"] = pipeline_api_key
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"pipeline_key": self.pipeline_key},
+ **{"pipeline_kwargs": self.pipeline_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "pipeline_ai"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call to Pipeline Cloud endpoint."""
+ try:
+ from pipeline import PipelineCloud
+ except ImportError:
+ raise ImportError(
+ "Could not import pipeline-ai python package. "
+ "Please install it with `pip install pipeline-ai`."
+ )
+ client = PipelineCloud(token=self.pipeline_api_key)
+ params = self.pipeline_kwargs or {}
+ params = {**params, **kwargs}
+
+ run = client.run_pipeline(self.pipeline_key, [prompt, params])
+ try:
+ text = run.result_preview[0][0]
+ except AttributeError:
+ raise AttributeError(
+ f"A pipeline run should have a `result_preview` attribute."
+ f"Run was: {run}"
+ )
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the pipeline parameters
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/predibase.py b/libs/community/langchain_community/llms/predibase.py
new file mode 100644
index 00000000000..2aaafd9128f
--- /dev/null
+++ b/libs/community/langchain_community/llms/predibase.py
@@ -0,0 +1,50 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Field
+
+
+class Predibase(LLM):
+ """Use your Predibase models with Langchain.
+
+ To use, you should have the ``predibase`` python package installed,
+ and have your Predibase API key.
+ """
+
+ model: str
+ predibase_api_key: str
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+
+ @property
+ def _llm_type(self) -> str:
+ return "predibase"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ try:
+ from predibase import PredibaseClient
+
+ pc = PredibaseClient(token=self.predibase_api_key)
+ except ImportError as e:
+ raise ImportError(
+ "Could not import Predibase Python package. "
+ "Please install it with `pip install predibase`."
+ ) from e
+ except ValueError as e:
+ raise ValueError("Your API key is not correct. Please try again") from e
+ # load model and version
+ results = pc.prompt(prompt, model_name=self.model)
+ return results[0].response
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"model_kwargs": self.model_kwargs},
+ }
diff --git a/libs/community/langchain_community/llms/predictionguard.py b/libs/community/langchain_community/llms/predictionguard.py
new file mode 100644
index 00000000000..51291500ca7
--- /dev/null
+++ b/libs/community/langchain_community/llms/predictionguard.py
@@ -0,0 +1,130 @@
+import logging
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class PredictionGuard(LLM):
+ """Prediction Guard large language models.
+
+ To use, you should have the ``predictionguard`` python package installed, and the
+ environment variable ``PREDICTIONGUARD_TOKEN`` set with your access token, or pass
+ it as a named parameter to the constructor. To use Prediction Guard's API along
+ with OpenAI models, set the environment variable ``OPENAI_API_KEY`` with your
+ OpenAI API key as well.
+
+ Example:
+ .. code-block:: python
+
+ pgllm = PredictionGuard(model="MPT-7B-Instruct",
+ token="my-access-token",
+ output={
+ "type": "boolean"
+ })
+ """
+
+ client: Any #: :meta private:
+ model: Optional[str] = "MPT-7B-Instruct"
+ """Model name to use."""
+
+ output: Optional[Dict[str, Any]] = None
+ """The output type or structure for controlling the LLM output."""
+
+ max_tokens: int = 256
+ """Denotes the number of tokens to predict per generation."""
+
+ temperature: float = 0.75
+ """A non-negative float that tunes the degree of randomness in generation."""
+
+ token: Optional[str] = None
+ """Your Prediction Guard access token."""
+
+ stop: Optional[List[str]] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the access token and python package exists in environment."""
+ token = get_from_dict_or_env(values, "token", "PREDICTIONGUARD_TOKEN")
+ try:
+ import predictionguard as pg
+
+ values["client"] = pg.Client(token=token)
+ except ImportError:
+ raise ImportError(
+ "Could not import predictionguard python package. "
+ "Please install it with `pip install predictionguard`."
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling the Prediction Guard API."""
+ return {
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ }
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model": self.model}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "predictionguard"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Prediction Guard's model API.
+ Args:
+ prompt: The prompt to pass into the model.
+ Returns:
+ The string generated by the model.
+ Example:
+ .. code-block:: python
+ response = pgllm("Tell me a joke.")
+ """
+ import predictionguard as pg
+
+ params = self._default_params
+ if self.stop is not None and stop is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+ elif self.stop is not None:
+ params["stop_sequences"] = self.stop
+ else:
+ params["stop_sequences"] = stop
+
+ response = pg.Completion.create(
+ model=self.model,
+ prompt=prompt,
+ output=self.output,
+ temperature=params["temperature"],
+ max_tokens=params["max_tokens"],
+ **kwargs,
+ )
+ text = response["choices"][0]["text"]
+
+ # If stop tokens are provided, Prediction Guard's endpoint returns them.
+ # In order to make this consistent with other endpoints, we strip them.
+ if stop is not None or self.stop is not None:
+ text = enforce_stop_tokens(text, params["stop_sequences"])
+
+ return text
diff --git a/libs/community/langchain_community/llms/promptlayer_openai.py b/libs/community/langchain_community/llms/promptlayer_openai.py
new file mode 100644
index 00000000000..f6944d1725d
--- /dev/null
+++ b/libs/community/langchain_community/llms/promptlayer_openai.py
@@ -0,0 +1,232 @@
+import datetime
+from typing import Any, List, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.outputs import LLMResult
+
+from langchain_community.llms.openai import OpenAI, OpenAIChat
+
+
+class PromptLayerOpenAI(OpenAI):
+ """PromptLayer OpenAI large language models.
+
+ To use, you should have the ``openai`` and ``promptlayer`` python
+ package installed, and the environment variable ``OPENAI_API_KEY``
+ and ``PROMPTLAYER_API_KEY`` set with your openAI API key and
+ promptlayer key respectively.
+
+ All parameters that can be passed to the OpenAI LLM can also
+ be passed here. The PromptLayerOpenAI LLM adds two optional
+
+ parameters:
+ ``pl_tags``: List of strings to tag the request with.
+ ``return_pl_id``: If True, the PromptLayer request ID will be
+ returned in the ``generation_info`` field of the
+ ``Generation`` object.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import PromptLayerOpenAI
+ openai = PromptLayerOpenAI(model_name="text-davinci-003")
+ """
+
+ pl_tags: Optional[List[str]]
+ return_pl_id: Optional[bool] = False
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Call OpenAI generate and then call PromptLayer API to log the request."""
+ from promptlayer.utils import get_api_key, promptlayer_api_request
+
+ request_start_time = datetime.datetime.now().timestamp()
+ generated_responses = super()._generate(prompts, stop, run_manager)
+ request_end_time = datetime.datetime.now().timestamp()
+ for i in range(len(prompts)):
+ prompt = prompts[i]
+ generation = generated_responses.generations[i][0]
+ resp = {
+ "text": generation.text,
+ "llm_output": generated_responses.llm_output,
+ }
+ params = {**self._identifying_params, **kwargs}
+ pl_request_id = promptlayer_api_request(
+ "langchain.PromptLayerOpenAI",
+ "langchain",
+ [prompt],
+ params,
+ self.pl_tags,
+ resp,
+ request_start_time,
+ request_end_time,
+ get_api_key(),
+ return_pl_id=self.return_pl_id,
+ )
+ if self.return_pl_id:
+ if generation.generation_info is None or not isinstance(
+ generation.generation_info, dict
+ ):
+ generation.generation_info = {}
+ generation.generation_info["pl_request_id"] = pl_request_id
+ return generated_responses
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ from promptlayer.utils import get_api_key, promptlayer_api_request_async
+
+ request_start_time = datetime.datetime.now().timestamp()
+ generated_responses = await super()._agenerate(prompts, stop, run_manager)
+ request_end_time = datetime.datetime.now().timestamp()
+ for i in range(len(prompts)):
+ prompt = prompts[i]
+ generation = generated_responses.generations[i][0]
+ resp = {
+ "text": generation.text,
+ "llm_output": generated_responses.llm_output,
+ }
+ params = {**self._identifying_params, **kwargs}
+ pl_request_id = await promptlayer_api_request_async(
+ "langchain.PromptLayerOpenAI.async",
+ "langchain",
+ [prompt],
+ params,
+ self.pl_tags,
+ resp,
+ request_start_time,
+ request_end_time,
+ get_api_key(),
+ return_pl_id=self.return_pl_id,
+ )
+ if self.return_pl_id:
+ if generation.generation_info is None or not isinstance(
+ generation.generation_info, dict
+ ):
+ generation.generation_info = {}
+ generation.generation_info["pl_request_id"] = pl_request_id
+ return generated_responses
+
+
+class PromptLayerOpenAIChat(OpenAIChat):
+ """Wrapper around OpenAI large language models.
+
+ To use, you should have the ``openai`` and ``promptlayer`` python
+ package installed, and the environment variable ``OPENAI_API_KEY``
+ and ``PROMPTLAYER_API_KEY`` set with your openAI API key and
+ promptlayer key respectively.
+
+ All parameters that can be passed to the OpenAIChat LLM can also
+ be passed here. The PromptLayerOpenAIChat adds two optional
+
+ parameters:
+ ``pl_tags``: List of strings to tag the request with.
+ ``return_pl_id``: If True, the PromptLayer request ID will be
+ returned in the ``generation_info`` field of the
+ ``Generation`` object.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import PromptLayerOpenAIChat
+ openaichat = PromptLayerOpenAIChat(model_name="gpt-3.5-turbo")
+ """
+
+ pl_tags: Optional[List[str]]
+ return_pl_id: Optional[bool] = False
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Call OpenAI generate and then call PromptLayer API to log the request."""
+ from promptlayer.utils import get_api_key, promptlayer_api_request
+
+ request_start_time = datetime.datetime.now().timestamp()
+ generated_responses = super()._generate(prompts, stop, run_manager)
+ request_end_time = datetime.datetime.now().timestamp()
+ for i in range(len(prompts)):
+ prompt = prompts[i]
+ generation = generated_responses.generations[i][0]
+ resp = {
+ "text": generation.text,
+ "llm_output": generated_responses.llm_output,
+ }
+ params = {**self._identifying_params, **kwargs}
+ pl_request_id = promptlayer_api_request(
+ "langchain.PromptLayerOpenAIChat",
+ "langchain",
+ [prompt],
+ params,
+ self.pl_tags,
+ resp,
+ request_start_time,
+ request_end_time,
+ get_api_key(),
+ return_pl_id=self.return_pl_id,
+ )
+ if self.return_pl_id:
+ if generation.generation_info is None or not isinstance(
+ generation.generation_info, dict
+ ):
+ generation.generation_info = {}
+ generation.generation_info["pl_request_id"] = pl_request_id
+ return generated_responses
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ from promptlayer.utils import get_api_key, promptlayer_api_request_async
+
+ request_start_time = datetime.datetime.now().timestamp()
+ generated_responses = await super()._agenerate(prompts, stop, run_manager)
+ request_end_time = datetime.datetime.now().timestamp()
+ for i in range(len(prompts)):
+ prompt = prompts[i]
+ generation = generated_responses.generations[i][0]
+ resp = {
+ "text": generation.text,
+ "llm_output": generated_responses.llm_output,
+ }
+ params = {**self._identifying_params, **kwargs}
+ pl_request_id = await promptlayer_api_request_async(
+ "langchain.PromptLayerOpenAIChat.async",
+ "langchain",
+ [prompt],
+ params,
+ self.pl_tags,
+ resp,
+ request_start_time,
+ request_end_time,
+ get_api_key(),
+ return_pl_id=self.return_pl_id,
+ )
+ if self.return_pl_id:
+ if generation.generation_info is None or not isinstance(
+ generation.generation_info, dict
+ ):
+ generation.generation_info = {}
+ generation.generation_info["pl_request_id"] = pl_request_id
+ return generated_responses
diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py
new file mode 100644
index 00000000000..b086aed3489
--- /dev/null
+++ b/libs/community/langchain_community/llms/replicate.py
@@ -0,0 +1,222 @@
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+if TYPE_CHECKING:
+ from replicate.prediction import Prediction
+
+logger = logging.getLogger(__name__)
+
+
+class Replicate(LLM):
+ """Replicate models.
+
+ To use, you should have the ``replicate`` python package installed,
+ and the environment variable ``REPLICATE_API_TOKEN`` set with your API token.
+ You can find your token here: https://replicate.com/account
+
+ The model param is required, but any other model parameters can also
+ be passed in with the format model_kwargs={model_param: value, ...}
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Replicate
+
+ replicate = Replicate(
+ model=(
+ "stability-ai/stable-diffusion: "
+ "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
+ ),
+ model_kwargs={"image_dimensions": "512x512"}
+ )
+ """
+
+ model: str
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict, alias="input")
+ replicate_api_token: Optional[str] = None
+ prompt_key: Optional[str] = None
+ version_obj: Any = Field(default=None, exclude=True)
+ """Optionally pass in the model version object during initialization to avoid
+ having to make an extra API call to retrieve it during streaming. NOTE: not
+ serializable, is excluded from serialization.
+ """
+
+ streaming: bool = False
+ """Whether to stream the results."""
+
+ stop: List[str] = Field(default_factory=list)
+ """Stop sequences to early-terminate generation."""
+
+ class Config:
+ """Configuration for this pydantic config."""
+
+ allow_population_by_field_name = True
+ extra = Extra.forbid
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"replicate_api_token": "REPLICATE_API_TOKEN"}
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "llms", "replicate"]
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ input = values.pop("input", {})
+ if input:
+ logger.warning(
+ "Init param `input` is deprecated, please use `model_kwargs` instead."
+ )
+ extra = {**values.pop("model_kwargs", {}), **input}
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""{field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ replicate_api_token = get_from_dict_or_env(
+ values, "replicate_api_token", "REPLICATE_API_TOKEN"
+ )
+ values["replicate_api_token"] = replicate_api_token
+ return values
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model": self.model,
+ "model_kwargs": self.model_kwargs,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of model."""
+ return "replicate"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call to replicate endpoint."""
+ if self.streaming:
+ completion: Optional[str] = None
+ for chunk in self._stream(
+ prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ if completion is None:
+ completion = chunk.text
+ else:
+ completion += chunk.text
+ else:
+ prediction = self._create_prediction(prompt, **kwargs)
+ prediction.wait()
+ if prediction.status == "failed":
+ raise RuntimeError(prediction.error)
+ if isinstance(prediction.output, str):
+ completion = prediction.output
+ else:
+ completion = "".join(prediction.output)
+ assert completion is not None
+ stop_conditions = stop or self.stop
+ for s in stop_conditions:
+ if s in completion:
+ completion = completion[: completion.find(s)]
+ return completion
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ prediction = self._create_prediction(prompt, **kwargs)
+ stop_conditions = stop or self.stop
+ stop_condition_reached = False
+ current_completion: str = ""
+ for output in prediction.output_iterator():
+ current_completion += output
+ # test for stop conditions, if specified
+ for s in stop_conditions:
+ if s in current_completion:
+ prediction.cancel()
+ stop_condition_reached = True
+ # Potentially some tokens that should still be yielded before ending
+ # stream.
+ stop_index = max(output.find(s), 0)
+ output = output[:stop_index]
+ if not output:
+ break
+ if output:
+ yield GenerationChunk(text=output)
+ if run_manager:
+ run_manager.on_llm_new_token(
+ output,
+ verbose=self.verbose,
+ )
+ if stop_condition_reached:
+ break
+
+ def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction:
+ try:
+ import replicate as replicate_python
+ except ImportError:
+ raise ImportError(
+ "Could not import replicate python package. "
+ "Please install it with `pip install replicate`."
+ )
+
+ # get the model and version
+ if self.version_obj is None:
+ model_str, version_str = self.model.split(":")
+ model = replicate_python.models.get(model_str)
+ self.version_obj = model.versions.get(version_str)
+
+ if self.prompt_key is None:
+ # sort through the openapi schema to get the name of the first input
+ input_properties = sorted(
+ self.version_obj.openapi_schema["components"]["schemas"]["Input"][
+ "properties"
+ ].items(),
+ key=lambda item: item[1].get("x-order", 0),
+ )
+
+ self.prompt_key = input_properties[0][0]
+
+ input_: Dict = {
+ self.prompt_key: prompt,
+ **self.model_kwargs,
+ **kwargs,
+ }
+ return replicate_python.predictions.create(
+ version=self.version_obj, input=input_
+ )
diff --git a/libs/community/langchain_community/llms/rwkv.py b/libs/community/langchain_community/llms/rwkv.py
new file mode 100644
index 00000000000..470c2005372
--- /dev/null
+++ b/libs/community/langchain_community/llms/rwkv.py
@@ -0,0 +1,234 @@
+"""RWKV models.
+
+Based on https://github.com/saharNooby/rwkv.cpp/blob/master/rwkv/chat_with_bot.py
+ https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
+"""
+from typing import Any, Dict, List, Mapping, Optional, Set
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class RWKV(LLM, BaseModel):
+ """RWKV language models.
+
+ To use, you should have the ``rwkv`` python package installed, the
+ pre-trained model file, and the model's config information.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import RWKV
+ model = RWKV(model="./models/rwkv-3b-fp16.bin", strategy="cpu fp32")
+
+ # Simplest invocation
+ response = model("Once upon a time, ")
+ """
+
+ model: str
+ """Path to the pre-trained RWKV model file."""
+
+ tokens_path: str
+ """Path to the RWKV tokens file."""
+
+ strategy: str = "cpu fp32"
+ """Token context window."""
+
+ rwkv_verbose: bool = True
+ """Print debug information."""
+
+ temperature: float = 1.0
+ """The temperature to use for sampling."""
+
+ top_p: float = 0.5
+ """The top-p value to use for sampling."""
+
+ penalty_alpha_frequency: float = 0.4
+ """Positive values penalize new tokens based on their existing frequency
+ in the text so far, decreasing the model's likelihood to repeat the same
+ line verbatim.."""
+
+ penalty_alpha_presence: float = 0.4
+ """Positive values penalize new tokens based on whether they appear
+ in the text so far, increasing the model's likelihood to talk about
+ new topics.."""
+
+ CHUNK_LEN: int = 256
+ """Batch size for prompt processing."""
+
+ max_tokens_per_generation: int = 256
+ """Maximum number of tokens to generate."""
+
+ client: Any = None #: :meta private:
+
+ tokenizer: Any = None #: :meta private:
+
+ pipeline: Any = None #: :meta private:
+
+ model_tokens: Any = None #: :meta private:
+
+ model_state: Any = None #: :meta private:
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "verbose": self.verbose,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ "penalty_alpha_frequency": self.penalty_alpha_frequency,
+ "penalty_alpha_presence": self.penalty_alpha_presence,
+ "CHUNK_LEN": self.CHUNK_LEN,
+ "max_tokens_per_generation": self.max_tokens_per_generation,
+ }
+
+ @staticmethod
+ def _rwkv_param_names() -> Set[str]:
+ """Get the identifying parameters."""
+ return {
+ "verbose",
+ }
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in the environment."""
+ try:
+ import tokenizers
+ except ImportError:
+ raise ImportError(
+ "Could not import tokenizers python package. "
+ "Please install it with `pip install tokenizers`."
+ )
+ try:
+ from rwkv.model import RWKV as RWKVMODEL
+ from rwkv.utils import PIPELINE
+
+ values["tokenizer"] = tokenizers.Tokenizer.from_file(values["tokens_path"])
+
+ rwkv_keys = cls._rwkv_param_names()
+ model_kwargs = {k: v for k, v in values.items() if k in rwkv_keys}
+ model_kwargs["verbose"] = values["rwkv_verbose"]
+ values["client"] = RWKVMODEL(
+ values["model"], strategy=values["strategy"], **model_kwargs
+ )
+ values["pipeline"] = PIPELINE(values["client"], values["tokens_path"])
+
+ except ImportError:
+ raise ImportError(
+ "Could not import rwkv python package. "
+ "Please install it with `pip install rwkv`."
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model": self.model,
+ **self._default_params,
+ **{k: v for k, v in self.__dict__.items() if k in RWKV._rwkv_param_names()},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return the type of llm."""
+ return "rwkv"
+
+ def run_rnn(self, _tokens: List[str], newline_adj: int = 0) -> Any:
+ AVOID_REPEAT_TOKENS = []
+ AVOID_REPEAT = "οΌοΌοΌοΌ"
+ for i in AVOID_REPEAT:
+ dd = self.pipeline.encode(i)
+ assert len(dd) == 1
+ AVOID_REPEAT_TOKENS += dd
+
+ tokens = [int(x) for x in _tokens]
+ self.model_tokens += tokens
+
+ out: Any = None
+
+ while len(tokens) > 0:
+ out, self.model_state = self.client.forward(
+ tokens[: self.CHUNK_LEN], self.model_state
+ )
+ tokens = tokens[self.CHUNK_LEN :]
+ END_OF_LINE = 187
+ out[END_OF_LINE] += newline_adj # adjust \n probability
+
+ if self.model_tokens[-1] in AVOID_REPEAT_TOKENS:
+ out[self.model_tokens[-1]] = -999999999
+ return out
+
+ def rwkv_generate(self, prompt: str) -> str:
+ self.model_state = None
+ self.model_tokens = []
+ logits = self.run_rnn(self.tokenizer.encode(prompt).ids)
+ begin = len(self.model_tokens)
+ out_last = begin
+
+ occurrence: Dict = {}
+
+ decoded = ""
+ for i in range(self.max_tokens_per_generation):
+ for n in occurrence:
+ logits[n] -= (
+ self.penalty_alpha_presence
+ + occurrence[n] * self.penalty_alpha_frequency
+ )
+ token = self.pipeline.sample_logits(
+ logits, temperature=self.temperature, top_p=self.top_p
+ )
+
+ END_OF_TEXT = 0
+ if token == END_OF_TEXT:
+ break
+ if token not in occurrence:
+ occurrence[token] = 1
+ else:
+ occurrence[token] += 1
+
+ logits = self.run_rnn([token])
+ xxx = self.tokenizer.decode(self.model_tokens[out_last:])
+ if "\ufffd" not in xxx: # avoid utf-8 display issues
+ decoded += xxx
+ out_last = begin + i + 1
+ if i >= self.max_tokens_per_generation - 100:
+ break
+
+ return decoded
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ r"""RWKV generation
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: A list of strings to stop generation when encountered.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ prompt = "Once upon a time, "
+ response = model(prompt, n_predict=55)
+ """
+ text = self.rwkv_generate(prompt)
+
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/sagemaker_endpoint.py b/libs/community/langchain_community/llms/sagemaker_endpoint.py
new file mode 100644
index 00000000000..7ca76fc411b
--- /dev/null
+++ b/libs/community/langchain_community/llms/sagemaker_endpoint.py
@@ -0,0 +1,371 @@
+"""Sagemaker InvokeEndpoint API."""
+import io
+import json
+from abc import abstractmethod
+from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
+OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])
+
+
+class LineIterator:
+ """
+ A helper class for parsing the byte stream input.
+
+ The output of the model will be in the following format:
+
+ b'{"outputs": [" a"]}\n'
+ b'{"outputs": [" challenging"]}\n'
+ b'{"outputs": [" problem"]}\n'
+ ...
+
+ While usually each PayloadPart event from the event stream will
+ contain a byte array with a full json, this is not guaranteed
+ and some of the json objects may be split acrossPayloadPart events.
+
+ For example:
+
+ {'PayloadPart': {'Bytes': b'{"outputs": '}}
+ {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
+
+
+ This class accounts for this by concatenating bytes written via the 'write' function
+ and then exposing a method which will return lines (ending with a '\n' character)
+ within the buffer via the 'scan_lines' function.
+ It maintains the position of the last read position to ensure
+ that previous bytes are not exposed again.
+
+ For more details see:
+ https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
+ """
+
+ def __init__(self, stream: Any) -> None:
+ self.byte_iterator = iter(stream)
+ self.buffer = io.BytesIO()
+ self.read_pos = 0
+
+ def __iter__(self) -> "LineIterator":
+ return self
+
+ def __next__(self) -> Any:
+ while True:
+ self.buffer.seek(self.read_pos)
+ line = self.buffer.readline()
+ if line and line[-1] == ord("\n"):
+ self.read_pos += len(line)
+ return line[:-1]
+ try:
+ chunk = next(self.byte_iterator)
+ except StopIteration:
+ if self.read_pos < self.buffer.getbuffer().nbytes:
+ continue
+ raise
+ if "PayloadPart" not in chunk:
+ # Unknown Event Type
+ continue
+ self.buffer.seek(0, io.SEEK_END)
+ self.buffer.write(chunk["PayloadPart"]["Bytes"])
+
+
+class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
+ """A handler class to transform input from LLM to a
+ format that SageMaker endpoint expects.
+
+ Similarly, the class handles transforming output from the
+ SageMaker endpoint to a format that LLM class expects.
+ """
+
+ """
+ Example:
+ .. code-block:: python
+
+ class ContentHandler(ContentHandlerBase):
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
+ input_str = json.dumps({prompt: prompt, **model_kwargs})
+ return input_str.encode('utf-8')
+
+ def transform_output(self, output: bytes) -> str:
+ response_json = json.loads(output.read().decode("utf-8"))
+ return response_json[0]["generated_text"]
+ """
+
+ content_type: Optional[str] = "text/plain"
+ """The MIME type of the input data passed to endpoint"""
+
+ accepts: Optional[str] = "text/plain"
+ """The MIME type of the response data returned from endpoint"""
+
+ @abstractmethod
+ def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
+ """Transforms the input to a format that model can accept
+ as the request Body. Should return bytes or seekable file
+ like object in the format specified in the content_type
+ request header.
+ """
+
+ @abstractmethod
+ def transform_output(self, output: bytes) -> OUTPUT_TYPE:
+ """Transforms the output from the model to string that
+ the LLM class expects.
+ """
+
+
+class LLMContentHandler(ContentHandlerBase[str, str]):
+ """Content handler for LLM class."""
+
+
+class SagemakerEndpoint(LLM):
+ """Sagemaker Inference Endpoint models.
+
+ To use, you must supply the endpoint name from your deployed
+ Sagemaker model & the region where it is deployed.
+
+ To authenticate, the AWS client uses the following methods to
+ automatically load credentials:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+
+ If a specific credential profile should be used, you must pass
+ the name of the profile from the ~/.aws/credentials file that is to be used.
+
+ Make sure the credentials / roles used have the required policies to
+ access the Sagemaker endpoint.
+ See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
+ """
+
+ """
+ Args:
+
+ region_name: The aws region e.g., `us-west-2`.
+ Fallsback to AWS_DEFAULT_REGION env variable
+ or region specified in ~/.aws/config.
+
+ credentials_profile_name: The name of the profile in the ~/.aws/credentials
+ or ~/.aws/config files, which has either access keys or role information
+ specified. If not specified, the default credential profile or, if on an
+ EC2 instance, credentials from IMDS will be used.
+
+ client: boto3 client for Sagemaker Endpoint
+
+ content_handler: Implementation for model specific LLMContentHandler
+
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import SagemakerEndpoint
+ endpoint_name = (
+ "my-endpoint-name"
+ )
+ region_name = (
+ "us-west-2"
+ )
+ credentials_profile_name = (
+ "default"
+ )
+ se = SagemakerEndpoint(
+ endpoint_name=endpoint_name,
+ region_name=region_name,
+ credentials_profile_name=credentials_profile_name
+ )
+
+ #Use with boto3 client
+ client = boto3.client(
+ "sagemaker-runtime",
+ region_name=region_name
+ )
+
+ se = SagemakerEndpoint(
+ endpoint_name=endpoint_name,
+ client=client
+ )
+
+ """
+ client: Any = None
+ """Boto3 client for sagemaker runtime"""
+
+ endpoint_name: str = ""
+ """The name of the endpoint from the deployed Sagemaker model.
+ Must be unique within an AWS Region."""
+
+ region_name: str = ""
+ """The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
+
+ credentials_profile_name: Optional[str] = None
+ """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
+ has either access keys or role information specified.
+ If not specified, the default credential profile or, if on an EC2 instance,
+ credentials from IMDS will be used.
+ See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+ """
+
+ content_handler: LLMContentHandler
+ """The content handler class that provides an input and
+ output transform functions to handle formats between LLM
+ and the endpoint.
+ """
+
+ streaming: bool = False
+ """Whether to stream the results."""
+
+ """
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
+
+ class ContentHandler(LLMContentHandler):
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
+ input_str = json.dumps({prompt: prompt, **model_kwargs})
+ return input_str.encode('utf-8')
+
+ def transform_output(self, output: bytes) -> str:
+ response_json = json.loads(output.read().decode("utf-8"))
+ return response_json[0]["generated_text"]
+ """
+
+ model_kwargs: Optional[Dict] = None
+ """Keyword arguments to pass to the model."""
+
+ endpoint_kwargs: Optional[Dict] = None
+ """Optional attributes passed to the invoke_endpoint
+ function. See `boto3`_. docs for more info.
+ .. _boto3:
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Dont do anything if client provided externally"""
+ if values.get("client") is not None:
+ return values
+
+ """Validate that AWS credentials to and python package exists in environment."""
+ try:
+ import boto3
+
+ try:
+ if values["credentials_profile_name"] is not None:
+ session = boto3.Session(
+ profile_name=values["credentials_profile_name"]
+ )
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ values["client"] = session.client(
+ "sagemaker-runtime", region_name=values["region_name"]
+ )
+
+ except Exception as e:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+
+ except ImportError:
+ raise ImportError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ **{"endpoint_name": self.endpoint_name},
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "sagemaker_endpoint"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Sagemaker inference endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = se("Tell me a joke.")
+ """
+ _model_kwargs = self.model_kwargs or {}
+ _model_kwargs = {**_model_kwargs, **kwargs}
+ _endpoint_kwargs = self.endpoint_kwargs or {}
+
+ body = self.content_handler.transform_input(prompt, _model_kwargs)
+ content_type = self.content_handler.content_type
+ accepts = self.content_handler.accepts
+
+ if self.streaming and run_manager:
+ try:
+ resp = self.client.invoke_endpoint_with_response_stream(
+ EndpointName=self.endpoint_name,
+ Body=body,
+ ContentType=self.content_handler.content_type,
+ **_endpoint_kwargs,
+ )
+ iterator = LineIterator(resp["Body"])
+ current_completion: str = ""
+ for line in iterator:
+ resp = json.loads(line)
+ resp_output = resp.get("outputs")[0]
+ if stop is not None:
+ # Uses same approach as below
+ resp_output = enforce_stop_tokens(resp_output, stop)
+ current_completion += resp_output
+ run_manager.on_llm_new_token(resp_output)
+ return current_completion
+ except Exception as e:
+ raise ValueError(f"Error raised by streaming inference endpoint: {e}")
+ else:
+ try:
+ response = self.client.invoke_endpoint(
+ EndpointName=self.endpoint_name,
+ Body=body,
+ ContentType=content_type,
+ Accept=accepts,
+ **_endpoint_kwargs,
+ )
+ except Exception as e:
+ raise ValueError(f"Error raised by inference endpoint: {e}")
+
+ text = self.content_handler.transform_output(response["Body"])
+ if stop is not None:
+ # This is a bit hacky, but I can't figure out a better way to enforce
+ # stop tokens when making calls to the sagemaker endpoint.
+ text = enforce_stop_tokens(text, stop)
+
+ return text
diff --git a/libs/community/langchain_community/llms/self_hosted.py b/libs/community/langchain_community/llms/self_hosted.py
new file mode 100644
index 00000000000..043ffca136c
--- /dev/null
+++ b/libs/community/langchain_community/llms/self_hosted.py
@@ -0,0 +1,220 @@
+import importlib.util
+import logging
+import pickle
+from typing import Any, Callable, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+def _generate_text(
+ pipeline: Any,
+ prompt: str,
+ *args: Any,
+ stop: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> str:
+ """Inference function to send to the remote hardware.
+
+ Accepts a pipeline callable (or, more likely,
+ a key pointing to the model on the cluster's object store)
+ and returns text predictions for each document
+ in the batch.
+ """
+ text = pipeline(prompt, *args, **kwargs)
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
+
+
+def _send_pipeline_to_device(pipeline: Any, device: int) -> Any:
+ """Send a pipeline to a device on the cluster."""
+ if isinstance(pipeline, str):
+ with open(pipeline, "rb") as f:
+ pipeline = pickle.load(f)
+
+ if importlib.util.find_spec("torch") is not None:
+ import torch
+
+ cuda_device_count = torch.cuda.device_count()
+ if device < -1 or (device >= cuda_device_count):
+ raise ValueError(
+ f"Got device=={device}, "
+ f"device is required to be within [-1, {cuda_device_count})"
+ )
+ if device < 0 and cuda_device_count > 0:
+ logger.warning(
+ "Device has %d GPUs available. "
+ "Provide device={deviceId} to `from_model_id` to use available"
+ "GPUs for execution. deviceId is -1 for CPU and "
+ "can be a positive integer associated with CUDA device id.",
+ cuda_device_count,
+ )
+
+ pipeline.device = torch.device(device)
+ pipeline.model = pipeline.model.to(pipeline.device)
+ return pipeline
+
+
+class SelfHostedPipeline(LLM):
+ """Model inference on self-hosted remote hardware.
+
+ Supported hardware includes auto-launched instances on AWS, GCP, Azure,
+ and Lambda, as well as servers specified
+ by IP address and SSH credentials (such as on-prem, or another
+ cloud like Paperspace, Coreweave, etc.).
+
+ To use, you should have the ``runhouse`` python package installed.
+
+ Example for custom pipeline and inference functions:
+ .. code-block:: python
+
+ from langchain_community.llms import SelfHostedPipeline
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
+ import runhouse as rh
+
+ def load_pipeline():
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
+ return pipeline(
+ "text-generation", model=model, tokenizer=tokenizer,
+ max_new_tokens=10
+ )
+ def inference_fn(pipeline, prompt, stop = None):
+ return pipeline(prompt)[0]["generated_text"]
+
+ gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
+ llm = SelfHostedPipeline(
+ model_load_fn=load_pipeline,
+ hardware=gpu,
+ model_reqs=model_reqs, inference_fn=inference_fn
+ )
+ Example for <2GB model (can be serialized and sent directly to the server):
+ .. code-block:: python
+
+ from langchain_community.llms import SelfHostedPipeline
+ import runhouse as rh
+ gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
+ my_model = ...
+ llm = SelfHostedPipeline.from_pipeline(
+ pipeline=my_model,
+ hardware=gpu,
+ model_reqs=["./", "torch", "transformers"],
+ )
+ Example passing model path for larger models:
+ .. code-block:: python
+
+ from langchain_community.llms import SelfHostedPipeline
+ import runhouse as rh
+ import pickle
+ from transformers import pipeline
+
+ generator = pipeline(model="gpt2")
+ rh.blob(pickle.dumps(generator), path="models/pipeline.pkl"
+ ).save().to(gpu, path="models")
+ llm = SelfHostedPipeline.from_pipeline(
+ pipeline="models/pipeline.pkl",
+ hardware=gpu,
+ model_reqs=["./", "torch", "transformers"],
+ )
+ """
+
+ pipeline_ref: Any #: :meta private:
+ client: Any #: :meta private:
+ inference_fn: Callable = _generate_text #: :meta private:
+ """Inference function to send to the remote hardware."""
+ hardware: Any
+ """Remote hardware to send the inference function to."""
+ model_load_fn: Callable
+ """Function to load the model remotely on the server."""
+ load_fn_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model load function."""
+ model_reqs: List[str] = ["./", "torch"]
+ """Requirements to install on hardware to inference the model."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def __init__(self, **kwargs: Any):
+ """Init the pipeline with an auxiliary function.
+
+ The load function must be in global scope to be imported
+ and run on the server, i.e. in a module and not a REPL or closure.
+ Then, initialize the remote inference function.
+ """
+ super().__init__(**kwargs)
+ try:
+ import runhouse as rh
+
+ except ImportError:
+ raise ImportError(
+ "Could not import runhouse python package. "
+ "Please install it with `pip install runhouse`."
+ )
+
+ remote_load_fn = rh.function(fn=self.model_load_fn).to(
+ self.hardware, reqs=self.model_reqs
+ )
+ _load_fn_kwargs = self.load_fn_kwargs or {}
+ self.pipeline_ref = remote_load_fn.remote(**_load_fn_kwargs)
+
+ self.client = rh.function(fn=self.inference_fn).to(
+ self.hardware, reqs=self.model_reqs
+ )
+
+ @classmethod
+ def from_pipeline(
+ cls,
+ pipeline: Any,
+ hardware: Any,
+ model_reqs: Optional[List[str]] = None,
+ device: int = 0,
+ **kwargs: Any,
+ ) -> LLM:
+ """Init the SelfHostedPipeline from a pipeline object or string."""
+ if not isinstance(pipeline, str):
+ logger.warning(
+ "Serializing pipeline to send to remote hardware. "
+ "Note, it can be quite slow"
+ "to serialize and send large models with each execution. "
+ "Consider sending the pipeline"
+ "to the cluster and passing the path to the pipeline instead."
+ )
+
+ load_fn_kwargs = {"pipeline": pipeline, "device": device}
+ return cls(
+ load_fn_kwargs=load_fn_kwargs,
+ model_load_fn=_send_pipeline_to_device,
+ hardware=hardware,
+ model_reqs=["transformers", "torch"] + (model_reqs or []),
+ **kwargs,
+ )
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"hardware": self.hardware},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ return "self_hosted_llm"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ return self.client(
+ pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
+ )
diff --git a/libs/community/langchain_community/llms/self_hosted_hugging_face.py b/libs/community/langchain_community/llms/self_hosted_hugging_face.py
new file mode 100644
index 00000000000..465d74c6770
--- /dev/null
+++ b/libs/community/langchain_community/llms/self_hosted_hugging_face.py
@@ -0,0 +1,213 @@
+import importlib.util
+import logging
+from typing import Any, Callable, List, Mapping, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.pydantic_v1 import Extra
+
+from langchain_community.llms.self_hosted import SelfHostedPipeline
+from langchain_community.llms.utils import enforce_stop_tokens
+
+DEFAULT_MODEL_ID = "gpt2"
+DEFAULT_TASK = "text-generation"
+VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
+
+logger = logging.getLogger(__name__)
+
+
+def _generate_text(
+ pipeline: Any,
+ prompt: str,
+ *args: Any,
+ stop: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> str:
+ """Inference function to send to the remote hardware.
+
+ Accepts a Hugging Face pipeline (or more likely,
+ a key pointing to such a pipeline on the cluster's object store)
+ and returns generated text.
+ """
+ response = pipeline(prompt, *args, **kwargs)
+ if pipeline.task == "text-generation":
+ # Text generation return includes the starter text.
+ text = response[0]["generated_text"][len(prompt) :]
+ elif pipeline.task == "text2text-generation":
+ text = response[0]["generated_text"]
+ elif pipeline.task == "summarization":
+ text = response[0]["summary_text"]
+ else:
+ raise ValueError(
+ f"Got invalid task {pipeline.task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
+
+
+def _load_transformer(
+ model_id: str = DEFAULT_MODEL_ID,
+ task: str = DEFAULT_TASK,
+ device: int = 0,
+ model_kwargs: Optional[dict] = None,
+) -> Any:
+ """Inference function to send to the remote hardware.
+
+ Accepts a huggingface model_id and returns a pipeline for the task.
+ """
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
+ from transformers import pipeline as hf_pipeline
+
+ _model_kwargs = model_kwargs or {}
+ tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
+
+ try:
+ if task == "text-generation":
+ model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
+ elif task in ("text2text-generation", "summarization"):
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
+ else:
+ raise ValueError(
+ f"Got invalid task {task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ except ImportError as e:
+ raise ValueError(
+ f"Could not load the {task} model due to missing dependencies."
+ ) from e
+
+ if importlib.util.find_spec("torch") is not None:
+ import torch
+
+ cuda_device_count = torch.cuda.device_count()
+ if device < -1 or (device >= cuda_device_count):
+ raise ValueError(
+ f"Got device=={device}, "
+ f"device is required to be within [-1, {cuda_device_count})"
+ )
+ if device < 0 and cuda_device_count > 0:
+ logger.warning(
+ "Device has %d GPUs available. "
+ "Provide device={deviceId} to `from_model_id` to use available"
+ "GPUs for execution. deviceId is -1 for CPU and "
+ "can be a positive integer associated with CUDA device id.",
+ cuda_device_count,
+ )
+
+ pipeline = hf_pipeline(
+ task=task,
+ model=model,
+ tokenizer=tokenizer,
+ device=device,
+ model_kwargs=_model_kwargs,
+ )
+ if pipeline.task not in VALID_TASKS:
+ raise ValueError(
+ f"Got invalid task {pipeline.task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ return pipeline
+
+
+class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
+ """HuggingFace Pipeline API to run on self-hosted remote hardware.
+
+ Supported hardware includes auto-launched instances on AWS, GCP, Azure,
+ and Lambda, as well as servers specified
+ by IP address and SSH credentials (such as on-prem, or another cloud
+ like Paperspace, Coreweave, etc.).
+
+ To use, you should have the ``runhouse`` python package installed.
+
+ Only supports `text-generation`, `text2text-generation` and `summarization` for now.
+
+ Example using from_model_id:
+ .. code-block:: python
+
+ from langchain_community.llms import SelfHostedHuggingFaceLLM
+ import runhouse as rh
+ gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
+ hf = SelfHostedHuggingFaceLLM(
+ model_id="google/flan-t5-large", task="text2text-generation",
+ hardware=gpu
+ )
+ Example passing fn that generates a pipeline (bc the pipeline is not serializable):
+ .. code-block:: python
+
+ from langchain_community.llms import SelfHostedHuggingFaceLLM
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
+ import runhouse as rh
+
+ def get_pipeline():
+ model_id = "gpt2"
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ model = AutoModelForCausalLM.from_pretrained(model_id)
+ pipe = pipeline(
+ "text-generation", model=model, tokenizer=tokenizer
+ )
+ return pipe
+ hf = SelfHostedHuggingFaceLLM(
+ model_load_fn=get_pipeline, model_id="gpt2", hardware=gpu)
+ """
+
+ model_id: str = DEFAULT_MODEL_ID
+ """Hugging Face model_id to load the model."""
+ task: str = DEFAULT_TASK
+ """Hugging Face task ("text-generation", "text2text-generation" or
+ "summarization")."""
+ device: int = 0
+ """Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc."""
+ model_kwargs: Optional[dict] = None
+ """Keyword arguments to pass to the model."""
+ hardware: Any
+ """Remote hardware to send the inference function to."""
+ model_reqs: List[str] = ["./", "transformers", "torch"]
+ """Requirements to install on hardware to inference the model."""
+ model_load_fn: Callable = _load_transformer
+ """Function to load the model remotely on the server."""
+ inference_fn: Callable = _generate_text #: :meta private:
+ """Inference function to send to the remote hardware."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def __init__(self, **kwargs: Any):
+ """Construct the pipeline remotely using an auxiliary function.
+
+ The load function needs to be importable to be imported
+ and run on the server, i.e. in a module and not a REPL or closure.
+ Then, initialize the remote inference function.
+ """
+ load_fn_kwargs = {
+ "model_id": kwargs.get("model_id", DEFAULT_MODEL_ID),
+ "task": kwargs.get("task", DEFAULT_TASK),
+ "device": kwargs.get("device", 0),
+ "model_kwargs": kwargs.get("model_kwargs", None),
+ }
+ super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"model_id": self.model_id},
+ **{"model_kwargs": self.model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ return "selfhosted_huggingface_pipeline"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ return self.client(
+ pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
+ )
diff --git a/libs/community/langchain_community/llms/stochasticai.py b/libs/community/langchain_community/llms/stochasticai.py
new file mode 100644
index 00000000000..0b3637e9aff
--- /dev/null
+++ b/libs/community/langchain_community/llms/stochasticai.py
@@ -0,0 +1,137 @@
+import logging
+import time
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+logger = logging.getLogger(__name__)
+
+
+class StochasticAI(LLM):
+ """StochasticAI large language models.
+
+ To use, you should have the environment variable ``STOCHASTICAI_API_KEY``
+ set with your API key.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import StochasticAI
+ stochasticai = StochasticAI(api_url="")
+ """
+
+ api_url: str = ""
+ """Model name to use."""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `create` call not
+ explicitly specified."""
+
+ stochasticai_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
+
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name not in all_required_field_names:
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ logger.warning(
+ f"""{field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ stochasticai_api_key = get_from_dict_or_env(
+ values, "stochasticai_api_key", "STOCHASTICAI_API_KEY"
+ )
+ values["stochasticai_api_key"] = stochasticai_api_key
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"endpoint_url": self.api_url},
+ **{"model_kwargs": self.model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "stochasticai"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to StochasticAI's complete endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = StochasticAI("Tell me a joke.")
+ """
+ params = self.model_kwargs or {}
+ params = {**params, **kwargs}
+ response_post = requests.post(
+ url=self.api_url,
+ json={"prompt": prompt, "params": params},
+ headers={
+ "apiKey": f"{self.stochasticai_api_key}",
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ },
+ )
+ response_post.raise_for_status()
+ response_post_json = response_post.json()
+ completed = False
+ while not completed:
+ response_get = requests.get(
+ url=response_post_json["data"]["responseUrl"],
+ headers={
+ "apiKey": f"{self.stochasticai_api_key}",
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ },
+ )
+ response_get.raise_for_status()
+ response_get_json = response_get.json()["data"]
+ text = response_get_json.get("completion")
+ completed = text is not None
+ time.sleep(0.5)
+ text = text[0]
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the model parameters
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/symblai_nebula.py b/libs/community/langchain_community/llms/symblai_nebula.py
new file mode 100644
index 00000000000..afe6598f238
--- /dev/null
+++ b/libs/community/langchain_community/llms/symblai_nebula.py
@@ -0,0 +1,232 @@
+import json
+import logging
+from typing import Any, Callable, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+from requests import ConnectTimeout, ReadTimeout, RequestException
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+DEFAULT_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai"
+DEFAULT_NEBULA_SERVICE_PATH = "/v1/model/generate"
+
+logger = logging.getLogger(__name__)
+
+
+class Nebula(LLM):
+ """Nebula Service models.
+
+ To use, you should have the environment variable ``NEBULA_SERVICE_URL``,
+ ``NEBULA_SERVICE_PATH`` and ``NEBULA_API_KEY`` set with your Nebula
+ Service, or pass it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Nebula
+
+ nebula = Nebula(
+ nebula_service_url="NEBULA_SERVICE_URL",
+ nebula_service_path="NEBULA_SERVICE_PATH",
+ nebula_api_key="NEBULA_API_KEY",
+ )
+ """ # noqa: E501
+
+ """Key/value arguments to pass to the model. Reserved for future use"""
+ model_kwargs: Optional[dict] = None
+
+ """Optional"""
+
+ nebula_service_url: Optional[str] = None
+ nebula_service_path: Optional[str] = None
+ nebula_api_key: Optional[SecretStr] = None
+ model: Optional[str] = None
+ max_new_tokens: Optional[int] = 128
+ temperature: Optional[float] = 0.6
+ top_p: Optional[float] = 0.95
+ repetition_penalty: Optional[float] = 1.0
+ top_k: Optional[int] = 1
+ stop_sequences: Optional[List[str]] = None
+ max_retries: Optional[int] = 10
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ nebula_service_url = get_from_dict_or_env(
+ values,
+ "nebula_service_url",
+ "NEBULA_SERVICE_URL",
+ DEFAULT_NEBULA_SERVICE_URL,
+ )
+ nebula_service_path = get_from_dict_or_env(
+ values,
+ "nebula_service_path",
+ "NEBULA_SERVICE_PATH",
+ DEFAULT_NEBULA_SERVICE_PATH,
+ )
+ nebula_api_key = convert_to_secret_str(
+ get_from_dict_or_env(values, "nebula_api_key", "NEBULA_API_KEY", None)
+ )
+
+ if nebula_service_url.endswith("/"):
+ nebula_service_url = nebula_service_url[:-1]
+ if not nebula_service_path.startswith("/"):
+ nebula_service_path = "/" + nebula_service_path
+
+ values["nebula_service_url"] = nebula_service_url
+ values["nebula_service_path"] = nebula_service_path
+ values["nebula_api_key"] = nebula_api_key
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Cohere API."""
+ return {
+ "max_new_tokens": self.max_new_tokens,
+ "temperature": self.temperature,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ "repetition_penalty": self.repetition_penalty,
+ }
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ _model_kwargs = self.model_kwargs or {}
+ return {
+ "nebula_service_url": self.nebula_service_url,
+ "nebula_service_path": self.nebula_service_path,
+ **{"model_kwargs": _model_kwargs},
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "nebula"
+
+ def _invocation_params(
+ self, stop_sequences: Optional[List[str]], **kwargs: Any
+ ) -> dict:
+ params = self._default_params
+ if self.stop_sequences is not None and stop_sequences is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+ elif self.stop_sequences is not None:
+ params["stop_sequences"] = self.stop_sequences
+ else:
+ params["stop_sequences"] = stop_sequences
+ return {**params, **kwargs}
+
+ @staticmethod
+ def _process_response(response: Any, stop: Optional[List[str]]) -> str:
+ text = response["output"]["text"]
+ if stop:
+ text = enforce_stop_tokens(text, stop)
+ return text
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Nebula Service endpoint.
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ The string generated by the model.
+ Example:
+ .. code-block:: python
+ response = nebula("Tell me a joke.")
+ """
+ params = self._invocation_params(stop, **kwargs)
+ prompt = prompt.strip()
+
+ response = completion_with_retry(
+ self,
+ prompt=prompt,
+ params=params,
+ url=f"{self.nebula_service_url}{self.nebula_service_path}",
+ )
+ _stop = params.get("stop_sequences")
+ return self._process_response(response, _stop)
+
+
+def make_request(
+ self: Nebula,
+ prompt: str,
+ url: str = f"{DEFAULT_NEBULA_SERVICE_URL}{DEFAULT_NEBULA_SERVICE_PATH}",
+ params: Optional[Dict] = None,
+) -> Any:
+ """Generate text from the model."""
+ params = params or {}
+ api_key = None
+ if self.nebula_api_key is not None:
+ api_key = self.nebula_api_key.get_secret_value()
+ headers = {
+ "Content-Type": "application/json",
+ "ApiKey": f"{api_key}",
+ }
+
+ body = {"prompt": prompt}
+
+ # add params to body
+ for key, value in params.items():
+ body[key] = value
+
+ # make request
+ response = requests.post(url, headers=headers, json=body)
+
+ if response.status_code != 200:
+ raise Exception(
+ f"Request failed with status code {response.status_code}"
+ f" and message {response.text}"
+ )
+
+ return json.loads(response.text)
+
+
+def _create_retry_decorator(llm: Nebula) -> Callable[[Any], Any]:
+ min_seconds = 4
+ max_seconds = 10
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterward
+ max_retries = llm.max_retries if llm.max_retries is not None else 3
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(
+ retry_if_exception_type((RequestException, ConnectTimeout, ReadTimeout))
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def completion_with_retry(llm: Nebula, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator(llm)
+
+ @retry_decorator
+ def _completion_with_retry(**_kwargs: Any) -> Any:
+ return make_request(llm, **_kwargs)
+
+ return _completion_with_retry(**kwargs)
diff --git a/libs/community/langchain_community/llms/textgen.py b/libs/community/langchain_community/llms/textgen.py
new file mode 100644
index 00000000000..d9e569e8d0c
--- /dev/null
+++ b/libs/community/langchain_community/llms/textgen.py
@@ -0,0 +1,417 @@
+import json
+import logging
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
+
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import Field
+
+logger = logging.getLogger(__name__)
+
+
+class TextGen(LLM):
+ """text-generation-webui models.
+
+ To use, you should have the text-generation-webui installed, a model loaded,
+ and --api added as a command-line option.
+
+ Suggested installation, use one-click installer for your OS:
+ https://github.com/oobabooga/text-generation-webui#one-click-installers
+
+ Parameters below taken from text-generation-webui api example:
+ https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import TextGen
+ llm = TextGen(model_url="http://localhost:8500")
+ """
+
+ model_url: str
+ """The full URL to the textgen webui including http[s]://host:port """
+
+ preset: Optional[str] = None
+ """The preset to use in the textgen webui """
+
+ max_new_tokens: Optional[int] = 250
+ """The maximum number of tokens to generate."""
+
+ do_sample: bool = Field(True, alias="do_sample")
+ """Do sample"""
+
+ temperature: Optional[float] = 1.3
+ """Primary factor to control randomness of outputs. 0 = deterministic
+ (only the most likely token is used). Higher value = more randomness."""
+
+ top_p: Optional[float] = 0.1
+ """If not set to 1, select tokens with probabilities adding up to less than this
+ number. Higher value = higher range of possible random results."""
+
+ typical_p: Optional[float] = 1
+ """If not set to 1, select only tokens that are at least this much more likely to
+ appear than random tokens, given the prior text."""
+
+ epsilon_cutoff: Optional[float] = 0 # In units of 1e-4
+ """Epsilon cutoff"""
+
+ eta_cutoff: Optional[float] = 0 # In units of 1e-4
+ """ETA cutoff"""
+
+ repetition_penalty: Optional[float] = 1.18
+ """Exponential penalty factor for repeating prior tokens. 1 means no penalty,
+ higher value = less repetition, lower value = more repetition."""
+
+ top_k: Optional[float] = 40
+ """Similar to top_p, but select instead only the top_k most likely tokens.
+ Higher value = higher range of possible random results."""
+
+ min_length: Optional[int] = 0
+ """Minimum generation length in tokens."""
+
+ no_repeat_ngram_size: Optional[int] = 0
+ """If not set to 0, specifies the length of token sets that are completely blocked
+ from repeating at all. Higher values = blocks larger phrases,
+ lower values = blocks words or letters from repeating.
+ Only 0 or high values are a good idea in most cases."""
+
+ num_beams: Optional[int] = 1
+ """Number of beams"""
+
+ penalty_alpha: Optional[float] = 0
+ """Penalty Alpha"""
+
+ length_penalty: Optional[float] = 1
+ """Length Penalty"""
+
+ early_stopping: bool = Field(False, alias="early_stopping")
+ """Early stopping"""
+
+ seed: int = Field(-1, alias="seed")
+ """Seed (-1 for random)"""
+
+ add_bos_token: bool = Field(True, alias="add_bos_token")
+ """Add the bos_token to the beginning of prompts.
+ Disabling this can make the replies more creative."""
+
+ truncation_length: Optional[int] = 2048
+ """Truncate the prompt up to this length. The leftmost tokens are removed if
+ the prompt exceeds this length. Most models require this to be at most 2048."""
+
+ ban_eos_token: bool = Field(False, alias="ban_eos_token")
+ """Ban the eos_token. Forces the model to never end the generation prematurely."""
+
+ skip_special_tokens: bool = Field(True, alias="skip_special_tokens")
+ """Skip special tokens. Some specific models need this unset."""
+
+ stopping_strings: Optional[List[str]] = []
+ """A list of strings to stop generation when encountered."""
+
+ streaming: bool = False
+ """Whether to stream the results, token by token."""
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling textgen."""
+ return {
+ "max_new_tokens": self.max_new_tokens,
+ "do_sample": self.do_sample,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "typical_p": self.typical_p,
+ "epsilon_cutoff": self.epsilon_cutoff,
+ "eta_cutoff": self.eta_cutoff,
+ "repetition_penalty": self.repetition_penalty,
+ "top_k": self.top_k,
+ "min_length": self.min_length,
+ "no_repeat_ngram_size": self.no_repeat_ngram_size,
+ "num_beams": self.num_beams,
+ "penalty_alpha": self.penalty_alpha,
+ "length_penalty": self.length_penalty,
+ "early_stopping": self.early_stopping,
+ "seed": self.seed,
+ "add_bos_token": self.add_bos_token,
+ "truncation_length": self.truncation_length,
+ "ban_eos_token": self.ban_eos_token,
+ "skip_special_tokens": self.skip_special_tokens,
+ "stopping_strings": self.stopping_strings,
+ }
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_url": self.model_url}, **self._default_params}
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "textgen"
+
+ def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
+ """
+ Performs sanity check, preparing parameters in format needed by textgen.
+
+ Args:
+ stop (Optional[List[str]]): List of stop sequences for textgen.
+
+ Returns:
+ Dictionary containing the combined parameters.
+ """
+
+ # Raise error if stop sequences are in both input and default params
+ # if self.stop and stop is not None:
+ if self.stopping_strings and stop is not None:
+ raise ValueError("`stop` found in both the input and default params.")
+
+ if self.preset is None:
+ params = self._default_params
+ else:
+ params = {"preset": self.preset}
+
+ # then sets it as configured, or default to an empty list:
+ params["stopping_strings"] = self.stopping_strings or stop or []
+
+ return params
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the textgen web API and return the output.
+
+ Args:
+ prompt: The prompt to use for generation.
+ stop: A list of strings to stop generation when encountered.
+
+ Returns:
+ The generated text.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import TextGen
+ llm = TextGen(model_url="http://localhost:5000")
+ llm("Write a story about llamas.")
+ """
+ if self.streaming:
+ combined_text_output = ""
+ for chunk in self._stream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ combined_text_output += chunk.text
+ result = combined_text_output
+
+ else:
+ url = f"{self.model_url}/api/v1/generate"
+ params = self._get_parameters(stop)
+ request = params.copy()
+ request["prompt"] = prompt
+ response = requests.post(url, json=request)
+
+ if response.status_code == 200:
+ result = response.json()["results"][0]["text"]
+ else:
+ print(f"ERROR: response: {response}")
+ result = ""
+
+ return result
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the textgen web API and return the output.
+
+ Args:
+ prompt: The prompt to use for generation.
+ stop: A list of strings to stop generation when encountered.
+
+ Returns:
+ The generated text.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import TextGen
+ llm = TextGen(model_url="http://localhost:5000")
+ llm("Write a story about llamas.")
+ """
+ if self.streaming:
+ combined_text_output = ""
+ async for chunk in self._astream(
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ combined_text_output += chunk.text
+ result = combined_text_output
+
+ else:
+ url = f"{self.model_url}/api/v1/generate"
+ params = self._get_parameters(stop)
+ request = params.copy()
+ request["prompt"] = prompt
+ response = requests.post(url, json=request)
+
+ if response.status_code == 200:
+ result = response.json()["results"][0]["text"]
+ else:
+ print(f"ERROR: response: {response}")
+ result = ""
+
+ return result
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ """Yields results objects as they are generated in real time.
+
+ It also calls the callback manager's on_llm_new_token event with
+ similar parameters to the OpenAI LLM class method of the same name.
+
+ Args:
+ prompt: The prompts to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ A generator representing the stream of tokens being generated.
+
+ Yields:
+ A dictionary like objects containing a string token and metadata.
+ See text-generation-webui docs and below for more.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import TextGen
+ llm = TextGen(
+ model_url = "ws://localhost:5005"
+ streaming=True
+ )
+ for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
+ stop=["'","\n"]):
+ print(chunk, end='', flush=True)
+
+ """
+ try:
+ import websocket
+ except ImportError:
+ raise ImportError(
+ "The `websocket-client` package is required for streaming."
+ )
+
+ params = {**self._get_parameters(stop), **kwargs}
+
+ url = f"{self.model_url}/api/v1/stream"
+
+ request = params.copy()
+ request["prompt"] = prompt
+
+ websocket_client = websocket.WebSocket()
+
+ websocket_client.connect(url)
+
+ websocket_client.send(json.dumps(request))
+
+ while True:
+ result = websocket_client.recv()
+ result = json.loads(result)
+
+ if result["event"] == "text_stream":
+ chunk = GenerationChunk(
+ text=result["text"],
+ generation_info=None,
+ )
+ yield chunk
+ elif result["event"] == "stream_end":
+ websocket_client.close()
+ return
+
+ if run_manager:
+ run_manager.on_llm_new_token(token=chunk.text)
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ """Yields results objects as they are generated in real time.
+
+ It also calls the callback manager's on_llm_new_token event with
+ similar parameters to the OpenAI LLM class method of the same name.
+
+ Args:
+ prompt: The prompts to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ A generator representing the stream of tokens being generated.
+
+ Yields:
+ A dictionary like objects containing a string token and metadata.
+ See text-generation-webui docs and below for more.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import TextGen
+ llm = TextGen(
+ model_url = "ws://localhost:5005"
+ streaming=True
+ )
+ for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
+ stop=["'","\n"]):
+ print(chunk, end='', flush=True)
+
+ """
+ try:
+ import websocket
+ except ImportError:
+ raise ImportError(
+ "The `websocket-client` package is required for streaming."
+ )
+
+ params = {**self._get_parameters(stop), **kwargs}
+
+ url = f"{self.model_url}/api/v1/stream"
+
+ request = params.copy()
+ request["prompt"] = prompt
+
+ websocket_client = websocket.WebSocket()
+
+ websocket_client.connect(url)
+
+ websocket_client.send(json.dumps(request))
+
+ while True:
+ result = websocket_client.recv()
+ result = json.loads(result)
+
+ if result["event"] == "text_stream":
+ chunk = GenerationChunk(
+ text=result["text"],
+ generation_info=None,
+ )
+ yield chunk
+ elif result["event"] == "stream_end":
+ websocket_client.close()
+ return
+
+ if run_manager:
+ await run_manager.on_llm_new_token(token=chunk.text)
diff --git a/libs/community/langchain_community/llms/titan_takeoff.py b/libs/community/langchain_community/llms/titan_takeoff.py
new file mode 100644
index 00000000000..103a81b59c6
--- /dev/null
+++ b/libs/community/langchain_community/llms/titan_takeoff.py
@@ -0,0 +1,161 @@
+from typing import Any, Iterator, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from requests.exceptions import ConnectionError
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class TitanTakeoff(LLM):
+ """Wrapper around Titan Takeoff APIs."""
+
+ base_url: str = "http://localhost:8000"
+ """Specifies the baseURL to use for the Titan Takeoff API.
+ Default = http://localhost:8000.
+ """
+
+ generate_max_length: int = 128
+ """Maximum generation length. Default = 128."""
+
+ sampling_topk: int = 1
+ """Sample predictions from the top K most probable candidates. Default = 1."""
+
+ sampling_topp: float = 1.0
+ """Sample from predictions whose cumulative probability exceeds this value.
+ Default = 1.0.
+ """
+
+ sampling_temperature: float = 1.0
+ """Sample with randomness. Bigger temperatures are associated with
+ more randomness and 'creativity'. Default = 1.0.
+ """
+
+ repetition_penalty: float = 1.0
+ """Penalise the generation of tokens that have been generated before.
+ Set to > 1 to penalize. Default = 1 (no penalty).
+ """
+
+ no_repeat_ngram_size: int = 0
+ """Prevent repetitions of ngrams of this size. Default = 0 (turned off)."""
+
+ streaming: bool = False
+ """Whether to stream the output. Default = False."""
+
+ @property
+ def _default_params(self) -> Mapping[str, Any]:
+ """Get the default parameters for calling Titan Takeoff Server."""
+ params = {
+ "generate_max_length": self.generate_max_length,
+ "sampling_topk": self.sampling_topk,
+ "sampling_topp": self.sampling_topp,
+ "sampling_temperature": self.sampling_temperature,
+ "repetition_penalty": self.repetition_penalty,
+ "no_repeat_ngram_size": self.no_repeat_ngram_size,
+ }
+ return params
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "titan_takeoff"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Titan Takeoff generate endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ prompt = "What is the capital of the United Kingdom?"
+ response = model(prompt)
+
+ """
+ try:
+ if self.streaming:
+ text_output = ""
+ for chunk in self._stream(
+ prompt=prompt,
+ stop=stop,
+ run_manager=run_manager,
+ ):
+ text_output += chunk.text
+ return text_output
+
+ url = f"{self.base_url}/generate"
+ params = {"text": prompt, **self._default_params}
+
+ response = requests.post(url, json=params)
+ response.raise_for_status()
+ response.encoding = "utf-8"
+ text = ""
+
+ if "message" in response.json():
+ text = response.json()["message"]
+ else:
+ raise ValueError("Something went wrong.")
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
+ except ConnectionError:
+ raise ConnectionError(
+ "Could not connect to Titan Takeoff server. \
+ Please make sure that the server is running."
+ )
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ """Call out to Titan Takeoff stream endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Yields:
+ A dictionary like object containing a string token.
+
+ Example:
+ .. code-block:: python
+
+ prompt = "What is the capital of the United Kingdom?"
+ response = model(prompt)
+
+ """
+ url = f"{self.base_url}/generate_stream"
+ params = {"text": prompt, **self._default_params}
+
+ response = requests.post(url, json=params, stream=True)
+ response.encoding = "utf-8"
+ for text in response.iter_content(chunk_size=1, decode_unicode=True):
+ if text:
+ chunk = GenerationChunk(text=text)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(token=chunk.text)
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {"base_url": self.base_url, **{}, **self._default_params}
diff --git a/libs/community/langchain_community/llms/titan_takeoff_pro.py b/libs/community/langchain_community/llms/titan_takeoff_pro.py
new file mode 100644
index 00000000000..52679cc3e84
--- /dev/null
+++ b/libs/community/langchain_community/llms/titan_takeoff_pro.py
@@ -0,0 +1,215 @@
+from typing import Any, Iterator, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from requests.exceptions import ConnectionError
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class TitanTakeoffPro(LLM):
+ base_url: Optional[str] = "http://localhost:3000"
+ """Specifies the baseURL to use for the Titan Takeoff Pro API.
+ Default = http://localhost:3000.
+ """
+
+ max_new_tokens: Optional[int] = None
+ """Maximum tokens generated."""
+
+ min_new_tokens: Optional[int] = None
+ """Minimum tokens generated."""
+
+ sampling_topk: Optional[int] = None
+ """Sample predictions from the top K most probable candidates."""
+
+ sampling_topp: Optional[float] = None
+ """Sample from predictions whose cumulative probability exceeds this value.
+ """
+
+ sampling_temperature: Optional[float] = None
+ """Sample with randomness. Bigger temperatures are associated with
+ more randomness and 'creativity'.
+ """
+
+ repetition_penalty: Optional[float] = None
+ """Penalise the generation of tokens that have been generated before.
+ Set to > 1 to penalize.
+ """
+
+ regex_string: Optional[str] = None
+ """A regex string for constrained generation."""
+
+ no_repeat_ngram_size: Optional[int] = None
+ """Prevent repetitions of ngrams of this size. Default = 0 (turned off)."""
+
+ streaming: bool = False
+ """Whether to stream the output. Default = False."""
+
+ @property
+ def _default_params(self) -> Mapping[str, Any]:
+ """Get the default parameters for calling Titan Takeoff Server (Pro)."""
+ return {
+ **(
+ {"regex_string": self.regex_string}
+ if self.regex_string is not None
+ else {}
+ ),
+ **(
+ {"sampling_temperature": self.sampling_temperature}
+ if self.sampling_temperature is not None
+ else {}
+ ),
+ **(
+ {"sampling_topp": self.sampling_topp}
+ if self.sampling_topp is not None
+ else {}
+ ),
+ **(
+ {"repetition_penalty": self.repetition_penalty}
+ if self.repetition_penalty is not None
+ else {}
+ ),
+ **(
+ {"max_new_tokens": self.max_new_tokens}
+ if self.max_new_tokens is not None
+ else {}
+ ),
+ **(
+ {"min_new_tokens": self.min_new_tokens}
+ if self.min_new_tokens is not None
+ else {}
+ ),
+ **(
+ {"sampling_topk": self.sampling_topk}
+ if self.sampling_topk is not None
+ else {}
+ ),
+ **(
+ {"no_repeat_ngram_size": self.no_repeat_ngram_size}
+ if self.no_repeat_ngram_size is not None
+ else {}
+ ),
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "titan_takeoff_pro"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Titan Takeoff (Pro) generate endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ prompt = "What is the capital of the United Kingdom?"
+ response = model(prompt)
+
+ """
+ try:
+ if self.streaming:
+ text_output = ""
+ for chunk in self._stream(
+ prompt=prompt,
+ stop=stop,
+ run_manager=run_manager,
+ ):
+ text_output += chunk.text
+ return text_output
+ url = f"{self.base_url}/generate"
+ params = {"text": prompt, **self._default_params}
+
+ response = requests.post(url, json=params)
+ response.raise_for_status()
+ response.encoding = "utf-8"
+
+ text = ""
+ if "text" in response.json():
+ text = response.json()["text"]
+ text = text.replace("", "")
+ else:
+ raise ValueError("Something went wrong.")
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
+ except ConnectionError:
+ raise ConnectionError(
+ "Could not connect to Titan Takeoff (Pro) server. \
+ Please make sure that the server is running."
+ )
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ """Call out to Titan Takeoff (Pro) stream endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Yields:
+ A dictionary like object containing a string token.
+
+ Example:
+ .. code-block:: python
+
+ prompt = "What is the capital of the United Kingdom?"
+ response = model(prompt)
+
+ """
+ url = f"{self.base_url}/generate_stream"
+ params = {"text": prompt, **self._default_params}
+
+ response = requests.post(url, json=params, stream=True)
+ response.encoding = "utf-8"
+ buffer = ""
+ for text in response.iter_content(chunk_size=1, decode_unicode=True):
+ buffer += text
+ if "data:" in buffer:
+ # Remove the first instance of "data:" from the buffer.
+ if buffer.startswith("data:"):
+ buffer = ""
+ if len(buffer.split("data:", 1)) == 2:
+ content, _ = buffer.split("data:", 1)
+ buffer = content.rstrip("\n")
+ # Trim the buffer to only have content after the "data:" part.
+ if buffer: # Ensure that there's content to process.
+ chunk = GenerationChunk(text=buffer)
+ buffer = "" # Reset buffer for the next set of data.
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(token=chunk.text)
+
+ # Yield any remaining content in the buffer.
+ if buffer:
+ chunk = GenerationChunk(text=buffer.replace("", ""))
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(token=chunk.text)
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {"base_url": self.base_url, **{}, **self._default_params}
diff --git a/libs/community/langchain_community/llms/together.py b/libs/community/langchain_community/llms/together.py
new file mode 100644
index 00000000000..08e7c79f35c
--- /dev/null
+++ b/libs/community/langchain_community/llms/together.py
@@ -0,0 +1,206 @@
+"""Wrapper around Together AI's Completion API."""
+import logging
+from typing import Any, Dict, List, Optional
+
+from aiohttp import ClientSession
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.utilities.requests import Requests
+
+logger = logging.getLogger(__name__)
+
+
+class Together(LLM):
+ """Wrapper around Together AI models.
+
+ To use, you'll need an API key which you can find here:
+ https://api.together.xyz/settings/api-keys. This can be passed in as init param
+ ``together_api_key`` or set as environment variable ``TOGETHER_API_KEY``.
+
+ Together AI API reference: https://docs.together.ai/reference/inference
+ """
+
+ base_url: str = "https://api.together.xyz/inference"
+ """Base inference API URL."""
+ together_api_key: SecretStr
+ """Together AI API key. Get it here: https://api.together.xyz/settings/api-keys"""
+ model: str
+ """Model name. Available models listed here:
+ https://docs.together.ai/docs/inference-models
+ """
+ temperature: Optional[float] = None
+ """Model temperature."""
+ top_p: Optional[float] = None
+ """Used to dynamically adjust the number of choices for each predicted token based
+ on the cumulative probabilities. A value of 1 will always yield the same
+ output. A temperature less than 1 favors more correctness and is appropriate
+ for question answering or summarization. A value greater than 1 introduces more
+ randomness in the output.
+ """
+ top_k: Optional[int] = None
+ """Used to limit the number of choices for the next predicted word or token. It
+ specifies the maximum number of tokens to consider at each step, based on their
+ probability of occurrence. This technique helps to speed up the generation
+ process and can improve the quality of the generated text by focusing on the
+ most likely options.
+ """
+ max_tokens: Optional[int] = None
+ """The maximum number of tokens to generate."""
+ repetition_penalty: Optional[float] = None
+ """A number that controls the diversity of generated text by reducing the
+ likelihood of repeated sequences. Higher values decrease repetition.
+ """
+ logprobs: Optional[int] = None
+ """An integer that specifies how many top token log probabilities are included in
+ the response for each token generation step.
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ values["together_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
+ )
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of model."""
+ return "together"
+
+ def _format_output(self, output: dict) -> str:
+ return output["output"]["choices"][0]["text"]
+
+ @staticmethod
+ def get_user_agent() -> str:
+ from langchain_community import __version__
+
+ return f"langchain/{__version__}"
+
+ @property
+ def default_params(self) -> Dict[str, Any]:
+ return {
+ "model": self.model,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "max_tokens": self.max_tokens,
+ "repetition_penalty": self.repetition_penalty,
+ }
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Together's text generation endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+
+ Returns:
+ The string generated by the model..
+ """
+
+ headers = {
+ "Authorization": f"Bearer {self.together_api_key.get_secret_value()}",
+ "Content-Type": "application/json",
+ }
+ stop_to_use = stop[0] if stop and len(stop) == 1 else stop
+ payload: Dict[str, Any] = {
+ **self.default_params,
+ "prompt": prompt,
+ "stop": stop_to_use,
+ **kwargs,
+ }
+
+ # filter None values to not pass them to the http payload
+ payload = {k: v for k, v in payload.items() if v is not None}
+ request = Requests(headers=headers)
+ response = request.post(url=self.base_url, data=payload)
+
+ if response.status_code >= 500:
+ raise Exception(f"Together Server: Error {response.status_code}")
+ elif response.status_code >= 400:
+ raise ValueError(f"Together received an invalid payload: {response.text}")
+ elif response.status_code != 200:
+ raise Exception(
+ f"Together returned an unexpected response with status "
+ f"{response.status_code}: {response.text}"
+ )
+
+ data = response.json()
+ if data.get("status") != "finished":
+ err_msg = data.get("error", "Undefined Error")
+ raise Exception(err_msg)
+
+ output = self._format_output(data)
+
+ return output
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call Together model to get predictions based on the prompt.
+
+ Args:
+ prompt: The prompt to pass into the model.
+
+ Returns:
+ The string generated by the model.
+ """
+ headers = {
+ "Authorization": f"Bearer {self.together_api_key.get_secret_value()}",
+ "Content-Type": "application/json",
+ }
+ stop_to_use = stop[0] if stop and len(stop) == 1 else stop
+ payload: Dict[str, Any] = {
+ **self.default_params,
+ "prompt": prompt,
+ "stop": stop_to_use,
+ **kwargs,
+ }
+
+ # filter None values to not pass them to the http payload
+ payload = {k: v for k, v in payload.items() if v is not None}
+ async with ClientSession() as session:
+ async with session.post(
+ self.base_url, json=payload, headers=headers
+ ) as response:
+ if response.status >= 500:
+ raise Exception(f"Together Server: Error {response.status}")
+ elif response.status >= 400:
+ raise ValueError(
+ f"Together received an invalid payload: {response.text}"
+ )
+ elif response.status != 200:
+ raise Exception(
+ f"Together returned an unexpected response with status "
+ f"{response.status}: {response.text}"
+ )
+
+ response_json = await response.json()
+
+ if response_json.get("status") != "finished":
+ err_msg = response_json.get("error", "Undefined Error")
+ raise Exception(err_msg)
+
+ output = self._format_output(response_json)
+ return output
diff --git a/libs/community/langchain_community/llms/tongyi.py b/libs/community/langchain_community/llms/tongyi.py
new file mode 100644
index 00000000000..586c6afa3f5
--- /dev/null
+++ b/libs/community/langchain_community/llms/tongyi.py
@@ -0,0 +1,277 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Callable, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import Generation, LLMResult
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from requests.exceptions import HTTPError
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _create_retry_decorator(llm: Tongyi) -> Callable[[Any], Any]:
+ min_seconds = 1
+ max_seconds = 4
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(llm.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(retry_if_exception_type(HTTPError)),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+
+def generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator(llm)
+
+ @retry_decorator
+ def _generate_with_retry(**_kwargs: Any) -> Any:
+ resp = llm.client.call(**_kwargs)
+ if resp.status_code == 200:
+ return resp
+ elif resp.status_code in [400, 401]:
+ raise ValueError(
+ f"status_code: {resp.status_code} \n "
+ f"code: {resp.code} \n message: {resp.message}"
+ )
+ else:
+ raise HTTPError(
+ f"HTTP error occurred: status_code: {resp.status_code} \n "
+ f"code: {resp.code} \n message: {resp.message}",
+ response=resp,
+ )
+
+ return _generate_with_retry(**kwargs)
+
+
+def stream_generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = _create_retry_decorator(llm)
+
+ @retry_decorator
+ def _stream_generate_with_retry(**_kwargs: Any) -> Any:
+ stream_resps = []
+ resps = llm.client.call(**_kwargs)
+ for resp in resps:
+ if resp.status_code == 200:
+ stream_resps.append(resp)
+ elif resp.status_code in [400, 401]:
+ raise ValueError(
+ f"status_code: {resp.status_code} \n "
+ f"code: {resp.code} \n message: {resp.message}"
+ )
+ else:
+ raise HTTPError(
+ f"HTTP error occurred: status_code: {resp.status_code} \n "
+ f"code: {resp.code} \n message: {resp.message}",
+ response=resp,
+ )
+ return stream_resps
+
+ return _stream_generate_with_retry(**kwargs)
+
+
+class Tongyi(LLM):
+ """Tongyi Qwen large language models.
+
+ To use, you should have the ``dashscope`` python package installed, and the
+ environment variable ``DASHSCOPE_API_KEY`` set with your API key, or pass
+ it as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Tongyi
+ Tongyi = tongyi()
+ """
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {"dashscope_api_key": "DASHSCOPE_API_KEY"}
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ client: Any #: :meta private:
+ model_name: str = "qwen-plus-v1"
+
+ """Model name to use."""
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+
+ top_p: float = 0.8
+ """Total probability mass of tokens to consider at each step."""
+
+ dashscope_api_key: Optional[str] = None
+ """Dashscope api key provide by alicloud."""
+
+ n: int = 1
+ """How many completions to generate for each prompt."""
+
+ streaming: bool = False
+ """Whether to stream the results or not."""
+
+ max_retries: int = 10
+ """Maximum number of retries to make when generating."""
+
+ prefix_messages: List = Field(default_factory=list)
+ """Series of messages for Chat input."""
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "tongyi"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ get_from_dict_or_env(values, "dashscope_api_key", "DASHSCOPE_API_KEY")
+ try:
+ import dashscope
+ except ImportError:
+ raise ImportError(
+ "Could not import dashscope python package. "
+ "Please install it with `pip install dashscope`."
+ )
+ try:
+ values["client"] = dashscope.Generation
+ except AttributeError:
+ raise ValueError(
+ "`dashscope` has no `Generation` attribute, this is likely "
+ "due to an old version of the dashscope package. Try upgrading it "
+ "with `pip install --upgrade dashscope`."
+ )
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling OpenAI API."""
+ normal_params = {
+ "top_p": self.top_p,
+ }
+
+ return {**normal_params, **self.model_kwargs}
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Tongyi's generate endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = tongyi("Tell me a joke.")
+ """
+ params: Dict[str, Any] = {
+ **{"model": self.model_name},
+ **self._default_params,
+ **kwargs,
+ }
+
+ completion = generate_with_retry(
+ self,
+ prompt=prompt,
+ **params,
+ )
+ return completion["output"]["text"]
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ generations = []
+ params: Dict[str, Any] = {
+ **{"model": self.model_name},
+ **self._default_params,
+ **kwargs,
+ }
+ if self.streaming:
+ if len(prompts) > 1:
+ raise ValueError("Cannot stream results with multiple prompts.")
+ params["stream"] = True
+ temp = ""
+ for stream_resp in stream_generate_with_retry(
+ self, prompt=prompts[0], **params
+ ):
+ if run_manager:
+ stream_resp_text = stream_resp["output"]["text"]
+ stream_resp_text = stream_resp_text.replace(temp, "")
+ # Ali Cloud's streaming transmission interface, each return content
+ # will contain the output
+ # of the previous round(as of September 20, 2023, future updates to
+ # the Alibaba Cloud API may vary)
+ run_manager.on_llm_new_token(stream_resp_text)
+ # The implementation of streaming transmission primarily relies on
+ # the "on_llm_new_token" method
+ # of the streaming callback.
+ temp = stream_resp["output"]["text"]
+
+ generations.append(
+ [
+ Generation(
+ text=stream_resp["output"]["text"],
+ generation_info=dict(
+ finish_reason=stream_resp["output"]["finish_reason"],
+ ),
+ )
+ ]
+ )
+ generations.reverse()
+ # In the official implementation of the OpenAI API,
+ # the "generations" parameter passed to LLMResult seems to be a 1*1*1
+ # two-dimensional list
+ # (including in non-streaming mode).
+ # Considering that Alibaba Cloud's streaming transmission
+ # (as of September 20, 2023, future updates to the Alibaba Cloud API may
+ # vary)
+ # includes the output of the previous round in each return,
+ # reversing this "generations" list should suffice
+ # (This is the solution with the least amount of changes to the source code,
+ # while still allowing for convenient modifications in the future,
+ # although it may result in slightly more memory consumption).
+ else:
+ for prompt in prompts:
+ completion = generate_with_retry(
+ self,
+ prompt=prompt,
+ **params,
+ )
+ generations.append(
+ [
+ Generation(
+ text=completion["output"]["text"],
+ generation_info=dict(
+ finish_reason=completion["output"]["finish_reason"],
+ ),
+ )
+ ]
+ )
+ return LLMResult(generations=generations)
diff --git a/libs/community/langchain_community/llms/utils.py b/libs/community/langchain_community/llms/utils.py
new file mode 100644
index 00000000000..b69c759eeae
--- /dev/null
+++ b/libs/community/langchain_community/llms/utils.py
@@ -0,0 +1,8 @@
+"""Common utility functions for LLM APIs."""
+import re
+from typing import List
+
+
+def enforce_stop_tokens(text: str, stop: List[str]) -> str:
+ """Cut off the text as soon as any stop words occur."""
+ return re.split("|".join(stop), text, maxsplit=1)[0]
diff --git a/libs/community/langchain_community/llms/vertexai.py b/libs/community/langchain_community/llms/vertexai.py
new file mode 100644
index 00000000000..77b3c70f9c8
--- /dev/null
+++ b/libs/community/langchain_community/llms/vertexai.py
@@ -0,0 +1,491 @@
+from __future__ import annotations
+
+from concurrent.futures import Executor, ThreadPoolExecutor
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ ClassVar,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import BaseLLM
+from langchain_core.outputs import Generation, GenerationChunk, LLMResult
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+
+from langchain_community.utilities.vertexai import (
+ create_retry_decorator,
+ get_client_info,
+ init_vertexai,
+ raise_vertex_import_error,
+)
+
+if TYPE_CHECKING:
+ from google.cloud.aiplatform.gapic import (
+ PredictionServiceAsyncClient,
+ PredictionServiceClient,
+ )
+ from google.cloud.aiplatform.models import Prediction
+ from google.protobuf.struct_pb2 import Value
+ from vertexai.language_models._language_models import (
+ TextGenerationResponse,
+ _LanguageModel,
+ )
+
+
+def _response_to_generation(
+ response: TextGenerationResponse,
+) -> GenerationChunk:
+ """Convert a stream response to a generation chunk."""
+ try:
+ generation_info = {
+ "is_blocked": response.is_blocked,
+ "safety_attributes": response.safety_attributes,
+ }
+ except Exception:
+ generation_info = None
+ return GenerationChunk(text=response.text, generation_info=generation_info)
+
+
+def is_codey_model(model_name: str) -> bool:
+ """Returns True if the model name is a Codey model.
+
+ Args:
+ model_name: The model name to check.
+
+ Returns: True if the model name is a Codey model.
+ """
+ return "code" in model_name
+
+
+def completion_with_retry(
+ llm: VertexAI,
+ *args: Any,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
+
+ @retry_decorator
+ def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
+ return llm.client.predict(*args, **kwargs)
+
+ return _completion_with_retry(*args, **kwargs)
+
+
+def stream_completion_with_retry(
+ llm: VertexAI,
+ *args: Any,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = create_retry_decorator(
+ llm, max_retries=llm.max_retries, run_manager=run_manager
+ )
+
+ @retry_decorator
+ def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
+ return llm.client.predict_streaming(*args, **kwargs)
+
+ return _completion_with_retry(*args, **kwargs)
+
+
+async def acompletion_with_retry(
+ llm: VertexAI,
+ *args: Any,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
+
+ @retry_decorator
+ async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
+ return await llm.client.predict_async(*args, **kwargs)
+
+ return await _acompletion_with_retry(*args, **kwargs)
+
+
+class _VertexAIBase(BaseModel):
+ project: Optional[str] = None
+ "The default GCP project to use when making Vertex API calls."
+ location: str = "us-central1"
+ "The default location to use when making API calls."
+ request_parallelism: int = 5
+ "The amount of parallelism allowed for requests issued to VertexAI models. "
+ "Default is 5."
+ max_retries: int = 6
+ """The maximum number of retries to make when generating."""
+ task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
+ stop: Optional[List[str]] = None
+ "Optional list of stop words to use when generating."
+ model_name: Optional[str] = None
+ "Underlying model name."
+
+ @classmethod
+ def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
+ if cls.task_executor is None:
+ cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
+ return cls.task_executor
+
+
+class _VertexAICommon(_VertexAIBase):
+ client: "_LanguageModel" = None #: :meta private:
+ client_preview: "_LanguageModel" = None #: :meta private:
+ model_name: str
+ "Underlying model name."
+ temperature: float = 0.0
+ "Sampling temperature, it controls the degree of randomness in token selection."
+ max_output_tokens: int = 128
+ "Token limit determines the maximum amount of text output from one prompt."
+ top_p: float = 0.95
+ "Tokens are selected from most probable to least until the sum of their "
+ "probabilities equals the top-p value. Top-p is ignored for Codey models."
+ top_k: int = 40
+ "How the model selects tokens for output, the next token is selected from "
+ "among the top-k most probable tokens. Top-k is ignored for Codey models."
+ credentials: Any = Field(default=None, exclude=True)
+ "The default custom credentials (google.auth.credentials.Credentials) to use "
+ "when making API calls. If not provided, credentials will be ascertained from "
+ "the environment."
+ n: int = 1
+ """How many completions to generate for each prompt."""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+
+ @property
+ def _llm_type(self) -> str:
+ return "vertexai"
+
+ @property
+ def is_codey_model(self) -> bool:
+ return is_codey_model(self.model_name)
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get the identifying parameters."""
+ return {**{"model_name": self.model_name}, **self._default_params}
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ params = {
+ "temperature": self.temperature,
+ "max_output_tokens": self.max_output_tokens,
+ "candidate_count": self.n,
+ }
+ if not self.is_codey_model:
+ params.update(
+ {
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ }
+ )
+ return params
+
+ @classmethod
+ def _try_init_vertexai(cls, values: Dict) -> None:
+ allowed_params = ["project", "location", "credentials"]
+ params = {k: v for k, v in values.items() if k in allowed_params}
+ init_vertexai(**params)
+ return None
+
+ def _prepare_params(
+ self,
+ stop: Optional[List[str]] = None,
+ stream: bool = False,
+ **kwargs: Any,
+ ) -> dict:
+ stop_sequences = stop or self.stop
+ params_mapping = {"n": "candidate_count"}
+ params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
+ params = {**self._default_params, "stop_sequences": stop_sequences, **params}
+ if stream or self.streaming:
+ params.pop("candidate_count")
+ return params
+
+
+class VertexAI(_VertexAICommon, BaseLLM):
+ """Google Vertex AI large language models."""
+
+ model_name: str = "text-bison"
+ "The name of the Vertex AI large language model."
+ tuned_model_name: Optional[str] = None
+ "The name of a tuned model. If provided, model_name is ignored."
+
+ @classmethod
+ def is_lc_serializable(self) -> bool:
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "llms", "vertexai"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in environment."""
+ cls._try_init_vertexai(values)
+ tuned_model_name = values.get("tuned_model_name")
+ model_name = values["model_name"]
+ try:
+ from vertexai.language_models import (
+ CodeGenerationModel,
+ TextGenerationModel,
+ )
+ from vertexai.preview.language_models import (
+ CodeGenerationModel as PreviewCodeGenerationModel,
+ )
+ from vertexai.preview.language_models import (
+ TextGenerationModel as PreviewTextGenerationModel,
+ )
+
+ if is_codey_model(model_name):
+ model_cls = CodeGenerationModel
+ preview_model_cls = PreviewCodeGenerationModel
+ else:
+ model_cls = TextGenerationModel
+ preview_model_cls = PreviewTextGenerationModel
+
+ if tuned_model_name:
+ values["client"] = model_cls.get_tuned_model(tuned_model_name)
+ values["client_preview"] = preview_model_cls.get_tuned_model(
+ tuned_model_name
+ )
+ else:
+ values["client"] = model_cls.from_pretrained(model_name)
+ values["client_preview"] = preview_model_cls.from_pretrained(model_name)
+
+ except ImportError:
+ raise_vertex_import_error()
+
+ if values["streaming"] and values["n"] > 1:
+ raise ValueError("Only one candidate can be generated with streaming!")
+ return values
+
+ def get_num_tokens(self, text: str) -> int:
+ """Get the number of tokens present in the text.
+
+ Useful for checking if an input will fit in a model's context window.
+
+ Args:
+ text: The string input to tokenize.
+
+ Returns:
+ The integer number of tokens in the text.
+ """
+ try:
+ result = self.client_preview.count_tokens([text])
+ except AttributeError:
+ raise_vertex_import_error()
+
+ return result.total_tokens
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ should_stream = stream if stream is not None else self.streaming
+ params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
+ generations = []
+ for prompt in prompts:
+ if should_stream:
+ generation = GenerationChunk(text="")
+ for chunk in self._stream(
+ prompt, stop=stop, run_manager=run_manager, **kwargs
+ ):
+ generation += chunk
+ generations.append([generation])
+ else:
+ res = completion_with_retry(
+ self, prompt, run_manager=run_manager, **params
+ )
+ generations.append([_response_to_generation(r) for r in res.candidates])
+ return LLMResult(generations=generations)
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ params = self._prepare_params(stop=stop, **kwargs)
+ generations = []
+ for prompt in prompts:
+ res = await acompletion_with_retry(
+ self, prompt, run_manager=run_manager, **params
+ )
+ generations.append([_response_to_generation(r) for r in res.candidates])
+ return LLMResult(generations=generations)
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ params = self._prepare_params(stop=stop, stream=True, **kwargs)
+ for stream_resp in stream_completion_with_retry(
+ self, prompt, run_manager=run_manager, **params
+ ):
+ chunk = _response_to_generation(stream_resp)
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(
+ chunk.text,
+ chunk=chunk,
+ verbose=self.verbose,
+ )
+
+
+class VertexAIModelGarden(_VertexAIBase, BaseLLM):
+ """Large language models served from Vertex AI Model Garden."""
+
+ client: "PredictionServiceClient" = None #: :meta private:
+ async_client: "PredictionServiceAsyncClient" = None #: :meta private:
+ endpoint_id: str
+ "A name of an endpoint where the model has been deployed."
+ allowed_model_args: Optional[List[str]] = None
+ "Allowed optional args to be passed to the model."
+ prompt_arg: str = "prompt"
+ result_arg: Optional[str] = "generated_text"
+ "Set result_arg to None if output of the model is expected to be a string."
+ "Otherwise, if it's a dict, provided an argument that contains the result."
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in environment."""
+ try:
+ from google.api_core.client_options import ClientOptions
+ from google.cloud.aiplatform.gapic import (
+ PredictionServiceAsyncClient,
+ PredictionServiceClient,
+ )
+ except ImportError:
+ raise_vertex_import_error()
+
+ if not values["project"]:
+ raise ValueError(
+ "A GCP project should be provided to run inference on Model Garden!"
+ )
+
+ client_options = ClientOptions(
+ api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
+ )
+ client_info = get_client_info(module="vertex-ai-model-garden")
+ values["client"] = PredictionServiceClient(
+ client_options=client_options, client_info=client_info
+ )
+ values["async_client"] = PredictionServiceAsyncClient(
+ client_options=client_options, client_info=client_info
+ )
+ return values
+
+ @property
+ def endpoint_path(self) -> str:
+ return self.client.endpoint_path(
+ project=self.project,
+ location=self.location,
+ endpoint=self.endpoint_id,
+ )
+
+ @property
+ def _llm_type(self) -> str:
+ return "vertexai_model_garden"
+
+ def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
+ try:
+ from google.protobuf import json_format
+ from google.protobuf.struct_pb2 import Value
+ except ImportError:
+ raise ImportError(
+ "protobuf package not found, please install it with"
+ " `pip install protobuf`"
+ )
+ instances = []
+ for prompt in prompts:
+ if self.allowed_model_args:
+ instance = {
+ k: v for k, v in kwargs.items() if k in self.allowed_model_args
+ }
+ else:
+ instance = {}
+ instance[self.prompt_arg] = prompt
+ instances.append(instance)
+
+ predict_instances = [
+ json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
+ ]
+ return predict_instances
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Run the LLM on the given prompt and input."""
+ instances = self._prepare_request(prompts, **kwargs)
+ response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
+ return self._parse_response(response)
+
+ def _parse_response(self, predictions: "Prediction") -> LLMResult:
+ generations: List[List[Generation]] = []
+ for result in predictions.predictions:
+ generations.append(
+ [
+ Generation(text=self._parse_prediction(prediction))
+ for prediction in result
+ ]
+ )
+ return LLMResult(generations=generations)
+
+ def _parse_prediction(self, prediction: Any) -> str:
+ if isinstance(prediction, str):
+ return prediction
+
+ if self.result_arg:
+ try:
+ return prediction[self.result_arg]
+ except KeyError:
+ if isinstance(prediction, str):
+ error_desc = (
+ "Provided non-None `result_arg` (result_arg="
+ f"{self.result_arg}). But got prediction of type "
+ f"{type(prediction)} instead of dict. Most probably, you"
+ "need to set `result_arg=None` during VertexAIModelGarden "
+ "initialization."
+ )
+ raise ValueError(error_desc)
+ else:
+ raise ValueError(f"{self.result_arg} key not found in prediction!")
+
+ return prediction
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Run the LLM on the given prompt and input."""
+ instances = self._prepare_request(prompts, **kwargs)
+ response = await self.async_client.predict(
+ endpoint=self.endpoint_path, instances=instances
+ )
+ return self._parse_response(response)
diff --git a/libs/community/langchain_community/llms/vllm.py b/libs/community/langchain_community/llms/vllm.py
new file mode 100644
index 00000000000..5710cf48723
--- /dev/null
+++ b/libs/community/langchain_community/llms/vllm.py
@@ -0,0 +1,176 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import BaseLLM
+from langchain_core.outputs import Generation, LLMResult
+from langchain_core.pydantic_v1 import Field, root_validator
+
+from langchain_community.llms.openai import BaseOpenAI
+from langchain_community.utils.openai import is_openai_v1
+
+
+class VLLM(BaseLLM):
+ """VLLM language model."""
+
+ model: str = ""
+ """The name or path of a HuggingFace Transformers model."""
+
+ tensor_parallel_size: Optional[int] = 1
+ """The number of GPUs to use for distributed execution with tensor parallelism."""
+
+ trust_remote_code: Optional[bool] = False
+ """Trust remote code (e.g., from HuggingFace) when downloading the model
+ and tokenizer."""
+
+ n: int = 1
+ """Number of output sequences to return for the given prompt."""
+
+ best_of: Optional[int] = None
+ """Number of output sequences that are generated from the prompt."""
+
+ presence_penalty: float = 0.0
+ """Float that penalizes new tokens based on whether they appear in the
+ generated text so far"""
+
+ frequency_penalty: float = 0.0
+ """Float that penalizes new tokens based on their frequency in the
+ generated text so far"""
+
+ temperature: float = 1.0
+ """Float that controls the randomness of the sampling."""
+
+ top_p: float = 1.0
+ """Float that controls the cumulative probability of the top tokens to consider."""
+
+ top_k: int = -1
+ """Integer that controls the number of top tokens to consider."""
+
+ use_beam_search: bool = False
+ """Whether to use beam search instead of sampling."""
+
+ stop: Optional[List[str]] = None
+ """List of strings that stop the generation when they are generated."""
+
+ ignore_eos: bool = False
+ """Whether to ignore the EOS token and continue generating tokens after
+ the EOS token is generated."""
+
+ max_new_tokens: int = 512
+ """Maximum number of tokens to generate per output sequence."""
+
+ logprobs: Optional[int] = None
+ """Number of log probabilities to return per output token."""
+
+ dtype: str = "auto"
+ """The data type for the model weights and activations."""
+
+ download_dir: Optional[str] = None
+ """Directory to download and load the weights. (Default to the default
+ cache dir of huggingface)"""
+
+ vllm_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for `vllm.LLM` call not explicitly specified."""
+
+ client: Any #: :meta private:
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+
+ try:
+ from vllm import LLM as VLLModel
+ except ImportError:
+ raise ImportError(
+ "Could not import vllm python package. "
+ "Please install it with `pip install vllm`."
+ )
+
+ values["client"] = VLLModel(
+ model=values["model"],
+ tensor_parallel_size=values["tensor_parallel_size"],
+ trust_remote_code=values["trust_remote_code"],
+ dtype=values["dtype"],
+ download_dir=values["download_dir"],
+ **values["vllm_kwargs"],
+ )
+
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling vllm."""
+ return {
+ "n": self.n,
+ "best_of": self.best_of,
+ "max_tokens": self.max_new_tokens,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ "presence_penalty": self.presence_penalty,
+ "frequency_penalty": self.frequency_penalty,
+ "stop": self.stop,
+ "ignore_eos": self.ignore_eos,
+ "use_beam_search": self.use_beam_search,
+ "logprobs": self.logprobs,
+ }
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Run the LLM on the given prompt and input."""
+
+ from vllm import SamplingParams
+
+ # build sampling parameters
+ params = {**self._default_params, **kwargs, "stop": stop}
+ sampling_params = SamplingParams(**params)
+ # call the model
+ outputs = self.client.generate(prompts, sampling_params)
+
+ generations = []
+ for output in outputs:
+ text = output.outputs[0].text
+ generations.append([Generation(text=text)])
+
+ return LLMResult(generations=generations)
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "vllm"
+
+
+class VLLMOpenAI(BaseOpenAI):
+ """vLLM OpenAI-compatible API client"""
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ @property
+ def _invocation_params(self) -> Dict[str, Any]:
+ """Get the parameters used to invoke the model."""
+
+ params: Dict[str, Any] = {
+ "model": self.model_name,
+ **self._default_params,
+ "logit_bias": None,
+ }
+ if not is_openai_v1():
+ params.update(
+ {
+ "api_key": self.openai_api_key,
+ "api_base": self.openai_api_base,
+ }
+ )
+
+ return params
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "vllm-openai"
diff --git a/libs/community/langchain_community/llms/volcengine_maas.py b/libs/community/langchain_community/llms/volcengine_maas.py
new file mode 100644
index 00000000000..9f32005f464
--- /dev/null
+++ b/libs/community/langchain_community/llms/volcengine_maas.py
@@ -0,0 +1,175 @@
+from __future__ import annotations
+
+from typing import Any, Dict, Iterator, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class VolcEngineMaasBase(BaseModel):
+ """Base class for VolcEngineMaas models."""
+
+ client: Any
+
+ volc_engine_maas_ak: Optional[str] = None
+ """access key for volc engine"""
+ volc_engine_maas_sk: Optional[str] = None
+ """secret key for volc engine"""
+
+ endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com"
+ """Endpoint of the VolcEngineMaas LLM."""
+
+ region: Optional[str] = "Region"
+ """Region of the VolcEngineMaas LLM."""
+
+ model: str = "skylark-lite-public"
+ """Model name. you could check this model details here
+ https://www.volcengine.com/docs/82379/1133187
+ and you could choose other models by change this field"""
+ model_version: Optional[str] = None
+ """Model version. Only used in moonshot large language model.
+ you could check details here https://www.volcengine.com/docs/82379/1158281"""
+
+ top_p: Optional[float] = 0.8
+ """Total probability mass of tokens to consider at each step."""
+
+ temperature: Optional[float] = 0.95
+ """A non-negative float that tunes the degree of randomness in generation."""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """model special arguments, you could check detail on model page"""
+
+ streaming: bool = False
+ """Whether to stream the results."""
+
+ connect_timeout: Optional[int] = 60
+ """Timeout for connect to volc engine maas endpoint. Default is 60 seconds."""
+
+ read_timeout: Optional[int] = 60
+ """Timeout for read response from volc engine maas endpoint.
+ Default is 60 seconds."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ ak = get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY")
+ sk = get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY")
+ endpoint = values["endpoint"]
+ if values["endpoint"] is not None and values["endpoint"] != "":
+ endpoint = values["endpoint"]
+ try:
+ from volcengine.maas import MaasService
+
+ maas = MaasService(
+ endpoint,
+ values["region"],
+ connection_timeout=values["connect_timeout"],
+ socket_timeout=values["read_timeout"],
+ )
+ maas.set_ak(ak)
+ values["volc_engine_maas_ak"] = ak
+ values["volc_engine_maas_sk"] = sk
+ maas.set_sk(sk)
+ values["client"] = maas
+ except ImportError:
+ raise ImportError(
+ "volcengine package not found, please install it with "
+ "`pip install volcengine`"
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling VolcEngineMaas API."""
+ normal_params = {
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ }
+
+ return {**normal_params, **self.model_kwargs}
+
+
+class VolcEngineMaasLLM(LLM, VolcEngineMaasBase):
+ """volc engine maas hosts a plethora of models.
+ You can utilize these models through this class.
+
+ To use, you should have the ``volcengine`` python package installed.
+ and set access key and secret key by environment variable or direct pass those to
+ this class.
+ access key, secret key are required parameters which you could get help
+ https://www.volcengine.com/docs/6291/65568
+
+ In order to use them, it is necessary to install the 'volcengine' Python package.
+ The access key and secret key must be set either via environment variables or
+ passed directly to this class.
+ access key and secret key are mandatory parameters for which assistance can be
+ sought at https://www.volcengine.com/docs/6291/65568.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import VolcEngineMaasLLM
+ model = VolcEngineMaasLLM(model="skylark-lite-public",
+ volc_engine_maas_ak="your_ak",
+ volc_engine_maas_sk="your_sk")
+ """
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "volc-engine-maas-llm"
+
+ def _convert_prompt_msg_params(
+ self,
+ prompt: str,
+ **kwargs: Any,
+ ) -> dict:
+ model_req = {
+ "model": {
+ "name": self.model,
+ }
+ }
+ if self.model_version is not None:
+ model_req["model"]["version"] = self.model_version
+
+ return {
+ **model_req,
+ "messages": [{"role": "user", "content": prompt}],
+ "parameters": {**self._default_params, **kwargs},
+ }
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ if self.streaming:
+ completion = ""
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
+ completion += chunk.text
+ return completion
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
+ response = self.client.chat(params)
+
+ return response.get("choice", {}).get("message", {}).get("content", "")
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
+ for res in self.client.stream_chat(params):
+ if res:
+ chunk = GenerationChunk(
+ text=res.get("choice", {}).get("message", {}).get("content", "")
+ )
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
diff --git a/libs/community/langchain_community/llms/watsonxllm.py b/libs/community/langchain_community/llms/watsonxllm.py
new file mode 100644
index 00000000000..3e927eb16b2
--- /dev/null
+++ b/libs/community/langchain_community/llms/watsonxllm.py
@@ -0,0 +1,353 @@
+import logging
+import os
+from typing import Any, Dict, Iterator, List, Mapping, Optional, Union
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import BaseLLM
+from langchain_core.outputs import Generation, GenerationChunk, LLMResult
+from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class WatsonxLLM(BaseLLM):
+ """
+ IBM watsonx.ai large language models.
+
+ To use, you should have ``ibm_watson_machine_learning`` python package installed,
+ and the environment variable ``WATSONX_APIKEY`` set with your API key, or pass
+ it as a named parameter to the constructor.
+
+
+ Example:
+ .. code-block:: python
+
+ from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames
+ parameters = {
+ GenTextParamsMetaNames.DECODING_METHOD: "sample",
+ GenTextParamsMetaNames.MAX_NEW_TOKENS: 100,
+ GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
+ GenTextParamsMetaNames.TEMPERATURE: 0.5,
+ GenTextParamsMetaNames.TOP_K: 50,
+ GenTextParamsMetaNames.TOP_P: 1,
+ }
+
+ from langchain_community.llms import WatsonxLLM
+ llm = WatsonxLLM(
+ model_id="google/flan-ul2",
+ url="https://us-south.ml.cloud.ibm.com",
+ apikey="*****",
+ project_id="*****",
+ params=parameters,
+ )
+ """
+
+ model_id: str = ""
+ """Type of model to use."""
+
+ project_id: str = ""
+ """ID of the Watson Studio project."""
+
+ space_id: str = ""
+ """ID of the Watson Studio space."""
+
+ url: Optional[SecretStr] = None
+ """Url to Watson Machine Learning instance"""
+
+ apikey: Optional[SecretStr] = None
+ """Apikey to Watson Machine Learning instance"""
+
+ token: Optional[SecretStr] = None
+ """Token to Watson Machine Learning instance"""
+
+ password: Optional[SecretStr] = None
+ """Password to Watson Machine Learning instance"""
+
+ username: Optional[SecretStr] = None
+ """Username to Watson Machine Learning instance"""
+
+ instance_id: Optional[SecretStr] = None
+ """Instance_id of Watson Machine Learning instance"""
+
+ version: Optional[SecretStr] = None
+ """Version of Watson Machine Learning instance"""
+
+ params: Optional[dict] = None
+ """Model parameters to use during generate requests."""
+
+ verify: Union[str, bool] = ""
+ """User can pass as verify one of following:
+ the path to a CA_BUNDLE file
+ the path of directory with certificates of trusted CAs
+ True - default path to truststore will be taken
+ False - no verification will be made"""
+
+ streaming: bool = False
+ """ Whether to stream the results or not. """
+
+ watsonx_model: Any
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ return False
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {
+ "url": "WATSONX_URL",
+ "apikey": "WATSONX_APIKEY",
+ "token": "WATSONX_TOKEN",
+ "password": "WATSONX_PASSWORD",
+ "username": "WATSONX_USERNAME",
+ "instance_id": "WATSONX_INSTANCE_ID",
+ }
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that credentials and python package exists in environment."""
+ values["url"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "url", "WATSONX_URL")
+ )
+ if "cloud.ibm.com" in values.get("url", "").get_secret_value():
+ values["apikey"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "apikey", "WATSONX_APIKEY")
+ )
+ else:
+ if (
+ not values["token"]
+ and "WATSONX_TOKEN" not in os.environ
+ and not values["password"]
+ and "WATSONX_PASSWORD" not in os.environ
+ and not values["apikey"]
+ and "WATSONX_APIKEY" not in os.environ
+ ):
+ raise ValueError(
+ "Did not find 'token', 'password' or 'apikey',"
+ " please add an environment variable"
+ " `WATSONX_TOKEN`, 'WATSONX_PASSWORD' or 'WATSONX_APIKEY' "
+ "which contains it,"
+ " or pass 'token', 'password' or 'apikey'"
+ " as a named parameter."
+ )
+ elif values["token"] or "WATSONX_TOKEN" in os.environ:
+ values["token"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "token", "WATSONX_TOKEN")
+ )
+ elif values["password"] or "WATSONX_PASSWORD" in os.environ:
+ values["password"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "password", "WATSONX_PASSWORD")
+ )
+ values["username"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "username", "WATSONX_USERNAME")
+ )
+ elif values["apikey"] or "WATSONX_APIKEY" in os.environ:
+ values["apikey"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "apikey", "WATSONX_APIKEY")
+ )
+ values["username"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "username", "WATSONX_USERNAME")
+ )
+ if not values["instance_id"] or "WATSONX_INSTANCE_ID" not in os.environ:
+ values["instance_id"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "instance_id", "WATSONX_INSTANCE_ID")
+ )
+
+ try:
+ from ibm_watson_machine_learning.foundation_models import Model
+
+ credentials = {
+ "url": values["url"].get_secret_value() if values["url"] else None,
+ "apikey": values["apikey"].get_secret_value()
+ if values["apikey"]
+ else None,
+ "token": values["token"].get_secret_value()
+ if values["token"]
+ else None,
+ "password": values["password"].get_secret_value()
+ if values["password"]
+ else None,
+ "username": values["username"].get_secret_value()
+ if values["username"]
+ else None,
+ "instance_id": values["instance_id"].get_secret_value()
+ if values["instance_id"]
+ else None,
+ "version": values["version"].get_secret_value()
+ if values["version"]
+ else None,
+ }
+ credentials_without_none_value = {
+ key: value for key, value in credentials.items() if value is not None
+ }
+
+ watsonx_model = Model(
+ model_id=values["model_id"],
+ credentials=credentials_without_none_value,
+ params=values["params"],
+ project_id=values["project_id"],
+ space_id=values["space_id"],
+ verify=values["verify"],
+ )
+ values["watsonx_model"] = watsonx_model
+
+ except ImportError:
+ raise ImportError(
+ "Could not import ibm_watson_machine_learning python package. "
+ "Please install it with `pip install ibm_watson_machine_learning`."
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model_id": self.model_id,
+ "params": self.params,
+ "project_id": self.project_id,
+ "space_id": self.space_id,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "IBM watsonx.ai"
+
+ @staticmethod
+ def _extract_token_usage(
+ response: Optional[List[Dict[str, Any]]] = None,
+ ) -> Dict[str, Any]:
+ if response is None:
+ return {"generated_token_count": 0, "input_token_count": 0}
+
+ input_token_count = 0
+ generated_token_count = 0
+
+ def get_count_value(key: str, result: Dict[str, Any]) -> int:
+ return result.get(key, 0) or 0
+
+ for res in response:
+ results = res.get("results")
+ if results:
+ input_token_count += get_count_value("input_token_count", results[0])
+ generated_token_count += get_count_value(
+ "generated_token_count", results[0]
+ )
+
+ return {
+ "generated_token_count": generated_token_count,
+ "input_token_count": input_token_count,
+ }
+
+ def _create_llm_result(self, response: List[dict]) -> LLMResult:
+ """Create the LLMResult from the choices and prompts."""
+ generations = []
+ for res in response:
+ results = res.get("results")
+ if results:
+ finish_reason = results[0].get("stop_reason")
+ gen = Generation(
+ text=results[0].get("generated_text"),
+ generation_info={"finish_reason": finish_reason},
+ )
+ generations.append([gen])
+ final_token_usage = self._extract_token_usage(response)
+ llm_output = {"token_usage": final_token_usage, "model_id": self.model_id}
+ return LLMResult(generations=generations, llm_output=llm_output)
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the IBM watsonx.ai inference endpoint.
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ run_manager: Optional callback manager.
+ Returns:
+ The string generated by the model.
+ Example:
+ .. code-block:: python
+
+ response = watsonxllm("What is a molecule")
+ """
+ result = self._generate(
+ prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
+ )
+ return result.generations[0][0].text
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ stream: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Call the IBM watsonx.ai inference endpoint which then generate the response.
+ Args:
+ prompts: List of strings (prompts) to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ run_manager: Optional callback manager.
+ Returns:
+ The full LLMResult output.
+ Example:
+ .. code-block:: python
+
+ response = watsonxllm.generate(["What is a molecule"])
+ """
+ should_stream = stream if stream is not None else self.streaming
+ if should_stream:
+ if len(prompts) > 1:
+ raise ValueError(
+ f"WatsonxLLM currently only supports single prompt, got {prompts}"
+ )
+ generation = GenerationChunk(text="")
+ stream_iter = self._stream(
+ prompts[0], stop=stop, run_manager=run_manager, **kwargs
+ )
+ for chunk in stream_iter:
+ if generation is None:
+ generation = chunk
+ else:
+ generation += chunk
+ assert generation is not None
+ return LLMResult(generations=[[generation]])
+ else:
+ response = self.watsonx_model.generate(prompt=prompts)
+ return self._create_llm_result(response)
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ """Call the IBM watsonx.ai inference endpoint which then streams the response.
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ run_manager: Optional callback manager.
+ Returns:
+ The iterator which yields generation chunks.
+ Example:
+ .. code-block:: python
+
+ response = watsonxllm.stream("What is a molecule")
+ for chunk in response:
+ print(chunk, end='')
+ """
+ for chunk in self.watsonx_model.generate_text_stream(prompt=prompt):
+ if chunk:
+ yield GenerationChunk(text=chunk)
+ if run_manager:
+ run_manager.on_llm_new_token(chunk)
diff --git a/libs/community/langchain_community/llms/writer.py b/libs/community/langchain_community/llms/writer.py
new file mode 100644
index 00000000000..3b7bc6a06c4
--- /dev/null
+++ b/libs/community/langchain_community/llms/writer.py
@@ -0,0 +1,159 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class Writer(LLM):
+ """Writer large language models.
+
+ To use, you should have the environment variable ``WRITER_API_KEY`` and
+ ``WRITER_ORG_ID`` set with your API key and organization ID respectively.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import Writer
+ writer = Writer(model_id="palmyra-base")
+ """
+
+ writer_org_id: Optional[str] = None
+ """Writer organization ID."""
+
+ model_id: str = "palmyra-instruct"
+ """Model name to use."""
+
+ min_tokens: Optional[int] = None
+ """Minimum number of tokens to generate."""
+
+ max_tokens: Optional[int] = None
+ """Maximum number of tokens to generate."""
+
+ temperature: Optional[float] = None
+ """What sampling temperature to use."""
+
+ top_p: Optional[float] = None
+ """Total probability mass of tokens to consider at each step."""
+
+ stop: Optional[List[str]] = None
+ """Sequences when completion generation will stop."""
+
+ presence_penalty: Optional[float] = None
+ """Penalizes repeated tokens regardless of frequency."""
+
+ repetition_penalty: Optional[float] = None
+ """Penalizes repeated tokens according to frequency."""
+
+ best_of: Optional[int] = None
+ """Generates this many completions server-side and returns the "best"."""
+
+ logprobs: bool = False
+ """Whether to return log probabilities."""
+
+ n: Optional[int] = None
+ """How many completions to generate."""
+
+ writer_api_key: Optional[str] = None
+ """Writer API key."""
+
+ base_url: Optional[str] = None
+ """Base url to use, if None decides based on model name."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and organization id exist in environment."""
+
+ writer_api_key = get_from_dict_or_env(
+ values, "writer_api_key", "WRITER_API_KEY"
+ )
+ values["writer_api_key"] = writer_api_key
+
+ writer_org_id = get_from_dict_or_env(values, "writer_org_id", "WRITER_ORG_ID")
+ values["writer_org_id"] = writer_org_id
+
+ return values
+
+ @property
+ def _default_params(self) -> Mapping[str, Any]:
+ """Get the default parameters for calling Writer API."""
+ return {
+ "minTokens": self.min_tokens,
+ "maxTokens": self.max_tokens,
+ "temperature": self.temperature,
+ "topP": self.top_p,
+ "stop": self.stop,
+ "presencePenalty": self.presence_penalty,
+ "repetitionPenalty": self.repetition_penalty,
+ "bestOf": self.best_of,
+ "logprobs": self.logprobs,
+ "n": self.n,
+ }
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"model_id": self.model_id, "writer_org_id": self.writer_org_id},
+ **self._default_params,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "writer"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to Writer's completions endpoint.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = Writer("Tell me a joke.")
+ """
+ if self.base_url is not None:
+ base_url = self.base_url
+ else:
+ base_url = (
+ "https://enterprise-api.writer.com/llm"
+ f"/organization/{self.writer_org_id}"
+ f"/model/{self.model_id}/completions"
+ )
+ params = {**self._default_params, **kwargs}
+ response = requests.post(
+ url=base_url,
+ headers={
+ "Authorization": f"{self.writer_api_key}",
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ },
+ json={"prompt": prompt, **params},
+ )
+ text = response.text
+ if stop is not None:
+ # I believe this is required since the stop tokens
+ # are not enforced by the model parameters
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/community/langchain_community/llms/xinference.py b/libs/community/langchain_community/llms/xinference.py
new file mode 100644
index 00000000000..0c44daa3880
--- /dev/null
+++ b/libs/community/langchain_community/llms/xinference.py
@@ -0,0 +1,206 @@
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.llms import LLM
+
+if TYPE_CHECKING:
+ from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
+ from xinference.model.llm.core import LlamaCppGenerateConfig
+
+
+class Xinference(LLM):
+ """Wrapper for accessing Xinference's large-scale model inference service.
+ To use, you should have the xinference library installed:
+
+ .. code-block:: bash
+
+ pip install "xinference[all]"
+
+ Check out: https://github.com/xorbitsai/inference
+ To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
+
+ Example:
+ To start a local instance of Xinference, run
+
+ .. code-block:: bash
+
+ $ xinference
+
+ You can also deploy Xinference in a distributed cluster. Here are the steps:
+
+ Starting the supervisor:
+
+ .. code-block:: bash
+
+ $ xinference-supervisor
+
+ Starting the worker:
+
+ .. code-block:: bash
+
+ $ xinference-worker
+
+ Then, launch a model using command line interface (CLI).
+
+ Example:
+
+ .. code-block:: bash
+
+ $ xinference launch -n orca -s 3 -q q4_0
+
+ It will return a model UID. Then, you can use Xinference with LangChain.
+
+ Example:
+
+ .. code-block:: python
+
+ from langchain_community.llms import Xinference
+
+ llm = Xinference(
+ server_url="http://0.0.0.0:9997",
+ model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
+ )
+
+ llm(
+ prompt="Q: where can we visit in the capital of France? A:",
+ generate_config={"max_tokens": 1024, "stream": True},
+ )
+
+ To view all the supported builtin models, run:
+
+ .. code-block:: bash
+
+ $ xinference list --all
+
+ """ # noqa: E501
+
+ client: Any
+ server_url: Optional[str]
+ """URL of the xinference server"""
+ model_uid: Optional[str]
+ """UID of the launched model"""
+ model_kwargs: Dict[str, Any]
+ """Keyword arguments to be passed to xinference.LLM"""
+
+ def __init__(
+ self,
+ server_url: Optional[str] = None,
+ model_uid: Optional[str] = None,
+ **model_kwargs: Any,
+ ):
+ try:
+ from xinference.client import RESTfulClient
+ except ImportError as e:
+ raise ImportError(
+ "Could not import RESTfulClient from xinference. Please install it"
+ " with `pip install xinference`."
+ ) from e
+
+ model_kwargs = model_kwargs or {}
+
+ super().__init__(
+ **{
+ "server_url": server_url,
+ "model_uid": model_uid,
+ "model_kwargs": model_kwargs,
+ }
+ )
+
+ if self.server_url is None:
+ raise ValueError("Please provide server URL")
+
+ if self.model_uid is None:
+ raise ValueError("Please provide the model UID")
+
+ self.client = RESTfulClient(server_url)
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "xinference"
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ **{"server_url": self.server_url},
+ **{"model_uid": self.model_uid},
+ **{"model_kwargs": self.model_kwargs},
+ }
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the xinference model and return the output.
+
+ Args:
+ prompt: The prompt to use for generation.
+ stop: Optional list of stop words to use when generating.
+ generate_config: Optional dictionary for the configuration used for
+ generation.
+
+ Returns:
+ The generated string by the model.
+ """
+ model = self.client.get_model(self.model_uid)
+
+ generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
+
+ generate_config = {**self.model_kwargs, **generate_config}
+
+ if stop:
+ generate_config["stop"] = stop
+
+ if generate_config and generate_config.get("stream"):
+ combined_text_output = ""
+ for token in self._stream_generate(
+ model=model,
+ prompt=prompt,
+ run_manager=run_manager,
+ generate_config=generate_config,
+ ):
+ combined_text_output += token
+ return combined_text_output
+
+ else:
+ completion = model.generate(prompt=prompt, generate_config=generate_config)
+ return completion["choices"][0]["text"]
+
+ def _stream_generate(
+ self,
+ model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"],
+ prompt: str,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ generate_config: Optional["LlamaCppGenerateConfig"] = None,
+ ) -> Generator[str, None, None]:
+ """
+ Args:
+ prompt: The prompt to use for generation.
+ model: The model used for generation.
+ stop: Optional list of stop words to use when generating.
+ generate_config: Optional dictionary for the configuration used for
+ generation.
+
+ Yields:
+ A string token.
+ """
+ streaming_response = model.generate(
+ prompt=prompt, generate_config=generate_config
+ )
+ for chunk in streaming_response:
+ if isinstance(chunk, dict):
+ choices = chunk.get("choices", [])
+ if choices:
+ choice = choices[0]
+ if isinstance(choice, dict):
+ token = choice.get("text", "")
+ log_probs = choice.get("logprobs")
+ if run_manager:
+ run_manager.on_llm_new_token(
+ token=token, verbose=self.verbose, log_probs=log_probs
+ )
+ yield token
diff --git a/libs/community/langchain_community/llms/yandex.py b/libs/community/langchain_community/llms/yandex.py
new file mode 100644
index 00000000000..d82daeba55c
--- /dev/null
+++ b/libs/community/langchain_community/llms/yandex.py
@@ -0,0 +1,206 @@
+from typing import Any, Dict, List, Mapping, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.load.serializable import Serializable
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.llms.utils import enforce_stop_tokens
+
+
+class _BaseYandexGPT(Serializable):
+ iam_token: str = ""
+ """Yandex Cloud IAM token for service account
+ with the `ai.languageModels.user` role"""
+ api_key: str = ""
+ """Yandex Cloud Api Key for service account
+ with the `ai.languageModels.user` role"""
+ model_name: str = "general"
+ """Model name to use."""
+ temperature: float = 0.6
+ """What sampling temperature to use.
+ Should be a double number between 0 (inclusive) and 1 (inclusive)."""
+ max_tokens: int = 7400
+ """Sets the maximum limit on the total number of tokens
+ used for both the input prompt and the generated response.
+ Must be greater than zero and not exceed 7400 tokens."""
+ stop: Optional[List[str]] = None
+ """Sequences when completion generation will stop."""
+ url: str = "llm.api.cloud.yandex.net:443"
+ """The url of the API."""
+
+ @property
+ def _llm_type(self) -> str:
+ return "yandex_gpt"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that iam token exists in environment."""
+
+ iam_token = get_from_dict_or_env(values, "iam_token", "YC_IAM_TOKEN", "")
+ values["iam_token"] = iam_token
+ api_key = get_from_dict_or_env(values, "api_key", "YC_API_KEY", "")
+ values["api_key"] = api_key
+ if api_key == "" and iam_token == "":
+ raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
+ return values
+
+
+class YandexGPT(_BaseYandexGPT, LLM):
+ """Yandex large language models.
+
+ To use, you should have the ``yandexcloud`` python package installed.
+
+ There are two authentication options for the service account
+ with the ``ai.languageModels.user`` role:
+ - You can specify the token in a constructor parameter `iam_token`
+ or in an environment variable `YC_IAM_TOKEN`.
+ - You can specify the key in a constructor parameter `api_key`
+ or in an environment variable `YC_API_KEY`.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import YandexGPT
+ yandex_gpt = YandexGPT(iam_token="t1.9eu...")
+ """
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {
+ "model_name": self.model_name,
+ "temperature": self.temperature,
+ "max_tokens": self.max_tokens,
+ "stop": self.stop,
+ }
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call the Yandex GPT model and return the output.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+
+ response = YandexGPT("Tell me a joke.")
+ """
+ try:
+ import grpc
+ from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
+ from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
+ from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import InstructRequest
+ from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
+ TextGenerationServiceStub,
+ )
+ except ImportError as e:
+ raise ImportError(
+ "Please install YandexCloud SDK" " with `pip install yandexcloud`."
+ ) from e
+ channel_credentials = grpc.ssl_channel_credentials()
+ channel = grpc.secure_channel(self.url, channel_credentials)
+ request = InstructRequest(
+ model=self.model_name,
+ request_text=prompt,
+ generation_options=GenerationOptions(
+ temperature=DoubleValue(value=self.temperature),
+ max_tokens=Int64Value(value=self.max_tokens),
+ ),
+ )
+ stub = TextGenerationServiceStub(channel)
+ if self.iam_token:
+ metadata = (("authorization", f"Bearer {self.iam_token}"),)
+ else:
+ metadata = (("authorization", f"Api-Key {self.api_key}"),)
+ res = stub.Instruct(request, metadata=metadata)
+ text = list(res)[0].alternatives[0].text
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Async call the Yandex GPT model and return the output.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+
+ Returns:
+ The string generated by the model.
+ """
+ try:
+ import asyncio
+
+ import grpc
+ from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
+ from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
+ from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import (
+ InstructRequest,
+ InstructResponse,
+ )
+ from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
+ TextGenerationAsyncServiceStub,
+ )
+ from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
+ from yandex.cloud.operation.operation_service_pb2_grpc import (
+ OperationServiceStub,
+ )
+ except ImportError as e:
+ raise ImportError(
+ "Please install YandexCloud SDK" " with `pip install yandexcloud`."
+ ) from e
+ operation_api_url = "operation.api.cloud.yandex.net:443"
+ channel_credentials = grpc.ssl_channel_credentials()
+ async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
+ request = InstructRequest(
+ model=self.model_name,
+ request_text=prompt,
+ generation_options=GenerationOptions(
+ temperature=DoubleValue(value=self.temperature),
+ max_tokens=Int64Value(value=self.max_tokens),
+ ),
+ )
+ stub = TextGenerationAsyncServiceStub(channel)
+ if self.iam_token:
+ metadata = (("authorization", f"Bearer {self.iam_token}"),)
+ else:
+ metadata = (("authorization", f"Api-Key {self.api_key}"),)
+ operation = await stub.Instruct(request, metadata=metadata)
+ async with grpc.aio.secure_channel(
+ operation_api_url, channel_credentials
+ ) as operation_channel:
+ operation_stub = OperationServiceStub(operation_channel)
+ while not operation.done:
+ await asyncio.sleep(1)
+ operation_request = GetOperationRequest(operation_id=operation.id)
+ operation = await operation_stub.Get(
+ operation_request, metadata=metadata
+ )
+
+ instruct_response = InstructResponse()
+ operation.response.Unpack(instruct_response)
+ text = instruct_response.alternatives[0].text
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+ return text
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_language.py b/libs/community/langchain_community/py.typed
similarity index 100%
rename from libs/langchain/tests/integration_tests/document_loaders/test_language.py
rename to libs/community/langchain_community/py.typed
diff --git a/libs/community/langchain_community/retrievers/__init__.py b/libs/community/langchain_community/retrievers/__init__.py
new file mode 100644
index 00000000000..3eaed0a31ec
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/__init__.py
@@ -0,0 +1,107 @@
+"""**Retriever** class returns Documents given a text **query**.
+
+It is more general than a vector store. A retriever does not need to be able to
+store documents, only to return (or retrieve) it. Vector stores can be used as
+the backbone of a retriever, but there are other types of retrievers as well.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseRetriever --> Retriever # Examples: ArxivRetriever, MergerRetriever
+
+**Main helpers:**
+
+.. code-block::
+
+ Document, Serializable, Callbacks,
+ CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
+"""
+
+from langchain_community.retrievers.arcee import ArceeRetriever
+from langchain_community.retrievers.arxiv import ArxivRetriever
+from langchain_community.retrievers.azure_cognitive_search import (
+ AzureCognitiveSearchRetriever,
+)
+from langchain_community.retrievers.bedrock import AmazonKnowledgeBasesRetriever
+from langchain_community.retrievers.bm25 import BM25Retriever
+from langchain_community.retrievers.chaindesk import ChaindeskRetriever
+from langchain_community.retrievers.chatgpt_plugin_retriever import (
+ ChatGPTPluginRetriever,
+)
+from langchain_community.retrievers.cohere_rag_retriever import CohereRagRetriever
+from langchain_community.retrievers.docarray import DocArrayRetriever
+from langchain_community.retrievers.elastic_search_bm25 import (
+ ElasticSearchBM25Retriever,
+)
+from langchain_community.retrievers.embedchain import EmbedchainRetriever
+from langchain_community.retrievers.google_cloud_documentai_warehouse import (
+ GoogleDocumentAIWarehouseRetriever,
+)
+from langchain_community.retrievers.google_vertex_ai_search import (
+ GoogleCloudEnterpriseSearchRetriever,
+ GoogleVertexAIMultiTurnSearchRetriever,
+ GoogleVertexAISearchRetriever,
+)
+from langchain_community.retrievers.kay import KayAiRetriever
+from langchain_community.retrievers.kendra import AmazonKendraRetriever
+from langchain_community.retrievers.knn import KNNRetriever
+from langchain_community.retrievers.llama_index import (
+ LlamaIndexGraphRetriever,
+ LlamaIndexRetriever,
+)
+from langchain_community.retrievers.metal import MetalRetriever
+from langchain_community.retrievers.milvus import MilvusRetriever
+from langchain_community.retrievers.outline import OutlineRetriever
+from langchain_community.retrievers.pinecone_hybrid_search import (
+ PineconeHybridSearchRetriever,
+)
+from langchain_community.retrievers.pubmed import PubMedRetriever
+from langchain_community.retrievers.remote_retriever import RemoteLangChainRetriever
+from langchain_community.retrievers.svm import SVMRetriever
+from langchain_community.retrievers.tavily_search_api import TavilySearchAPIRetriever
+from langchain_community.retrievers.tfidf import TFIDFRetriever
+from langchain_community.retrievers.vespa_retriever import VespaRetriever
+from langchain_community.retrievers.weaviate_hybrid_search import (
+ WeaviateHybridSearchRetriever,
+)
+from langchain_community.retrievers.wikipedia import WikipediaRetriever
+from langchain_community.retrievers.zep import ZepRetriever
+from langchain_community.retrievers.zilliz import ZillizRetriever
+
+__all__ = [
+ "AmazonKendraRetriever",
+ "AmazonKnowledgeBasesRetriever",
+ "ArceeRetriever",
+ "ArxivRetriever",
+ "AzureCognitiveSearchRetriever",
+ "ChatGPTPluginRetriever",
+ "ChaindeskRetriever",
+ "CohereRagRetriever",
+ "ElasticSearchBM25Retriever",
+ "EmbedchainRetriever",
+ "GoogleDocumentAIWarehouseRetriever",
+ "GoogleCloudEnterpriseSearchRetriever",
+ "GoogleVertexAIMultiTurnSearchRetriever",
+ "GoogleVertexAISearchRetriever",
+ "KayAiRetriever",
+ "KNNRetriever",
+ "LlamaIndexGraphRetriever",
+ "LlamaIndexRetriever",
+ "MetalRetriever",
+ "MilvusRetriever",
+ "OutlineRetriever",
+ "PineconeHybridSearchRetriever",
+ "PubMedRetriever",
+ "RemoteLangChainRetriever",
+ "SVMRetriever",
+ "TavilySearchAPIRetriever",
+ "TFIDFRetriever",
+ "BM25Retriever",
+ "VespaRetriever",
+ "WeaviateHybridSearchRetriever",
+ "WikipediaRetriever",
+ "ZepRetriever",
+ "ZillizRetriever",
+ "DocArrayRetriever",
+]
diff --git a/libs/community/langchain_community/retrievers/arcee.py b/libs/community/langchain_community/retrievers/arcee.py
new file mode 100644
index 00000000000..4e2116e58bc
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/arcee.py
@@ -0,0 +1,139 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter
+
+
+class ArceeRetriever(BaseRetriever):
+ """Document retriever for Arcee's Domain Adapted Language Models (DALMs).
+
+ To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
+ or pass ``arcee_api_key`` as a named parameter.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.retrievers import ArceeRetriever
+
+ retriever = ArceeRetriever(
+ model="DALM-PubMed",
+ arcee_api_key="ARCEE-API-KEY"
+ )
+
+ documents = retriever.get_relevant_documents("AI-driven music therapy")
+ """
+
+ _client: Optional[ArceeWrapper] = None #: :meta private:
+ """Arcee client."""
+
+ arcee_api_key: SecretStr
+ """Arcee API Key"""
+
+ model: str
+ """Arcee DALM name"""
+
+ arcee_api_url: str = "https://api.arcee.ai"
+ """Arcee API URL"""
+
+ arcee_api_version: str = "v2"
+ """Arcee API Version"""
+
+ arcee_app_url: str = "https://app.arcee.ai"
+ """Arcee App URL"""
+
+ model_kwargs: Optional[Dict[str, Any]] = None
+ """Keyword arguments to pass to the model."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ underscore_attrs_are_private = True
+
+ def __init__(self, **data: Any) -> None:
+ """Initializes private fields."""
+
+ super().__init__(**data)
+
+ self._client = ArceeWrapper(
+ arcee_api_key=self.arcee_api_key.get_secret_value(),
+ arcee_api_url=self.arcee_api_url,
+ arcee_api_version=self.arcee_api_version,
+ model_kwargs=self.model_kwargs,
+ model_name=self.model,
+ )
+
+ self._client.validate_model_training_status()
+
+ @root_validator()
+ def validate_environments(cls, values: Dict) -> Dict:
+ """Validate Arcee environment variables."""
+
+ # validate env vars
+ values["arcee_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "arcee_api_key",
+ "ARCEE_API_KEY",
+ )
+ )
+
+ values["arcee_api_url"] = get_from_dict_or_env(
+ values,
+ "arcee_api_url",
+ "ARCEE_API_URL",
+ )
+
+ values["arcee_app_url"] = get_from_dict_or_env(
+ values,
+ "arcee_app_url",
+ "ARCEE_APP_URL",
+ )
+
+ values["arcee_api_version"] = get_from_dict_or_env(
+ values,
+ "arcee_api_version",
+ "ARCEE_API_VERSION",
+ )
+
+ # validate model kwargs
+ if values["model_kwargs"]:
+ kw = values["model_kwargs"]
+
+ # validate size
+ if kw.get("size") is not None:
+ if not kw.get("size") >= 0:
+ raise ValueError("`size` must not be negative.")
+
+ # validate filters
+ if kw.get("filters") is not None:
+ if not isinstance(kw.get("filters"), List):
+ raise ValueError("`filters` must be a list.")
+ for f in kw.get("filters"):
+ DALMFilter(**f)
+
+ return values
+
+ def _get_relevant_documents(
+ self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
+ ) -> List[Document]:
+ """Retrieve {size} contexts with your retriever for a given query
+
+ Args:
+ query: Query to submit to the model
+ size: The max number of context results to retrieve.
+ Defaults to 3. (Can be less if filters are provided).
+ filters: Filters to apply to the context dataset.
+ """
+
+ try:
+ if not self._client:
+ raise ValueError("Client is not initialized.")
+ return self._client.retrieve(query=query, **kwargs)
+ except Exception as e:
+ raise ValueError(f"Error while retrieving documents: {e}") from e
diff --git a/libs/community/langchain_community/retrievers/arxiv.py b/libs/community/langchain_community/retrievers/arxiv.py
new file mode 100644
index 00000000000..633d22b1a24
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/arxiv.py
@@ -0,0 +1,25 @@
+from typing import List
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+from langchain_community.utilities.arxiv import ArxivAPIWrapper
+
+
+class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
+ """`Arxiv` retriever.
+
+ It wraps load() to get_relevant_documents().
+ It uses all ArxivAPIWrapper arguments without any change.
+ """
+
+ get_full_documents: bool = False
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ if self.get_full_documents:
+ return self.load(query=query)
+ else:
+ return self.get_summaries_as_docs(query)
diff --git a/libs/community/langchain_community/retrievers/azure_cognitive_search.py b/libs/community/langchain_community/retrievers/azure_cognitive_search.py
new file mode 100644
index 00000000000..fb91dd188f6
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/azure_cognitive_search.py
@@ -0,0 +1,114 @@
+from __future__ import annotations
+
+import json
+from typing import Dict, List, Optional
+
+import aiohttp
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForRetrieverRun,
+ CallbackManagerForRetrieverRun,
+)
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.utils import get_from_dict_or_env, get_from_env
+
+DEFAULT_URL_SUFFIX = "search.windows.net"
+"""Default URL Suffix for endpoint connection - commercial cloud"""
+
+
+class AzureCognitiveSearchRetriever(BaseRetriever):
+ """`Azure Cognitive Search` service retriever."""
+
+ service_name: str = ""
+ """Name of Azure Cognitive Search service"""
+ index_name: str = ""
+ """Name of Index inside Azure Cognitive Search service"""
+ api_key: str = ""
+ """API Key. Both Admin and Query keys work, but for reading data it's
+ recommended to use a Query key."""
+ api_version: str = "2020-06-30"
+ """API version"""
+ aiosession: Optional[aiohttp.ClientSession] = None
+ """ClientSession, in case we want to reuse connection for better performance."""
+ content_key: str = "content"
+ """Key in a retrieved result to set as the Document page_content."""
+ top_k: Optional[int] = None
+ """Number of results to retrieve. Set to None to retrieve all results."""
+
+ class Config:
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that service name, index name and api key exists in environment."""
+ values["service_name"] = get_from_dict_or_env(
+ values, "service_name", "AZURE_COGNITIVE_SEARCH_SERVICE_NAME"
+ )
+ values["index_name"] = get_from_dict_or_env(
+ values, "index_name", "AZURE_COGNITIVE_SEARCH_INDEX_NAME"
+ )
+ values["api_key"] = get_from_dict_or_env(
+ values, "api_key", "AZURE_COGNITIVE_SEARCH_API_KEY"
+ )
+ return values
+
+ def _build_search_url(self, query: str) -> str:
+ url_suffix = get_from_env(
+ "", "AZURE_COGNITIVE_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX
+ )
+ base_url = f"https://{self.service_name}.{url_suffix}/"
+ endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
+ top_param = f"&$top={self.top_k}" if self.top_k else ""
+ return base_url + endpoint_path + f"&search={query}" + top_param
+
+ @property
+ def _headers(self) -> Dict[str, str]:
+ return {
+ "Content-Type": "application/json",
+ "api-key": self.api_key,
+ }
+
+ def _search(self, query: str) -> List[dict]:
+ search_url = self._build_search_url(query)
+ response = requests.get(search_url, headers=self._headers)
+ if response.status_code != 200:
+ raise Exception(f"Error in search request: {response}")
+
+ return json.loads(response.text)["value"]
+
+ async def _asearch(self, query: str) -> List[dict]:
+ search_url = self._build_search_url(query)
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(search_url, headers=self._headers) as response:
+ response_json = await response.json()
+ else:
+ async with self.aiosession.get(
+ search_url, headers=self._headers
+ ) as response:
+ response_json = await response.json()
+
+ return response_json["value"]
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ search_results = self._search(query)
+
+ return [
+ Document(page_content=result.pop(self.content_key), metadata=result)
+ for result in search_results
+ ]
+
+ async def _aget_relevant_documents(
+ self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ search_results = await self._asearch(query)
+
+ return [
+ Document(page_content=result.pop(self.content_key), metadata=result)
+ for result in search_results
+ ]
diff --git a/libs/community/langchain_community/retrievers/bedrock.py b/libs/community/langchain_community/retrievers/bedrock.py
new file mode 100644
index 00000000000..507fd5b88f1
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/bedrock.py
@@ -0,0 +1,124 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.retrievers import BaseRetriever
+
+
+class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
+ numberOfResults: int = 4
+
+
+class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
+ vectorSearchConfiguration: VectorSearchConfig
+
+
+class AmazonKnowledgeBasesRetriever(BaseRetriever):
+ """A retriever class for `Amazon Bedrock Knowledge Bases`.
+
+ See https://aws.amazon.com/bedrock/knowledge-bases for more info.
+
+ Args:
+ knowledge_base_id: Knowledge Base ID.
+ region_name: The aws region e.g., `us-west-2`.
+ Fallback to AWS_DEFAULT_REGION env variable or region specified in
+ ~/.aws/config.
+ credentials_profile_name: The name of the profile in the ~/.aws/credentials
+ or ~/.aws/config files, which has either access keys or role information
+ specified. If not specified, the default credential profile or, if on an
+ EC2 instance, credentials from IMDS will be used.
+ client: boto3 client for bedrock agent runtime.
+ retrieval_config: Configuration for retrieval.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
+
+ retriever = AmazonKnowledgeBasesRetriever(
+ knowledge_base_id="",
+ retrieval_config={
+ "vectorSearchConfiguration": {
+ "numberOfResults": 4
+ }
+ },
+ )
+ """
+
+ knowledge_base_id: str
+ region_name: Optional[str] = None
+ credentials_profile_name: Optional[str] = None
+ endpoint_url: Optional[str] = None
+ client: Any
+ retrieval_config: RetrievalConfig
+
+ @root_validator(pre=True)
+ def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ if values.get("client") is not None:
+ return values
+
+ try:
+ import boto3
+ from botocore.client import Config
+ from botocore.exceptions import UnknownServiceError
+
+ if values.get("credentials_profile_name"):
+ session = boto3.Session(profile_name=values["credentials_profile_name"])
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ client_params = {
+ "config": Config(
+ connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
+ )
+ }
+ if values.get("region_name"):
+ client_params["region_name"] = values["region_name"]
+
+ if values.get("endpoint_url"):
+ client_params["endpoint_url"] = values["endpoint_url"]
+
+ values["client"] = session.client("bedrock-agent-runtime", **client_params)
+
+ return values
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ except UnknownServiceError as e:
+ raise ModuleNotFoundError(
+ "Ensure that you have installed the latest boto3 package "
+ "that contains the API for `bedrock-runtime-agent`."
+ ) from e
+ except Exception as e:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ response = self.client.retrieve(
+ retrievalQuery={"text": query.strip()},
+ knowledgeBaseId=self.knowledge_base_id,
+ retrievalConfiguration=self.retrieval_config.dict(),
+ )
+ results = response["retrievalResults"]
+ documents = []
+ for result in results:
+ documents.append(
+ Document(
+ page_content=result["content"]["text"],
+ metadata={
+ "location": result["location"],
+ "score": result["score"] if "score" in result else 0,
+ },
+ )
+ )
+
+ return documents
diff --git a/libs/community/langchain_community/retrievers/bm25.py b/libs/community/langchain_community/retrievers/bm25.py
new file mode 100644
index 00000000000..c0e0b248313
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/bm25.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+from typing import Any, Callable, Dict, Iterable, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+def default_preprocessing_func(text: str) -> List[str]:
+ return text.split()
+
+
+class BM25Retriever(BaseRetriever):
+ """`BM25` retriever without Elasticsearch."""
+
+ vectorizer: Any
+ """ BM25 vectorizer."""
+ docs: List[Document]
+ """ List of documents."""
+ k: int = 4
+ """ Number of documents to return."""
+ preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
+ """ Preprocessing function to use on the text before BM25 vectorization."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: Iterable[str],
+ metadatas: Optional[Iterable[dict]] = None,
+ bm25_params: Optional[Dict[str, Any]] = None,
+ preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
+ **kwargs: Any,
+ ) -> BM25Retriever:
+ """
+ Create a BM25Retriever from a list of texts.
+ Args:
+ texts: A list of texts to vectorize.
+ metadatas: A list of metadata dicts to associate with each text.
+ bm25_params: Parameters to pass to the BM25 vectorizer.
+ preprocess_func: A function to preprocess each text before vectorization.
+ **kwargs: Any other arguments to pass to the retriever.
+
+ Returns:
+ A BM25Retriever instance.
+ """
+ try:
+ from rank_bm25 import BM25Okapi
+ except ImportError:
+ raise ImportError(
+ "Could not import rank_bm25, please install with `pip install "
+ "rank_bm25`."
+ )
+
+ texts_processed = [preprocess_func(t) for t in texts]
+ bm25_params = bm25_params or {}
+ vectorizer = BM25Okapi(texts_processed, **bm25_params)
+ metadatas = metadatas or ({} for _ in texts)
+ docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
+ return cls(
+ vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs
+ )
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: Iterable[Document],
+ *,
+ bm25_params: Optional[Dict[str, Any]] = None,
+ preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
+ **kwargs: Any,
+ ) -> BM25Retriever:
+ """
+ Create a BM25Retriever from a list of Documents.
+ Args:
+ documents: A list of Documents to vectorize.
+ bm25_params: Parameters to pass to the BM25 vectorizer.
+ preprocess_func: A function to preprocess each text before vectorization.
+ **kwargs: Any other arguments to pass to the retriever.
+
+ Returns:
+ A BM25Retriever instance.
+ """
+ texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
+ return cls.from_texts(
+ texts=texts,
+ bm25_params=bm25_params,
+ metadatas=metadatas,
+ preprocess_func=preprocess_func,
+ **kwargs,
+ )
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ processed_query = self.preprocess_func(query)
+ return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
+ return return_docs
diff --git a/libs/community/langchain_community/retrievers/chaindesk.py b/libs/community/langchain_community/retrievers/chaindesk.py
new file mode 100644
index 00000000000..4c8aa2c582b
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/chaindesk.py
@@ -0,0 +1,92 @@
+from typing import Any, List, Optional
+
+import aiohttp
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForRetrieverRun,
+ CallbackManagerForRetrieverRun,
+)
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class ChaindeskRetriever(BaseRetriever):
+ """`Chaindesk API` retriever."""
+
+ datastore_url: str
+ top_k: Optional[int]
+ api_key: Optional[str]
+
+ def __init__(
+ self,
+ datastore_url: str,
+ top_k: Optional[int] = None,
+ api_key: Optional[str] = None,
+ ):
+ self.datastore_url = datastore_url
+ self.api_key = api_key
+ self.top_k = top_k
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun,
+ **kwargs: Any,
+ ) -> List[Document]:
+ response = requests.post(
+ self.datastore_url,
+ json={
+ "query": query,
+ **({"topK": self.top_k} if self.top_k is not None else {}),
+ },
+ headers={
+ "Content-Type": "application/json",
+ **(
+ {"Authorization": f"Bearer {self.api_key}"}
+ if self.api_key is not None
+ else {}
+ ),
+ },
+ )
+ data = response.json()
+ return [
+ Document(
+ page_content=r["text"],
+ metadata={"source": r["source"], "score": r["score"]},
+ )
+ for r in data["results"]
+ ]
+
+ async def _aget_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: AsyncCallbackManagerForRetrieverRun,
+ **kwargs: Any,
+ ) -> List[Document]:
+ async with aiohttp.ClientSession() as session:
+ async with session.request(
+ "POST",
+ self.datastore_url,
+ json={
+ "query": query,
+ **({"topK": self.top_k} if self.top_k is not None else {}),
+ },
+ headers={
+ "Content-Type": "application/json",
+ **(
+ {"Authorization": f"Bearer {self.api_key}"}
+ if self.api_key is not None
+ else {}
+ ),
+ },
+ ) as response:
+ data = await response.json()
+ return [
+ Document(
+ page_content=r["text"],
+ metadata={"source": r["source"], "score": r["score"]},
+ )
+ for r in data["results"]
+ ]
diff --git a/libs/community/langchain_community/retrievers/chatgpt_plugin_retriever.py b/libs/community/langchain_community/retrievers/chatgpt_plugin_retriever.py
new file mode 100644
index 00000000000..13591336fc4
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/chatgpt_plugin_retriever.py
@@ -0,0 +1,90 @@
+from __future__ import annotations
+
+from typing import List, Optional
+
+import aiohttp
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForRetrieverRun,
+ CallbackManagerForRetrieverRun,
+)
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class ChatGPTPluginRetriever(BaseRetriever):
+ """`ChatGPT plugin` retriever."""
+
+ url: str
+ """URL of the ChatGPT plugin."""
+ bearer_token: str
+ """Bearer token for the ChatGPT plugin."""
+ top_k: int = 3
+ """Number of documents to return."""
+ filter: Optional[dict] = None
+ """Filter to apply to the results."""
+ aiosession: Optional[aiohttp.ClientSession] = None
+ """Aiohttp session to use for requests."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+ """Allow arbitrary types."""
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ url, json, headers = self._create_request(query)
+ response = requests.post(url, json=json, headers=headers)
+ results = response.json()["results"][0]["results"]
+ docs = []
+ for d in results:
+ content = d.pop("text")
+ metadata = d.pop("metadata", d)
+ if metadata.get("source_id"):
+ metadata["source"] = metadata.pop("source_id")
+ docs.append(Document(page_content=content, metadata=metadata))
+ return docs
+
+ async def _aget_relevant_documents(
+ self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ url, json, headers = self._create_request(query)
+
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(url, headers=headers, json=json) as response:
+ res = await response.json()
+ else:
+ async with self.aiosession.post(
+ url, headers=headers, json=json
+ ) as response:
+ res = await response.json()
+
+ results = res["results"][0]["results"]
+ docs = []
+ for d in results:
+ content = d.pop("text")
+ metadata = d.pop("metadata", d)
+ if metadata.get("source_id"):
+ metadata["source"] = metadata.pop("source_id")
+ docs.append(Document(page_content=content, metadata=metadata))
+ return docs
+
+ def _create_request(self, query: str) -> tuple[str, dict, dict]:
+ url = f"{self.url}/query"
+ json = {
+ "queries": [
+ {
+ "query": query,
+ "filter": self.filter,
+ "top_k": self.top_k,
+ }
+ ]
+ }
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self.bearer_token}",
+ }
+ return url, json, headers
diff --git a/libs/community/langchain_community/retrievers/cohere_rag_retriever.py b/libs/community/langchain_community/retrievers/cohere_rag_retriever.py
new file mode 100644
index 00000000000..39dcc30f3b4
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/cohere_rag_retriever.py
@@ -0,0 +1,88 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForRetrieverRun,
+ CallbackManagerForRetrieverRun,
+)
+from langchain_core.documents import Document
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import HumanMessage
+from langchain_core.pydantic_v1 import Field
+from langchain_core.retrievers import BaseRetriever
+
+if TYPE_CHECKING:
+ from langchain_core.messages import BaseMessage
+
+
+def _get_docs(response: Any) -> List[Document]:
+ docs = [
+ Document(page_content=doc["snippet"], metadata=doc)
+ for doc in response.generation_info["documents"]
+ ]
+ docs.append(
+ Document(
+ page_content=response.message.content,
+ metadata={
+ "type": "model_response",
+ "citations": response.generation_info["citations"],
+ "search_results": response.generation_info["search_results"],
+ "search_queries": response.generation_info["search_queries"],
+ "token_count": response.generation_info["token_count"],
+ },
+ )
+ )
+ return docs
+
+
+class CohereRagRetriever(BaseRetriever):
+ """Cohere Chat API with RAG."""
+
+ connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
+ """
+ When specified, the model's reply will be enriched with information found by
+ querying each of the connectors (RAG). These will be returned as langchain
+ documents.
+
+ Currently only accepts {"id": "web-search"}.
+ """
+
+ llm: BaseChatModel
+ """Cohere ChatModel to use."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+ """Allow arbitrary types."""
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
+ ) -> List[Document]:
+ messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
+ res = self.llm.generate(
+ messages,
+ connectors=self.connectors,
+ callbacks=run_manager.get_child(),
+ **kwargs,
+ ).generations[0][0]
+ return _get_docs(res)
+
+ async def _aget_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: AsyncCallbackManagerForRetrieverRun,
+ **kwargs: Any,
+ ) -> List[Document]:
+ messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
+ res = (
+ await self.llm.agenerate(
+ messages,
+ connectors=self.connectors,
+ callbacks=run_manager.get_child(),
+ **kwargs,
+ )
+ ).generations[0][0]
+ return _get_docs(res)
diff --git a/libs/community/langchain_community/retrievers/databerry.py b/libs/community/langchain_community/retrievers/databerry.py
new file mode 100644
index 00000000000..c1ea6277004
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/databerry.py
@@ -0,0 +1,74 @@
+from typing import List, Optional
+
+import aiohttp
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForRetrieverRun,
+ CallbackManagerForRetrieverRun,
+)
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class DataberryRetriever(BaseRetriever):
+ """`Databerry API` retriever."""
+
+ datastore_url: str
+ top_k: Optional[int]
+ api_key: Optional[str]
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ response = requests.post(
+ self.datastore_url,
+ json={
+ "query": query,
+ **({"topK": self.top_k} if self.top_k is not None else {}),
+ },
+ headers={
+ "Content-Type": "application/json",
+ **(
+ {"Authorization": f"Bearer {self.api_key}"}
+ if self.api_key is not None
+ else {}
+ ),
+ },
+ )
+ data = response.json()
+ return [
+ Document(
+ page_content=r["text"],
+ metadata={"source": r["source"], "score": r["score"]},
+ )
+ for r in data["results"]
+ ]
+
+ async def _aget_relevant_documents(
+ self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ async with aiohttp.ClientSession() as session:
+ async with session.request(
+ "POST",
+ self.datastore_url,
+ json={
+ "query": query,
+ **({"topK": self.top_k} if self.top_k is not None else {}),
+ },
+ headers={
+ "Content-Type": "application/json",
+ **(
+ {"Authorization": f"Bearer {self.api_key}"}
+ if self.api_key is not None
+ else {}
+ ),
+ },
+ ) as response:
+ data = await response.json()
+ return [
+ Document(
+ page_content=r["text"],
+ metadata={"source": r["source"], "score": r["score"]},
+ )
+ for r in data["results"]
+ ]
diff --git a/libs/community/langchain_community/retrievers/docarray.py b/libs/community/langchain_community/retrievers/docarray.py
new file mode 100644
index 00000000000..bb3db872073
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/docarray.py
@@ -0,0 +1,207 @@
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.retrievers import BaseRetriever
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+
+class SearchType(str, Enum):
+ """Enumerator of the types of search to perform."""
+
+ similarity = "similarity"
+ mmr = "mmr"
+
+
+class DocArrayRetriever(BaseRetriever):
+ """`DocArray Document Indices` retriever.
+
+ Currently, it supports 5 backends:
+ InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
+ ElasticDocIndex, and WeaviateDocumentIndex.
+
+ Args:
+ index: One of the above-mentioned index instances
+ embeddings: Embedding model to represent text as vectors
+ search_field: Field to consider for searching in the documents.
+ Should be an embedding/vector/tensor.
+ content_field: Field that represents the main content in your document schema.
+ Will be used as a `page_content`. Everything else will go into `metadata`.
+ search_type: Type of search to perform (similarity / mmr)
+ filters: Filters applied for document retrieval.
+ top_k: Number of documents to return
+ """
+
+ index: Any
+ embeddings: Embeddings
+ search_field: str
+ content_field: str
+ search_type: SearchType = SearchType.similarity
+ top_k: int = 1
+ filters: Optional[Any] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun,
+ ) -> List[Document]:
+ """Get documents relevant for a query.
+
+ Args:
+ query: string to find relevant documents for
+
+ Returns:
+ List of relevant documents
+ """
+ query_emb = np.array(self.embeddings.embed_query(query))
+
+ if self.search_type == SearchType.similarity:
+ results = self._similarity_search(query_emb)
+ elif self.search_type == SearchType.mmr:
+ results = self._mmr_search(query_emb)
+ else:
+ raise ValueError(
+ f"Search type {self.search_type} does not exist. "
+ f"Choose either 'similarity' or 'mmr'."
+ )
+
+ return results
+
+ def _search(
+ self, query_emb: np.ndarray, top_k: int
+ ) -> List[Union[Dict[str, Any], Any]]:
+ """
+ Perform a search using the query embedding and return top_k documents.
+
+ Args:
+ query_emb: Query represented as an embedding
+ top_k: Number of documents to return
+
+ Returns:
+ A list of top_k documents matching the query
+ """
+
+ from docarray.index import ElasticDocIndex, WeaviateDocumentIndex
+
+ filter_args = {}
+ search_field = self.search_field
+ if isinstance(self.index, WeaviateDocumentIndex):
+ filter_args["where_filter"] = self.filters
+ search_field = ""
+ elif isinstance(self.index, ElasticDocIndex):
+ filter_args["query"] = self.filters
+ else:
+ filter_args["filter_query"] = self.filters
+
+ if self.filters:
+ query = (
+ self.index.build_query() # get empty query object
+ .find(
+ query=query_emb, search_field=search_field
+ ) # add vector similarity search
+ .filter(**filter_args) # add filter search
+ .build(limit=top_k) # build the query
+ )
+ # execute the combined query and return the results
+ docs = self.index.execute_query(query)
+ if hasattr(docs, "documents"):
+ docs = docs.documents
+ docs = docs[:top_k]
+ else:
+ docs = self.index.find(
+ query=query_emb, search_field=search_field, limit=top_k
+ ).documents
+ return docs
+
+ def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
+ """
+ Perform a similarity search.
+
+ Args:
+ query_emb: Query represented as an embedding
+
+ Returns:
+ A list of documents most similar to the query
+ """
+ docs = self._search(query_emb=query_emb, top_k=self.top_k)
+ results = [self._docarray_to_langchain_doc(doc) for doc in docs]
+ return results
+
+ def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
+ """
+ Perform a maximal marginal relevance (mmr) search.
+
+ Args:
+ query_emb: Query represented as an embedding
+
+ Returns:
+ A list of diverse documents related to the query
+ """
+ docs = self._search(query_emb=query_emb, top_k=20)
+
+ mmr_selected = maximal_marginal_relevance(
+ query_emb,
+ [
+ doc[self.search_field]
+ if isinstance(doc, dict)
+ else getattr(doc, self.search_field)
+ for doc in docs
+ ],
+ k=self.top_k,
+ )
+ results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
+ return results
+
+ def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
+ """
+ Convert a DocArray document (which also might be a dict)
+ to a langchain document format.
+
+ DocArray document can contain arbitrary fields, so the mapping is done
+ in the following way:
+
+ page_content <-> content_field
+ metadata <-> all other fields excluding
+ tensors and embeddings (so float, int, string)
+
+ Args:
+ doc: DocArray document
+
+ Returns:
+ Document in langchain format
+
+ Raises:
+ ValueError: If the document doesn't contain the content field
+ """
+
+ fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
+
+ if self.content_field not in fields:
+ raise ValueError(
+ f"Document does not contain the content field - {self.content_field}."
+ )
+ lc_doc = Document(
+ page_content=doc[self.content_field]
+ if isinstance(doc, dict)
+ else getattr(doc, self.content_field)
+ )
+
+ for name in fields:
+ value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
+ if (
+ isinstance(value, (str, int, float, bool))
+ and name != self.content_field
+ ):
+ lc_doc.metadata[name] = value
+
+ return lc_doc
diff --git a/libs/community/langchain_community/retrievers/elastic_search_bm25.py b/libs/community/langchain_community/retrievers/elastic_search_bm25.py
new file mode 100644
index 00000000000..a4f1f539d28
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/elastic_search_bm25.py
@@ -0,0 +1,137 @@
+"""Wrapper around Elasticsearch vector database."""
+
+from __future__ import annotations
+
+import uuid
+from typing import Any, Iterable, List
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class ElasticSearchBM25Retriever(BaseRetriever):
+ """`Elasticsearch` retriever that uses `BM25`.
+
+ To connect to an Elasticsearch instance that requires login credentials,
+ including Elastic Cloud, use the Elasticsearch URL format
+ https://username:password@es_host:9243. For example, to connect to Elastic
+ Cloud, create the Elasticsearch URL with the required authentication details and
+ pass it to the ElasticVectorSearch constructor as the named parameter
+ elasticsearch_url.
+
+ You can obtain your Elastic Cloud URL and login credentials by logging in to the
+ Elastic Cloud console at https://cloud.elastic.co, selecting your deployment, and
+ navigating to the "Deployments" page.
+
+ To obtain your Elastic Cloud password for the default "elastic" user:
+
+ 1. Log in to the Elastic Cloud console at https://cloud.elastic.co
+ 2. Go to "Security" > "Users"
+ 3. Locate the "elastic" user and click "Edit"
+ 4. Click "Reset password"
+ 5. Follow the prompts to reset the password
+
+ The format for Elastic Cloud URLs is
+ https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
+ """
+
+ client: Any
+ """Elasticsearch client."""
+ index_name: str
+ """Name of the index to use in Elasticsearch."""
+
+ @classmethod
+ def create(
+ cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75
+ ) -> ElasticSearchBM25Retriever:
+ """
+ Create a ElasticSearchBM25Retriever from a list of texts.
+
+ Args:
+ elasticsearch_url: URL of the Elasticsearch instance to connect to.
+ index_name: Name of the index to use in Elasticsearch.
+ k1: BM25 parameter k1.
+ b: BM25 parameter b.
+
+ Returns:
+
+ """
+ from elasticsearch import Elasticsearch
+
+ # Create an Elasticsearch client instance
+ es = Elasticsearch(elasticsearch_url)
+
+ # Define the index settings and mappings
+ settings = {
+ "analysis": {"analyzer": {"default": {"type": "standard"}}},
+ "similarity": {
+ "custom_bm25": {
+ "type": "BM25",
+ "k1": k1,
+ "b": b,
+ }
+ },
+ }
+ mappings = {
+ "properties": {
+ "content": {
+ "type": "text",
+ "similarity": "custom_bm25", # Use the custom BM25 similarity
+ }
+ }
+ }
+
+ # Create the index with the specified settings and mappings
+ es.indices.create(index=index_name, mappings=mappings, settings=settings)
+ return cls(client=es, index_name=index_name)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ refresh_indices: bool = True,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the retriever.
+
+ Args:
+ texts: Iterable of strings to add to the retriever.
+ refresh_indices: bool to refresh ElasticSearch indices
+
+ Returns:
+ List of ids from adding the texts into the retriever.
+ """
+ try:
+ from elasticsearch.helpers import bulk
+ except ImportError:
+ raise ValueError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+ requests = []
+ ids = []
+ for i, text in enumerate(texts):
+ _id = str(uuid.uuid4())
+ request = {
+ "_op_type": "index",
+ "_index": self.index_name,
+ "content": text,
+ "_id": _id,
+ }
+ ids.append(_id)
+ requests.append(request)
+ bulk(self.client, requests)
+
+ if refresh_indices:
+ self.client.indices.refresh(index=self.index_name)
+ return ids
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ query_dict = {"query": {"match": {"content": query}}}
+ res = self.client.search(index=self.index_name, body=query_dict)
+
+ docs = []
+ for r in res["hits"]["hits"]:
+ docs.append(Document(page_content=r["_source"]["content"]))
+ return docs
diff --git a/libs/community/langchain_community/retrievers/embedchain.py b/libs/community/langchain_community/retrievers/embedchain.py
new file mode 100644
index 00000000000..0fae1b0f65a
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/embedchain.py
@@ -0,0 +1,71 @@
+"""Wrapper around Embedchain Retriever."""
+
+from __future__ import annotations
+
+from typing import Any, Iterable, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class EmbedchainRetriever(BaseRetriever):
+ """`Embedchain` retriever."""
+
+ client: Any
+ """Embedchain Pipeline."""
+
+ @classmethod
+ def create(cls, yaml_path: Optional[str] = None) -> EmbedchainRetriever:
+ """
+ Create a EmbedchainRetriever from a YAML configuration file.
+
+ Args:
+ yaml_path: Path to the YAML configuration file. If not provided,
+ a default configuration is used.
+
+ Returns:
+ An instance of EmbedchainRetriever.
+
+ """
+ from embedchain import Pipeline
+
+ # Create an Embedchain Pipeline instance
+ if yaml_path:
+ client = Pipeline.from_config(yaml_path=yaml_path)
+ else:
+ client = Pipeline()
+ return cls(client=client)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the retriever.
+
+ Args:
+ texts: Iterable of strings/URLs to add to the retriever.
+
+ Returns:
+ List of ids from adding the texts into the retriever.
+ """
+ ids = []
+ for text in texts:
+ _id = self.client.add(text)
+ ids.append(_id)
+ return ids
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ res = self.client.search(query)
+
+ docs = []
+ for r in res:
+ docs.append(
+ Document(
+ page_content=r["context"],
+ metadata={"source": r["source"], "document_id": r["document_id"]},
+ )
+ )
+ return docs
diff --git a/libs/community/langchain_community/retrievers/google_cloud_documentai_warehouse.py b/libs/community/langchain_community/retrievers/google_cloud_documentai_warehouse.py
new file mode 100644
index 00000000000..ce71a2db1d4
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/google_cloud_documentai_warehouse.py
@@ -0,0 +1,120 @@
+"""Retriever wrapper for Google Cloud Document AI Warehouse."""
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.utilities.vertexai import get_client_info
+
+if TYPE_CHECKING:
+ from google.cloud.contentwarehouse_v1 import (
+ DocumentServiceClient,
+ RequestMetadata,
+ SearchDocumentsRequest,
+ )
+ from google.cloud.contentwarehouse_v1.services.document_service.pagers import (
+ SearchDocumentsPager,
+ )
+
+
+class GoogleDocumentAIWarehouseRetriever(BaseRetriever):
+ """A retriever based on Document AI Warehouse.
+
+ Documents should be created and documents should be uploaded
+ in a separate flow, and this retriever uses only Document AI
+ schema_id provided to search for revelant documents.
+
+ More info: https://cloud.google.com/document-ai-warehouse.
+ """
+
+ location: str = "us"
+ """Google Cloud location where Document AI Warehouse is placed."""
+ project_number: str
+ """Google Cloud project number, should contain digits only."""
+ schema_id: Optional[str] = None
+ """Document AI Warehouse schema to query against.
+ If nothing is provided, all documents in the project will be searched."""
+ qa_size_limit: int = 5
+ """The limit on the number of documents returned."""
+ client: "DocumentServiceClient" = None #: :meta private:
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validates the environment."""
+ try: # noqa: F401
+ from google.cloud.contentwarehouse_v1 import DocumentServiceClient
+ except ImportError as exc:
+ raise ImportError(
+ "google.cloud.contentwarehouse is not installed."
+ "Please install it with pip install google-cloud-contentwarehouse"
+ ) from exc
+
+ values["project_number"] = get_from_dict_or_env(
+ values, "project_number", "PROJECT_NUMBER"
+ )
+ values["client"] = DocumentServiceClient(
+ client_info=get_client_info(module="document-ai-warehouse")
+ )
+ return values
+
+ def _prepare_request_metadata(self, user_ldap: str) -> "RequestMetadata":
+ from google.cloud.contentwarehouse_v1 import RequestMetadata, UserInfo
+
+ user_info = UserInfo(id=f"user:{user_ldap}")
+ return RequestMetadata(user_info=user_info)
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
+ ) -> List[Document]:
+ request = self._prepare_search_request(query, **kwargs)
+ response = self.client.search_documents(request=request)
+ return self._parse_search_response(response=response)
+
+ def _prepare_search_request(
+ self, query: str, **kwargs: Any
+ ) -> "SearchDocumentsRequest":
+ from google.cloud.contentwarehouse_v1 import (
+ DocumentQuery,
+ SearchDocumentsRequest,
+ )
+
+ try:
+ user_ldap = kwargs["user_ldap"]
+ except KeyError:
+ raise ValueError("Argument user_ldap should be provided!")
+
+ request_metadata = self._prepare_request_metadata(user_ldap=user_ldap)
+ schemas = []
+ if self.schema_id:
+ schemas.append(
+ self.client.document_schema_path(
+ project=self.project_number,
+ location=self.location,
+ document_schema=self.schema_id,
+ )
+ )
+ return SearchDocumentsRequest(
+ parent=self.client.common_location_path(self.project_number, self.location),
+ request_metadata=request_metadata,
+ document_query=DocumentQuery(
+ query=query, is_nl_query=True, document_schema_names=schemas
+ ),
+ qa_size_limit=self.qa_size_limit,
+ )
+
+ def _parse_search_response(
+ self, response: "SearchDocumentsPager"
+ ) -> List[Document]:
+ documents = []
+ for doc in response.matching_documents:
+ metadata = {
+ "title": doc.document.title,
+ "source": doc.document.raw_document_path,
+ }
+ documents.append(
+ Document(page_content=doc.search_text_snippet, metadata=metadata)
+ )
+ return documents
diff --git a/libs/community/langchain_community/retrievers/google_vertex_ai_search.py b/libs/community/langchain_community/retrievers/google_vertex_ai_search.py
new file mode 100644
index 00000000000..2ed46927014
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/google_vertex_ai_search.py
@@ -0,0 +1,469 @@
+"""Retriever wrapper for Google Vertex AI Search."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.utilities.vertexai import get_client_info
+
+if TYPE_CHECKING:
+ from google.api_core.client_options import ClientOptions
+ from google.cloud.discoveryengine_v1beta import (
+ ConversationalSearchServiceClient,
+ SearchRequest,
+ SearchResult,
+ SearchServiceClient,
+ )
+
+
+class _BaseGoogleVertexAISearchRetriever(BaseModel):
+ project_id: str
+ """Google Cloud Project ID."""
+ data_store_id: str
+ """Vertex AI Search data store ID."""
+ location_id: str = "global"
+ """Vertex AI Search data store location."""
+ serving_config_id: str = "default_config"
+ """Vertex AI Search serving config ID."""
+ credentials: Any = None
+ """The default custom credentials (google.auth.credentials.Credentials) to use
+ when making API calls. If not provided, credentials will be ascertained from
+ the environment."""
+ engine_data_type: int = Field(default=0, ge=0, le=2)
+ """ Defines the Vertex AI Search data type
+ 0 - Unstructured data
+ 1 - Structured data
+ 2 - Website data
+ """
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validates the environment."""
+ try:
+ from google.cloud import discoveryengine_v1beta # noqa: F401
+ except ImportError as exc:
+ raise ImportError(
+ "google.cloud.discoveryengine is not installed."
+ "Please install it with pip install "
+ "google-cloud-discoveryengine>=0.11.0"
+ ) from exc
+ try:
+ from google.api_core.exceptions import InvalidArgument # noqa: F401
+ except ImportError as exc:
+ raise ImportError(
+ "google.api_core.exceptions is not installed. "
+ "Please install it with pip install google-api-core"
+ ) from exc
+
+ values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")
+
+ try:
+ # For backwards compatibility
+ search_engine_id = get_from_dict_or_env(
+ values, "search_engine_id", "SEARCH_ENGINE_ID"
+ )
+
+ if search_engine_id:
+ import warnings
+
+ warnings.warn(
+ "The `search_engine_id` parameter is deprecated. Use `data_store_id` instead.", # noqa: E501
+ DeprecationWarning,
+ )
+ values["data_store_id"] = search_engine_id
+ except: # noqa: E722
+ pass
+
+ values["data_store_id"] = get_from_dict_or_env(
+ values, "data_store_id", "DATA_STORE_ID"
+ )
+
+ return values
+
+ @property
+ def client_options(self) -> "ClientOptions":
+ from google.api_core.client_options import ClientOptions
+
+ return ClientOptions(
+ api_endpoint=f"{self.location_id}-discoveryengine.googleapis.com"
+ if self.location_id != "global"
+ else None
+ )
+
+ def _convert_structured_search_response(
+ self, results: Sequence[SearchResult]
+ ) -> List[Document]:
+ """Converts a sequence of search results to a list of LangChain documents."""
+ import json
+
+ from google.protobuf.json_format import MessageToDict
+
+ documents: List[Document] = []
+
+ for result in results:
+ document_dict = MessageToDict(
+ result.document._pb, preserving_proto_field_name=True
+ )
+
+ documents.append(
+ Document(
+ page_content=json.dumps(document_dict.get("struct_data", {})),
+ metadata={"id": document_dict["id"], "name": document_dict["name"]},
+ )
+ )
+
+ return documents
+
+ def _convert_unstructured_search_response(
+ self, results: Sequence[SearchResult], chunk_type: str
+ ) -> List[Document]:
+ """Converts a sequence of search results to a list of LangChain documents."""
+ from google.protobuf.json_format import MessageToDict
+
+ documents: List[Document] = []
+
+ for result in results:
+ document_dict = MessageToDict(
+ result.document._pb, preserving_proto_field_name=True
+ )
+ derived_struct_data = document_dict.get("derived_struct_data")
+ if not derived_struct_data:
+ continue
+
+ doc_metadata = document_dict.get("struct_data", {})
+ doc_metadata["id"] = document_dict["id"]
+
+ if chunk_type not in derived_struct_data:
+ continue
+
+ for chunk in derived_struct_data[chunk_type]:
+ doc_metadata["source"] = derived_struct_data.get("link", "")
+
+ if chunk_type == "extractive_answers":
+ doc_metadata["source"] += f":{chunk.get('pageNumber', '')}"
+
+ documents.append(
+ Document(
+ page_content=chunk.get("content", ""), metadata=doc_metadata
+ )
+ )
+
+ return documents
+
+ def _convert_website_search_response(
+ self, results: Sequence[SearchResult], chunk_type: str
+ ) -> List[Document]:
+ """Converts a sequence of search results to a list of LangChain documents."""
+ from google.protobuf.json_format import MessageToDict
+
+ documents: List[Document] = []
+ chunk_type = "extractive_answers"
+
+ for result in results:
+ document_dict = MessageToDict(
+ result.document._pb, preserving_proto_field_name=True
+ )
+ derived_struct_data = document_dict.get("derived_struct_data")
+ if not derived_struct_data:
+ continue
+
+ doc_metadata = document_dict.get("struct_data", {})
+ doc_metadata["id"] = document_dict["id"]
+ doc_metadata["source"] = derived_struct_data.get("link", "")
+
+ if chunk_type not in derived_struct_data:
+ continue
+
+ text_field = "snippet" if chunk_type == "snippets" else "content"
+
+ for chunk in derived_struct_data[chunk_type]:
+ documents.append(
+ Document(
+ page_content=chunk.get(text_field, ""), metadata=doc_metadata
+ )
+ )
+
+ if not documents:
+ print(f"No {chunk_type} could be found.")
+ if chunk_type == "extractive_answers":
+ print(
+ "Make sure that your data store is using Advanced Website "
+ "Indexing.\n"
+ "https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing" # noqa: E501
+ )
+
+ return documents
+
+
+class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
+ """`Google Vertex AI Search` retriever.
+
+ For a detailed explanation of the Vertex AI Search concepts
+ and configuration parameters, refer to the product documentation.
+ https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
+ """
+
+ filter: Optional[str] = None
+ """Filter expression."""
+ get_extractive_answers: bool = False
+ """If True return Extractive Answers, otherwise return Extractive Segments or Snippets.""" # noqa: E501
+ max_documents: int = Field(default=5, ge=1, le=100)
+ """The maximum number of documents to return."""
+ max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
+ """The maximum number of extractive answers returned in each search result.
+ At most 5 answers will be returned for each SearchResult.
+ """
+ max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
+ """The maximum number of extractive segments returned in each search result.
+ Currently one segment will be returned for each SearchResult.
+ """
+ query_expansion_condition: int = Field(default=1, ge=0, le=2)
+ """Specification to determine under which conditions query expansion should occur.
+ 0 - Unspecified query expansion condition. In this case, server behavior defaults
+ to disabled
+ 1 - Disabled query expansion. Only the exact search query is used, even if
+ SearchResponse.total_size is zero.
+ 2 - Automatic query expansion built by the Search API.
+ """
+ spell_correction_mode: int = Field(default=2, ge=0, le=2)
+ """Specification to determine under which conditions query expansion should occur.
+ 0 - Unspecified spell correction mode. In this case, server behavior defaults
+ to auto.
+ 1 - Suggestion only. Search API will try to find a spell suggestion if there is any
+ and put in the `SearchResponse.corrected_query`.
+ The spell suggestion will not be used as the search query.
+ 2 - Automatic spell correction built by the Search API.
+ Search will be based on the corrected query if found.
+ """
+
+ _client: SearchServiceClient
+ _serving_config: str
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.ignore
+ arbitrary_types_allowed = True
+ underscore_attrs_are_private = True
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initializes private fields."""
+ try:
+ from google.cloud.discoveryengine_v1beta import SearchServiceClient
+ except ImportError as exc:
+ raise ImportError(
+ "google.cloud.discoveryengine is not installed."
+ "Please install it with pip install google-cloud-discoveryengine"
+ ) from exc
+
+ super().__init__(**kwargs)
+
+ # For more information, refer to:
+ # https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
+ self._client = SearchServiceClient(
+ credentials=self.credentials,
+ client_options=self.client_options,
+ client_info=get_client_info(module="vertex-ai-search"),
+ )
+
+ self._serving_config = self._client.serving_config_path(
+ project=self.project_id,
+ location=self.location_id,
+ data_store=self.data_store_id,
+ serving_config=self.serving_config_id,
+ )
+
+ def _create_search_request(self, query: str) -> SearchRequest:
+ """Prepares a SearchRequest object."""
+ from google.cloud.discoveryengine_v1beta import SearchRequest
+
+ query_expansion_spec = SearchRequest.QueryExpansionSpec(
+ condition=self.query_expansion_condition,
+ )
+
+ spell_correction_spec = SearchRequest.SpellCorrectionSpec(
+ mode=self.spell_correction_mode
+ )
+
+ if self.engine_data_type == 0:
+ if self.get_extractive_answers:
+ extractive_content_spec = (
+ SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
+ max_extractive_answer_count=self.max_extractive_answer_count,
+ )
+ )
+ else:
+ extractive_content_spec = (
+ SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
+ max_extractive_segment_count=self.max_extractive_segment_count,
+ )
+ )
+ content_search_spec = SearchRequest.ContentSearchSpec(
+ extractive_content_spec=extractive_content_spec
+ )
+ elif self.engine_data_type == 1:
+ content_search_spec = None
+ elif self.engine_data_type == 2:
+ content_search_spec = SearchRequest.ContentSearchSpec(
+ extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
+ max_extractive_answer_count=self.max_extractive_answer_count,
+ ),
+ snippet_spec=SearchRequest.ContentSearchSpec.SnippetSpec(
+ return_snippet=True
+ ),
+ )
+ else:
+ raise NotImplementedError(
+ "Only data store type 0 (Unstructured), 1 (Structured),"
+ "or 2 (Website) are supported currently."
+ + f" Got {self.engine_data_type}"
+ )
+
+ return SearchRequest(
+ query=query,
+ filter=self.filter,
+ serving_config=self._serving_config,
+ page_size=self.max_documents,
+ content_search_spec=content_search_spec,
+ query_expansion_spec=query_expansion_spec,
+ spell_correction_spec=spell_correction_spec,
+ )
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ """Get documents relevant for a query."""
+ from google.api_core.exceptions import InvalidArgument
+
+ search_request = self._create_search_request(query)
+
+ try:
+ response = self._client.search(search_request)
+ except InvalidArgument as exc:
+ raise type(exc)(
+ exc.message
+ + " This might be due to engine_data_type not set correctly."
+ )
+
+ if self.engine_data_type == 0:
+ chunk_type = (
+ "extractive_answers"
+ if self.get_extractive_answers
+ else "extractive_segments"
+ )
+ documents = self._convert_unstructured_search_response(
+ response.results, chunk_type
+ )
+ elif self.engine_data_type == 1:
+ documents = self._convert_structured_search_response(response.results)
+ elif self.engine_data_type == 2:
+ chunk_type = (
+ "extractive_answers" if self.get_extractive_answers else "snippets"
+ )
+ documents = self._convert_website_search_response(
+ response.results, chunk_type
+ )
+ else:
+ raise NotImplementedError(
+ "Only data store type 0 (Unstructured), 1 (Structured),"
+ "or 2 (Website) are supported currently."
+ + f" Got {self.engine_data_type}"
+ )
+
+ return documents
+
+
+class GoogleVertexAIMultiTurnSearchRetriever(
+ BaseRetriever, _BaseGoogleVertexAISearchRetriever
+):
+ """`Google Vertex AI Search` retriever for multi-turn conversations."""
+
+ conversation_id: str = "-"
+ """Vertex AI Search Conversation ID."""
+
+ _client: ConversationalSearchServiceClient
+ _serving_config: str
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.ignore
+ arbitrary_types_allowed = True
+ underscore_attrs_are_private = True
+
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
+ from google.cloud.discoveryengine_v1beta import (
+ ConversationalSearchServiceClient,
+ )
+
+ self._client = ConversationalSearchServiceClient(
+ credentials=self.credentials,
+ client_options=self.client_options,
+ client_info=get_client_info(module="vertex-ai-search"),
+ )
+
+ self._serving_config = self._client.serving_config_path(
+ project=self.project_id,
+ location=self.location_id,
+ data_store=self.data_store_id,
+ serving_config=self.serving_config_id,
+ )
+
+ if self.engine_data_type == 1:
+ raise NotImplementedError(
+ "Data store type 1 (Structured)"
+ "is not currently supported for multi-turn search."
+ + f" Got {self.engine_data_type}"
+ )
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ """Get documents relevant for a query."""
+ from google.cloud.discoveryengine_v1beta import (
+ ConverseConversationRequest,
+ TextInput,
+ )
+
+ request = ConverseConversationRequest(
+ name=self._client.conversation_path(
+ self.project_id,
+ self.location_id,
+ self.data_store_id,
+ self.conversation_id,
+ ),
+ serving_config=self._serving_config,
+ query=TextInput(input=query),
+ )
+ response = self._client.converse_conversation(request)
+
+ if self.engine_data_type == 2:
+ return self._convert_website_search_response(
+ response.search_results, "extractive_answers"
+ )
+
+ return self._convert_unstructured_search_response(
+ response.search_results, "extractive_answers"
+ )
+
+
+class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
+ """`Google Vertex Search API` retriever alias for backwards compatibility.
+ DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
+ """
+
+ def __init__(self, **data: Any):
+ import warnings
+
+ warnings.warn(
+ "GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
+ DeprecationWarning,
+ )
+
+ super().__init__(**data)
diff --git a/libs/community/langchain_community/retrievers/kay.py b/libs/community/langchain_community/retrievers/kay.py
new file mode 100644
index 00000000000..ef594157b1b
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/kay.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+from typing import Any, List
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class KayAiRetriever(BaseRetriever):
+ """
+ Retriever for Kay.ai datasets.
+
+ To work properly, expects you to have KAY_API_KEY env variable set.
+ You can get one for free at https://kay.ai/.
+ """
+
+ client: Any
+ num_contexts: int
+
+ @classmethod
+ def create(
+ cls,
+ dataset_id: str,
+ data_types: List[str],
+ num_contexts: int = 6,
+ ) -> KayAiRetriever:
+ """
+ Create a KayRetriever given a Kay dataset id and a list of datasources.
+
+ Args:
+ dataset_id: A dataset id category in Kay, like "company"
+ data_types: A list of datasources present within a dataset. For
+ "company" the corresponding datasources could be
+ ["10-K", "10-Q", "8-K", "PressRelease"].
+ num_contexts: The number of documents to retrieve on each query.
+ Defaults to 6.
+ """
+ try:
+ from kay.rag.retrievers import KayRetriever
+ except ImportError:
+ raise ImportError(
+ "Could not import kay python package. Please install it with "
+ "`pip install kay`.",
+ )
+
+ client = KayRetriever(dataset_id, data_types)
+ return cls(client=client, num_contexts=num_contexts)
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ ctxs = self.client.query(query=query, num_context=self.num_contexts)
+ docs = []
+ for ctx in ctxs:
+ page_content = ctx.pop("chunk_embed_text", None)
+ if page_content is None:
+ continue
+ docs.append(Document(page_content=page_content, metadata={**ctx}))
+ return docs
diff --git a/libs/community/langchain_community/retrievers/kendra.py b/libs/community/langchain_community/retrievers/kendra.py
new file mode 100644
index 00000000000..ec1554c0abf
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/kendra.py
@@ -0,0 +1,423 @@
+import re
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator, validator
+from langchain_core.retrievers import BaseRetriever
+
+
+def clean_excerpt(excerpt: str) -> str:
+ """Clean an excerpt from Kendra.
+
+ Args:
+ excerpt: The excerpt to clean.
+
+ Returns:
+ The cleaned excerpt.
+
+ """
+ if not excerpt:
+ return excerpt
+ res = re.sub(r"\s+", " ", excerpt).replace("...", "")
+ return res
+
+
+def combined_text(item: "ResultItem") -> str:
+ """Combine a ResultItem title and excerpt into a single string.
+
+ Args:
+ item: the ResultItem of a Kendra search.
+
+ Returns:
+ A combined text of the title and excerpt of the given item.
+
+ """
+ text = ""
+ title = item.get_title()
+ if title:
+ text += f"Document Title: {title}\n"
+ excerpt = clean_excerpt(item.get_excerpt())
+ if excerpt:
+ text += f"Document Excerpt: \n{excerpt}\n"
+ return text
+
+
+DocumentAttributeValueType = Union[str, int, List[str], None]
+"""Possible types of a DocumentAttributeValue.
+
+Dates are also represented as str.
+"""
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class Highlight(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Information that highlights the keywords in the excerpt."""
+
+ BeginOffset: int
+ """The zero-based location in the excerpt where the highlight starts."""
+ EndOffset: int
+ """The zero-based location in the excerpt where the highlight ends."""
+ TopAnswer: Optional[bool]
+ """Indicates whether the result is the best one."""
+ Type: Optional[str]
+ """The highlight type: STANDARD or THESAURUS_SYNONYM."""
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class TextWithHighLights(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Text with highlights."""
+
+ Text: str
+ """The text."""
+ Highlights: Optional[Any]
+ """The highlights."""
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class AdditionalResultAttributeValue( # type: ignore[call-arg]
+ BaseModel, extra=Extra.allow
+):
+ """Value of an additional result attribute."""
+
+ TextWithHighlightsValue: TextWithHighLights
+ """The text with highlights value."""
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class AdditionalResultAttribute(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Additional result attribute."""
+
+ Key: str
+ """The key of the attribute."""
+ ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
+ """The type of the value."""
+ Value: AdditionalResultAttributeValue
+ """The value of the attribute."""
+
+ def get_value_text(self) -> str:
+ return self.Value.TextWithHighlightsValue.Text
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class DocumentAttributeValue(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Value of a document attribute."""
+
+ DateValue: Optional[str]
+ """The date expressed as an ISO 8601 string."""
+ LongValue: Optional[int]
+ """The long value."""
+ StringListValue: Optional[List[str]]
+ """The string list value."""
+ StringValue: Optional[str]
+ """The string value."""
+
+ @property
+ def value(self) -> DocumentAttributeValueType:
+ """The only defined document attribute value or None.
+ According to Amazon Kendra, you can only provide one
+ value for a document attribute.
+ """
+ if self.DateValue:
+ return self.DateValue
+ if self.LongValue:
+ return self.LongValue
+ if self.StringListValue:
+ return self.StringListValue
+ if self.StringValue:
+ return self.StringValue
+
+ return None
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class DocumentAttribute(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """Document attribute."""
+
+ Key: str
+ """The key of the attribute."""
+ Value: DocumentAttributeValue
+ """The value of the attribute."""
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg]
+ """Base class of a result item."""
+
+ Id: Optional[str]
+ """The ID of the relevant result item."""
+ DocumentId: Optional[str]
+ """The document ID."""
+ DocumentURI: Optional[str]
+ """The document URI."""
+ DocumentAttributes: Optional[List[DocumentAttribute]] = []
+ """The document attributes."""
+
+ @abstractmethod
+ def get_title(self) -> str:
+ """Document title."""
+
+ @abstractmethod
+ def get_excerpt(self) -> str:
+ """Document excerpt or passage original content as retrieved by Kendra."""
+
+ def get_additional_metadata(self) -> dict:
+ """Document additional metadata dict.
+ This returns any extra metadata except these:
+ * result_id
+ * document_id
+ * source
+ * title
+ * excerpt
+ * document_attributes
+ """
+ return {}
+
+ def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]:
+ """Document attributes dict."""
+ return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
+
+ def to_doc(
+ self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
+ ) -> Document:
+ """Converts this item to a Document."""
+ page_content = page_content_formatter(self)
+ metadata = self.get_additional_metadata()
+ metadata.update(
+ {
+ "result_id": self.Id,
+ "document_id": self.DocumentId,
+ "source": self.DocumentURI,
+ "title": self.get_title(),
+ "excerpt": self.get_excerpt(),
+ "document_attributes": self.get_document_attributes_dict(),
+ }
+ )
+
+ return Document(page_content=page_content, metadata=metadata)
+
+
+class QueryResultItem(ResultItem):
+ """Query API result item."""
+
+ DocumentTitle: TextWithHighLights
+ """The document title."""
+ FeedbackToken: Optional[str]
+ """Identifies a particular result from a particular query."""
+ Format: Optional[str]
+ """
+ If the Type is ANSWER, then format is either:
+ * TABLE: a table excerpt is returned in TableExcerpt;
+ * TEXT: a text excerpt is returned in DocumentExcerpt.
+ """
+ Type: Optional[str]
+ """Type of result: DOCUMENT or QUESTION_ANSWER or ANSWER"""
+ AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = []
+ """One or more additional attributes associated with the result."""
+ DocumentExcerpt: Optional[TextWithHighLights]
+ """Excerpt of the document text."""
+
+ def get_title(self) -> str:
+ return self.DocumentTitle.Text
+
+ def get_attribute_value(self) -> str:
+ if not self.AdditionalAttributes:
+ return ""
+ if not self.AdditionalAttributes[0]:
+ return ""
+ else:
+ return self.AdditionalAttributes[0].get_value_text()
+
+ def get_excerpt(self) -> str:
+ if (
+ self.AdditionalAttributes
+ and self.AdditionalAttributes[0].Key == "AnswerText"
+ ):
+ excerpt = self.get_attribute_value()
+ elif self.DocumentExcerpt:
+ excerpt = self.DocumentExcerpt.Text
+ else:
+ excerpt = ""
+
+ return excerpt
+
+ def get_additional_metadata(self) -> dict:
+ additional_metadata = {"type": self.Type}
+ return additional_metadata
+
+
+class RetrieveResultItem(ResultItem):
+ """Retrieve API result item."""
+
+ DocumentTitle: Optional[str]
+ """The document title."""
+ Content: Optional[str]
+ """The content of the item."""
+
+ def get_title(self) -> str:
+ return self.DocumentTitle or ""
+
+ def get_excerpt(self) -> str:
+ return self.Content or ""
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class QueryResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """`Amazon Kendra Query API` search result.
+
+ It is composed of:
+ * Relevant suggested answers: either a text excerpt or table excerpt.
+ * Matching FAQs or questions-answer from your FAQ file.
+ * Documents including an excerpt of each document with its title.
+ """
+
+ ResultItems: List[QueryResultItem]
+ """The result items."""
+
+
+# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
+class RetrieveResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
+ """`Amazon Kendra Retrieve API` search result.
+
+ It is composed of:
+ * relevant passages or text excerpts given an input query.
+ """
+
+ QueryId: str
+ """The ID of the query."""
+ ResultItems: List[RetrieveResultItem]
+ """The result items."""
+
+
+class AmazonKendraRetriever(BaseRetriever):
+ """`Amazon Kendra Index` retriever.
+
+ Args:
+ index_id: Kendra index id
+
+ region_name: The aws region e.g., `us-west-2`.
+ Fallsback to AWS_DEFAULT_REGION env variable
+ or region specified in ~/.aws/config.
+
+ credentials_profile_name: The name of the profile in the ~/.aws/credentials
+ or ~/.aws/config files, which has either access keys or role information
+ specified. If not specified, the default credential profile or, if on an
+ EC2 instance, credentials from IMDS will be used.
+
+ top_k: No of results to return
+
+ attribute_filter: Additional filtering of results based on metadata
+ See: https://docs.aws.amazon.com/kendra/latest/APIReference
+
+ page_content_formatter: generates the Document page_content
+ allowing access to all result item attributes. By default, it uses
+ the item's title and excerpt.
+
+ client: boto3 client for Kendra
+
+ user_context: Provides information about the user context
+ See: https://docs.aws.amazon.com/kendra/latest/APIReference
+
+ Example:
+ .. code-block:: python
+
+ retriever = AmazonKendraRetriever(
+ index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03"
+ )
+
+ """
+
+ index_id: str
+ region_name: Optional[str] = None
+ credentials_profile_name: Optional[str] = None
+ top_k: int = 3
+ attribute_filter: Optional[Dict] = None
+ page_content_formatter: Callable[[ResultItem], str] = combined_text
+ client: Any
+ user_context: Optional[Dict] = None
+
+ @validator("top_k")
+ def validate_top_k(cls, value: int) -> int:
+ if value < 0:
+ raise ValueError(f"top_k ({value}) cannot be negative.")
+ return value
+
+ @root_validator(pre=True)
+ def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ if values.get("client") is not None:
+ return values
+
+ try:
+ import boto3
+
+ if values.get("credentials_profile_name"):
+ session = boto3.Session(profile_name=values["credentials_profile_name"])
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ client_params = {}
+ if values.get("region_name"):
+ client_params["region_name"] = values["region_name"]
+
+ values["client"] = session.client("kendra", **client_params)
+
+ return values
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ except Exception as e:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+
+ def _kendra_query(self, query: str) -> Sequence[ResultItem]:
+ kendra_kwargs = {
+ "IndexId": self.index_id,
+ "QueryText": query.strip(),
+ "PageSize": self.top_k,
+ }
+ if self.attribute_filter is not None:
+ kendra_kwargs["AttributeFilter"] = self.attribute_filter
+ if self.user_context is not None:
+ kendra_kwargs["UserContext"] = self.user_context
+
+ response = self.client.retrieve(**kendra_kwargs)
+ r_result = RetrieveResult.parse_obj(response)
+ if r_result.ResultItems:
+ return r_result.ResultItems
+
+ # Retrieve API returned 0 results, fall back to Query API
+ response = self.client.query(**kendra_kwargs)
+ q_result = QueryResult.parse_obj(response)
+ return q_result.ResultItems
+
+ def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
+ top_docs = [
+ item.to_doc(self.page_content_formatter)
+ for item in result_items[: self.top_k]
+ ]
+ return top_docs
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun,
+ ) -> List[Document]:
+ """Run search on Kendra index and get top k documents
+
+ Example:
+ .. code-block:: python
+
+ docs = retriever.get_relevant_documents('This is my query')
+
+ """
+ result_items = self._kendra_query(query)
+ top_k_docs = self._get_top_k_docs(result_items)
+ return top_k_docs
diff --git a/libs/community/langchain_community/retrievers/knn.py b/libs/community/langchain_community/retrievers/knn.py
new file mode 100644
index 00000000000..045d11cc1d3
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/knn.py
@@ -0,0 +1,81 @@
+"""KNN Retriever.
+Largely based on
+https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
+
+from __future__ import annotations
+
+import concurrent.futures
+from typing import Any, List, Optional
+
+import numpy as np
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.retrievers import BaseRetriever
+
+
+def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
+ """
+ Create an index of embeddings for a list of contexts.
+
+ Args:
+ contexts: List of contexts to embed.
+ embeddings: Embeddings model to use.
+
+ Returns:
+ Index of embeddings.
+ """
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ return np.array(list(executor.map(embeddings.embed_query, contexts)))
+
+
+class KNNRetriever(BaseRetriever):
+ """`KNN` retriever."""
+
+ embeddings: Embeddings
+ """Embeddings model to use."""
+ index: Any
+ """Index of embeddings."""
+ texts: List[str]
+ """List of texts to index."""
+ k: int = 4
+ """Number of results to return."""
+ relevancy_threshold: Optional[float] = None
+ """Threshold for relevancy."""
+
+ class Config:
+
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @classmethod
+ def from_texts(
+ cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
+ ) -> KNNRetriever:
+ index = create_index(texts, embeddings)
+ return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ query_embeds = np.array(self.embeddings.embed_query(query))
+ # calc L2 norm
+ index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
+ query_embeds = query_embeds / np.sqrt((query_embeds**2).sum())
+
+ similarities = index_embeds.dot(query_embeds)
+ sorted_ix = np.argsort(-similarities)
+
+ denominator = np.max(similarities) - np.min(similarities) + 1e-6
+ normalized_similarities = (similarities - np.min(similarities)) / denominator
+
+ top_k_results = [
+ Document(page_content=self.texts[row])
+ for row in sorted_ix[0 : self.k]
+ if (
+ self.relevancy_threshold is None
+ or normalized_similarities[row] >= self.relevancy_threshold
+ )
+ ]
+ return top_k_results
diff --git a/libs/community/langchain_community/retrievers/llama_index.py b/libs/community/langchain_community/retrievers/llama_index.py
new file mode 100644
index 00000000000..6063efe5178
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/llama_index.py
@@ -0,0 +1,86 @@
+from typing import Any, Dict, List, cast
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import Field
+from langchain_core.retrievers import BaseRetriever
+
+
+class LlamaIndexRetriever(BaseRetriever):
+ """`LlamaIndex` retriever.
+
+ It is used for the question-answering with sources over
+ an LlamaIndex data structure."""
+
+ index: Any
+ """LlamaIndex index to query."""
+ query_kwargs: Dict = Field(default_factory=dict)
+ """Keyword arguments to pass to the query method."""
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ """Get documents relevant for a query."""
+ try:
+ from llama_index.indices.base import BaseGPTIndex
+ from llama_index.response.schema import Response
+ except ImportError:
+ raise ImportError(
+ "You need to install `pip install llama-index` to use this retriever."
+ )
+ index = cast(BaseGPTIndex, self.index)
+
+ response = index.query(query, response_mode="no_text", **self.query_kwargs)
+ response = cast(Response, response)
+ # parse source nodes
+ docs = []
+ for source_node in response.source_nodes:
+ metadata = source_node.extra_info or {}
+ docs.append(
+ Document(page_content=source_node.source_text, metadata=metadata)
+ )
+ return docs
+
+
+class LlamaIndexGraphRetriever(BaseRetriever):
+ """`LlamaIndex` graph data structure retriever.
+
+ It is used for question-answering with sources over an LlamaIndex
+ graph data structure."""
+
+ graph: Any
+ """LlamaIndex graph to query."""
+ query_configs: List[Dict] = Field(default_factory=list)
+ """List of query configs to pass to the query method."""
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ """Get documents relevant for a query."""
+ try:
+ from llama_index.composability.graph import (
+ QUERY_CONFIG_TYPE,
+ ComposableGraph,
+ )
+ from llama_index.response.schema import Response
+ except ImportError:
+ raise ImportError(
+ "You need to install `pip install llama-index` to use this retriever."
+ )
+ graph = cast(ComposableGraph, self.graph)
+
+ # for now, inject response_mode="no_text" into query configs
+ for query_config in self.query_configs:
+ query_config["response_mode"] = "no_text"
+ query_configs = cast(List[QUERY_CONFIG_TYPE], self.query_configs)
+ response = graph.query(query, query_configs=query_configs)
+ response = cast(Response, response)
+
+ # parse source nodes
+ docs = []
+ for source_node in response.source_nodes:
+ metadata = source_node.extra_info or {}
+ docs.append(
+ Document(page_content=source_node.source_text, metadata=metadata)
+ )
+ return docs
diff --git a/libs/community/langchain_community/retrievers/metal.py b/libs/community/langchain_community/retrievers/metal.py
new file mode 100644
index 00000000000..6eefd8312e3
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/metal.py
@@ -0,0 +1,42 @@
+from typing import Any, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.retrievers import BaseRetriever
+
+
+class MetalRetriever(BaseRetriever):
+ """`Metal API` retriever."""
+
+ client: Any
+ """The Metal client to use."""
+ params: Optional[dict] = None
+ """The parameters to pass to the Metal client."""
+
+ @root_validator(pre=True)
+ def validate_client(cls, values: dict) -> dict:
+ """Validate that the client is of the correct type."""
+ from metal_sdk.metal import Metal
+
+ if "client" in values:
+ client = values["client"]
+ if not isinstance(client, Metal):
+ raise ValueError(
+ "Got unexpected client, should be of type metal_sdk.metal.Metal. "
+ f"Instead, got {type(client)}"
+ )
+
+ values["params"] = values.get("params", {})
+
+ return values
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ results = self.client.search({"text": query}, **self.params)
+ final_results = []
+ for r in results["data"]:
+ metadata = {k: v for k, v in r.items() if k != "text"}
+ final_results.append(Document(page_content=r["text"], metadata=metadata))
+ return final_results
diff --git a/libs/community/langchain_community/retrievers/milvus.py b/libs/community/langchain_community/retrievers/milvus.py
new file mode 100644
index 00000000000..1d759155701
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/milvus.py
@@ -0,0 +1,80 @@
+"""Milvus Retriever"""
+import warnings
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.retrievers import BaseRetriever
+
+from langchain_community.vectorstores.milvus import Milvus
+
+# TODO: Update to MilvusClient + Hybrid Search when available
+
+
+class MilvusRetriever(BaseRetriever):
+ """`Milvus API` retriever."""
+
+ embedding_function: Embeddings
+ collection_name: str = "LangChainCollection"
+ connection_args: Optional[Dict[str, Any]] = None
+ consistency_level: str = "Session"
+ search_params: Optional[dict] = None
+
+ store: Milvus
+ retriever: BaseRetriever
+
+ @root_validator(pre=True)
+ def create_retriever(cls, values: Dict) -> Dict:
+ """Create the Milvus store and retriever."""
+ values["store"] = Milvus(
+ values["embedding_function"],
+ values["collection_name"],
+ values["connection_args"],
+ values["consistency_level"],
+ )
+ values["retriever"] = values["store"].as_retriever(
+ search_kwargs={"param": values["search_params"]}
+ )
+ return values
+
+ def add_texts(
+ self, texts: List[str], metadatas: Optional[List[dict]] = None
+ ) -> None:
+ """Add text to the Milvus store
+
+ Args:
+ texts (List[str]): The text
+ metadatas (List[dict]): Metadata dicts, must line up with existing store
+ """
+ self.store.add_texts(texts, metadatas)
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun,
+ **kwargs: Any,
+ ) -> List[Document]:
+ return self.retriever.get_relevant_documents(
+ query, run_manager=run_manager.get_child(), **kwargs
+ )
+
+
+def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
+ """Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.
+
+ Args:
+ *args:
+ **kwargs:
+
+ Returns:
+ MilvusRetriever
+ """
+ warnings.warn(
+ "MilvusRetreiver will be deprecated in the future. "
+ "Please use MilvusRetriever ('i' before 'e') instead.",
+ DeprecationWarning,
+ )
+ return MilvusRetriever(*args, **kwargs)
diff --git a/libs/community/langchain_community/retrievers/outline.py b/libs/community/langchain_community/retrievers/outline.py
new file mode 100644
index 00000000000..03b1118125d
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/outline.py
@@ -0,0 +1,20 @@
+from typing import List
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+from langchain_community.utilities.outline import OutlineAPIWrapper
+
+
+class OutlineRetriever(BaseRetriever, OutlineAPIWrapper):
+ """Retriever for Outline API.
+
+ It wraps run() to get_relevant_documents().
+ It uses all OutlineAPIWrapper arguments without any change.
+ """
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ return self.run(query=query)
diff --git a/libs/community/langchain_community/retrievers/pinecone_hybrid_search.py b/libs/community/langchain_community/retrievers/pinecone_hybrid_search.py
new file mode 100644
index 00000000000..333d30a0893
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/pinecone_hybrid_search.py
@@ -0,0 +1,180 @@
+"""Taken from: https://docs.pinecone.io/docs/hybrid-search"""
+
+import hashlib
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import Extra, root_validator
+from langchain_core.retrievers import BaseRetriever
+
+
+def hash_text(text: str) -> str:
+ """Hash a text using SHA256.
+
+ Args:
+ text: Text to hash.
+
+ Returns:
+ Hashed text.
+ """
+ return str(hashlib.sha256(text.encode("utf-8")).hexdigest())
+
+
+def create_index(
+ contexts: List[str],
+ index: Any,
+ embeddings: Embeddings,
+ sparse_encoder: Any,
+ ids: Optional[List[str]] = None,
+ metadatas: Optional[List[dict]] = None,
+ namespace: Optional[str] = None,
+) -> None:
+ """Create an index from a list of contexts.
+
+ It modifies the index argument in-place!
+
+ Args:
+ contexts: List of contexts to embed.
+ index: Index to use.
+ embeddings: Embeddings model to use.
+ sparse_encoder: Sparse encoder to use.
+ ids: List of ids to use for the documents.
+ metadatas: List of metadata to use for the documents.
+ """
+ batch_size = 32
+ _iterator = range(0, len(contexts), batch_size)
+ try:
+ from tqdm.auto import tqdm
+
+ _iterator = tqdm(_iterator)
+ except ImportError:
+ pass
+
+ if ids is None:
+ # create unique ids using hash of the text
+ ids = [hash_text(context) for context in contexts]
+
+ for i in _iterator:
+ # find end of batch
+ i_end = min(i + batch_size, len(contexts))
+ # extract batch
+ context_batch = contexts[i:i_end]
+ batch_ids = ids[i:i_end]
+ metadata_batch = (
+ metadatas[i:i_end] if metadatas else [{} for _ in context_batch]
+ )
+ # add context passages as metadata
+ meta = [
+ {"context": context, **metadata}
+ for context, metadata in zip(context_batch, metadata_batch)
+ ]
+
+ # create dense vectors
+ dense_embeds = embeddings.embed_documents(context_batch)
+ # create sparse vectors
+ sparse_embeds = sparse_encoder.encode_documents(context_batch)
+ for s in sparse_embeds:
+ s["values"] = [float(s1) for s1 in s["values"]]
+
+ vectors = []
+ # loop through the data and create dictionaries for upserts
+ for doc_id, sparse, dense, metadata in zip(
+ batch_ids, sparse_embeds, dense_embeds, meta
+ ):
+ vectors.append(
+ {
+ "id": doc_id,
+ "sparse_values": sparse,
+ "values": dense,
+ "metadata": metadata,
+ }
+ )
+
+ # upload the documents to the new hybrid index
+ index.upsert(vectors, namespace=namespace)
+
+
+class PineconeHybridSearchRetriever(BaseRetriever):
+ """`Pinecone Hybrid Search` retriever."""
+
+ embeddings: Embeddings
+ """Embeddings model to use."""
+ """description"""
+ sparse_encoder: Any
+ """Sparse encoder to use."""
+ index: Any
+ """Pinecone index to use."""
+ top_k: int = 4
+ """Number of documents to return."""
+ alpha: float = 0.5
+ """Alpha value for hybrid search."""
+ namespace: Optional[str] = None
+ """Namespace value for index partition."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ def add_texts(
+ self,
+ texts: List[str],
+ ids: Optional[List[str]] = None,
+ metadatas: Optional[List[dict]] = None,
+ namespace: Optional[str] = None,
+ ) -> None:
+ create_index(
+ texts,
+ self.index,
+ self.embeddings,
+ self.sparse_encoder,
+ ids=ids,
+ metadatas=metadatas,
+ namespace=namespace,
+ )
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ try:
+ from pinecone_text.hybrid import hybrid_convex_scale # noqa:F401
+ from pinecone_text.sparse.base_sparse_encoder import (
+ BaseSparseEncoder, # noqa:F401
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import pinecone_text python package. "
+ "Please install it with `pip install pinecone_text`."
+ )
+ return values
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ from pinecone_text.hybrid import hybrid_convex_scale
+
+ sparse_vec = self.sparse_encoder.encode_queries(query)
+ # convert the question into a dense vector
+ dense_vec = self.embeddings.embed_query(query)
+ # scale alpha with hybrid_scale
+ dense_vec, sparse_vec = hybrid_convex_scale(dense_vec, sparse_vec, self.alpha)
+ sparse_vec["values"] = [float(s1) for s1 in sparse_vec["values"]]
+ # query pinecone with the query parameters
+ result = self.index.query(
+ vector=dense_vec,
+ sparse_vector=sparse_vec,
+ top_k=self.top_k,
+ include_metadata=True,
+ namespace=self.namespace,
+ )
+ final_result = []
+ for res in result["matches"]:
+ context = res["metadata"].pop("context")
+ final_result.append(
+ Document(page_content=context, metadata=res["metadata"])
+ )
+ # return search results as json
+ return final_result
diff --git a/libs/community/langchain_community/retrievers/pubmed.py b/libs/community/langchain_community/retrievers/pubmed.py
new file mode 100644
index 00000000000..d68e85b80b0
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/pubmed.py
@@ -0,0 +1,20 @@
+from typing import List
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+from langchain_community.utilities.pubmed import PubMedAPIWrapper
+
+
+class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
+ """`PubMed API` retriever.
+
+ It wraps load() to get_relevant_documents().
+ It uses all PubMedAPIWrapper arguments without any change.
+ """
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ return self.load_docs(query=query)
diff --git a/libs/community/langchain_community/retrievers/pupmed.py b/libs/community/langchain_community/retrievers/pupmed.py
new file mode 100644
index 00000000000..b4318034b27
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/pupmed.py
@@ -0,0 +1,5 @@
+from langchain_community.retrievers.pubmed import PubMedRetriever
+
+__all__ = [
+ "PubMedRetriever",
+]
diff --git a/libs/community/langchain_community/retrievers/remote_retriever.py b/libs/community/langchain_community/retrievers/remote_retriever.py
new file mode 100644
index 00000000000..f384385557b
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/remote_retriever.py
@@ -0,0 +1,56 @@
+from typing import List, Optional
+
+import aiohttp
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForRetrieverRun,
+ CallbackManagerForRetrieverRun,
+)
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class RemoteLangChainRetriever(BaseRetriever):
+ """`LangChain API` retriever."""
+
+ url: str
+ """URL of the remote LangChain API."""
+ headers: Optional[dict] = None
+ """Headers to use for the request."""
+ input_key: str = "message"
+ """Key to use for the input in the request."""
+ response_key: str = "response"
+ """Key to use for the response in the request."""
+ page_content_key: str = "page_content"
+ """Key to use for the page content in the response."""
+ metadata_key: str = "metadata"
+ """Key to use for the metadata in the response."""
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ response = requests.post(
+ self.url, json={self.input_key: query}, headers=self.headers
+ )
+ result = response.json()
+ return [
+ Document(
+ page_content=r[self.page_content_key], metadata=r[self.metadata_key]
+ )
+ for r in result[self.response_key]
+ ]
+
+ async def _aget_relevant_documents(
+ self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ async with aiohttp.ClientSession() as session:
+ async with session.request(
+ "POST", self.url, headers=self.headers, json={self.input_key: query}
+ ) as response:
+ result = await response.json()
+ return [
+ Document(
+ page_content=r[self.page_content_key], metadata=r[self.metadata_key]
+ )
+ for r in result[self.response_key]
+ ]
diff --git a/libs/community/langchain_community/retrievers/svm.py b/libs/community/langchain_community/retrievers/svm.py
new file mode 100644
index 00000000000..96753ecac19
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/svm.py
@@ -0,0 +1,128 @@
+from __future__ import annotations
+
+import concurrent.futures
+from typing import Any, Iterable, List, Optional
+
+import numpy as np
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.retrievers import BaseRetriever
+
+
+def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
+ """
+ Create an index of embeddings for a list of contexts.
+
+ Args:
+ contexts: List of contexts to embed.
+ embeddings: Embeddings model to use.
+
+ Returns:
+ Index of embeddings.
+ """
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ return np.array(list(executor.map(embeddings.embed_query, contexts)))
+
+
+class SVMRetriever(BaseRetriever):
+ """`SVM` retriever.
+
+ Largely based on
+ https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb
+ """
+
+ embeddings: Embeddings
+ """Embeddings model to use."""
+ index: Any
+ """Index of embeddings."""
+ texts: List[str]
+ """List of texts to index."""
+ metadatas: Optional[List[dict]] = None
+ """List of metadatas corresponding with each text."""
+ k: int = 4
+ """Number of results to return."""
+ relevancy_threshold: Optional[float] = None
+ """Threshold for relevancy."""
+
+ class Config:
+
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embeddings: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> SVMRetriever:
+ index = create_index(texts, embeddings)
+ return cls(
+ embeddings=embeddings,
+ index=index,
+ texts=texts,
+ metadatas=metadatas,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: Iterable[Document],
+ embeddings: Embeddings,
+ **kwargs: Any,
+ ) -> SVMRetriever:
+ texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
+ return cls.from_texts(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
+ )
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ try:
+ from sklearn import svm
+ except ImportError:
+ raise ImportError(
+ "Could not import scikit-learn, please install with `pip install "
+ "scikit-learn`."
+ )
+
+ query_embeds = np.array(self.embeddings.embed_query(query))
+ x = np.concatenate([query_embeds[None, ...], self.index])
+ y = np.zeros(x.shape[0])
+ y[0] = 1
+
+ clf = svm.LinearSVC(
+ class_weight="balanced", verbose=False, max_iter=10000, tol=1e-6, C=0.1
+ )
+ clf.fit(x, y)
+
+ similarities = clf.decision_function(x)
+ sorted_ix = np.argsort(-similarities)
+
+ # svm.LinearSVC in scikit-learn is non-deterministic.
+ # if a text is the same as a query, there is no guarantee
+ # the query will be in the first index.
+ # this performs a simple swap, this works because anything
+ # left of the 0 should be equivalent.
+ zero_index = np.where(sorted_ix == 0)[0][0]
+ if zero_index != 0:
+ sorted_ix[0], sorted_ix[zero_index] = sorted_ix[zero_index], sorted_ix[0]
+
+ denominator = np.max(similarities) - np.min(similarities) + 1e-6
+ normalized_similarities = (similarities - np.min(similarities)) / denominator
+
+ top_k_results = []
+ for row in sorted_ix[1 : self.k + 1]:
+ if (
+ self.relevancy_threshold is None
+ or normalized_similarities[row] >= self.relevancy_threshold
+ ):
+ metadata = self.metadatas[row - 1] if self.metadatas else {}
+ doc = Document(page_content=self.texts[row - 1], metadata=metadata)
+ top_k_results.append(doc)
+ return top_k_results
diff --git a/libs/community/langchain_community/retrievers/tavily_search_api.py b/libs/community/langchain_community/retrievers/tavily_search_api.py
new file mode 100644
index 00000000000..317b5774ed0
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/tavily_search_api.py
@@ -0,0 +1,84 @@
+import os
+from enum import Enum
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class SearchDepth(Enum):
+ """Search depth as enumerator."""
+
+ BASIC = "basic"
+ ADVANCED = "advanced"
+
+
+class TavilySearchAPIRetriever(BaseRetriever):
+ """Tavily Search API retriever."""
+
+ k: int = 10
+ include_generated_answer: bool = False
+ include_raw_content: bool = False
+ include_images: bool = False
+ search_depth: SearchDepth = SearchDepth.BASIC
+ include_domains: Optional[List[str]] = None
+ exclude_domains: Optional[List[str]] = None
+ kwargs: Optional[Dict[str, Any]] = {}
+ api_key: Optional[str] = None
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ try:
+ from tavily import Client
+ except ImportError:
+ raise ImportError(
+ "Tavily python package not found. "
+ "Please install it with `pip install tavily-python`."
+ )
+
+ tavily = Client(api_key=self.api_key or os.environ["TAVILY_API_KEY"])
+ max_results = self.k if not self.include_generated_answer else self.k - 1
+ response = tavily.search(
+ query=query,
+ max_results=max_results,
+ search_depth=self.search_depth.value,
+ include_answer=self.include_generated_answer,
+ include_domains=self.include_domains,
+ exclude_domains=self.exclude_domains,
+ include_raw_content=self.include_raw_content,
+ include_images=self.include_images,
+ **self.kwargs,
+ )
+ docs = [
+ Document(
+ page_content=result.get("content", "")
+ if not self.include_raw_content
+ else result.get("raw_content", ""),
+ metadata={
+ "title": result.get("title", ""),
+ "source": result.get("url", ""),
+ **{
+ k: v
+ for k, v in result.items()
+ if k not in ("content", "title", "url", "raw_content")
+ },
+ "images": response.get("images"),
+ },
+ )
+ for result in response.get("results")
+ ]
+ if self.include_generated_answer:
+ docs = [
+ Document(
+ page_content=response.get("answer", ""),
+ metadata={
+ "title": "Suggested Answer",
+ "source": "https://tavily.com/",
+ },
+ ),
+ *docs,
+ ]
+
+ return docs
diff --git a/libs/community/langchain_community/retrievers/tfidf.py b/libs/community/langchain_community/retrievers/tfidf.py
new file mode 100644
index 00000000000..9a033265b34
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/tfidf.py
@@ -0,0 +1,127 @@
+from __future__ import annotations
+
+import pickle
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class TFIDFRetriever(BaseRetriever):
+ """`TF-IDF` retriever.
+
+ Largely based on
+ https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb
+ """
+
+ vectorizer: Any
+ """TF-IDF vectorizer."""
+ docs: List[Document]
+ """Documents."""
+ tfidf_array: Any
+ """TF-IDF array."""
+ k: int = 4
+ """Number of documents to return."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: Iterable[str],
+ metadatas: Optional[Iterable[dict]] = None,
+ tfidf_params: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> TFIDFRetriever:
+ try:
+ from sklearn.feature_extraction.text import TfidfVectorizer
+ except ImportError:
+ raise ImportError(
+ "Could not import scikit-learn, please install with `pip install "
+ "scikit-learn`."
+ )
+
+ tfidf_params = tfidf_params or {}
+ vectorizer = TfidfVectorizer(**tfidf_params)
+ tfidf_array = vectorizer.fit_transform(texts)
+ metadatas = metadatas or ({} for _ in texts)
+ docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
+ return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: Iterable[Document],
+ *,
+ tfidf_params: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> TFIDFRetriever:
+ texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
+ return cls.from_texts(
+ texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
+ )
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ from sklearn.metrics.pairwise import cosine_similarity
+
+ query_vec = self.vectorizer.transform(
+ [query]
+ ) # Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
+ results = cosine_similarity(self.tfidf_array, query_vec).reshape(
+ (-1,)
+ ) # Op -- (n_docs,1) -- Cosine Sim with each doc
+ return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
+ return return_docs
+
+ def save_local(
+ self,
+ folder_path: str,
+ file_name: str = "tfidf_vectorizer",
+ ) -> None:
+ try:
+ import joblib
+ except ImportError:
+ raise ImportError(
+ "Could not import joblib, please install with `pip install joblib`."
+ )
+
+ path = Path(folder_path)
+ path.mkdir(exist_ok=True, parents=True)
+
+ # Save vectorizer with joblib dump.
+ joblib.dump(self.vectorizer, path / f"{file_name}.joblib")
+
+ # Save docs and tfidf array as pickle.
+ with open(path / f"{file_name}.pkl", "wb") as f:
+ pickle.dump((self.docs, self.tfidf_array), f)
+
+ @classmethod
+ def load_local(
+ cls,
+ folder_path: str,
+ file_name: str = "tfidf_vectorizer",
+ ) -> TFIDFRetriever:
+ try:
+ import joblib
+ except ImportError:
+ raise ImportError(
+ "Could not import joblib, please install with `pip install joblib`."
+ )
+
+ path = Path(folder_path)
+
+ # Load vectorizer with joblib load.
+ vectorizer = joblib.load(path / f"{file_name}.joblib")
+
+ # Load docs and tfidf array as pickle.
+ with open(path / f"{file_name}.pkl", "rb") as f:
+ docs, tfidf_array = pickle.load(f)
+
+ return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array)
diff --git a/libs/community/langchain_community/retrievers/vespa_retriever.py b/libs/community/langchain_community/retrievers/vespa_retriever.py
new file mode 100644
index 00000000000..6f5eb66aa0e
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/vespa_retriever.py
@@ -0,0 +1,126 @@
+from __future__ import annotations
+
+import json
+from typing import Any, Dict, List, Literal, Optional, Sequence, Union
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class VespaRetriever(BaseRetriever):
+ """`Vespa` retriever."""
+
+ app: Any
+ """Vespa application to query."""
+ body: Dict
+ """Body of the query."""
+ content_field: str
+ """Name of the content field."""
+ metadata_fields: Sequence[str]
+ """Names of the metadata fields."""
+
+ def _query(self, body: Dict) -> List[Document]:
+ response = self.app.query(body)
+
+ if not str(response.status_code).startswith("2"):
+ raise RuntimeError(
+ "Could not retrieve data from Vespa. Error code: {}".format(
+ response.status_code
+ )
+ )
+
+ root = response.json["root"]
+ if "errors" in root:
+ raise RuntimeError(json.dumps(root["errors"]))
+
+ docs = []
+ for child in response.hits:
+ page_content = child["fields"].pop(self.content_field, "")
+ if self.metadata_fields == "*":
+ metadata = child["fields"]
+ else:
+ metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
+ metadata["id"] = child["id"]
+ docs.append(Document(page_content=page_content, metadata=metadata))
+ return docs
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ body = self.body.copy()
+ body["query"] = query
+ return self._query(body)
+
+ def get_relevant_documents_with_filter(
+ self, query: str, *, _filter: Optional[str] = None
+ ) -> List[Document]:
+ body = self.body.copy()
+ _filter = f" and {_filter}" if _filter else ""
+ body["yql"] = body["yql"] + _filter
+ body["query"] = query
+ return self._query(body)
+
+ @classmethod
+ def from_params(
+ cls,
+ url: str,
+ content_field: str,
+ *,
+ k: Optional[int] = None,
+ metadata_fields: Union[Sequence[str], Literal["*"]] = (),
+ sources: Union[Sequence[str], Literal["*"], None] = None,
+ _filter: Optional[str] = None,
+ yql: Optional[str] = None,
+ **kwargs: Any,
+ ) -> VespaRetriever:
+ """Instantiate retriever from params.
+
+ Args:
+ url (str): Vespa app URL.
+ content_field (str): Field in results to return as Document page_content.
+ k (Optional[int]): Number of Documents to return. Defaults to None.
+ metadata_fields(Sequence[str] or "*"): Fields in results to include in
+ document metadata. Defaults to empty tuple ().
+ sources (Sequence[str] or "*" or None): Sources to retrieve
+ from. Defaults to None.
+ _filter (Optional[str]): Document filter condition expressed in YQL.
+ Defaults to None.
+ yql (Optional[str]): Full YQL query to be used. Should not be specified
+ if _filter or sources are specified. Defaults to None.
+ kwargs (Any): Keyword arguments added to query body.
+
+ Returns:
+ VespaRetriever: Instantiated VespaRetriever.
+ """
+ try:
+ from vespa.application import Vespa
+ except ImportError:
+ raise ImportError(
+ "pyvespa is not installed, please install with `pip install pyvespa`"
+ )
+ app = Vespa(url)
+ body = kwargs.copy()
+ if yql and (sources or _filter):
+ raise ValueError(
+ "yql should only be specified if both sources and _filter are not "
+ "specified."
+ )
+ else:
+ if metadata_fields == "*":
+ _fields = "*"
+ body["summary"] = "short"
+ else:
+ _fields = ", ".join([content_field] + list(metadata_fields or []))
+ _sources = ", ".join(sources) if isinstance(sources, Sequence) else "*"
+ _filter = f" and {_filter}" if _filter else ""
+ yql = f"select {_fields} from sources {_sources} where userQuery(){_filter}"
+ body["yql"] = yql
+ if k:
+ body["hits"] = k
+ return cls(
+ app=app,
+ body=body,
+ content_field=content_field,
+ metadata_fields=metadata_fields,
+ )
diff --git a/libs/community/langchain_community/retrievers/weaviate_hybrid_search.py b/libs/community/langchain_community/retrievers/weaviate_hybrid_search.py
new file mode 100644
index 00000000000..b216e27b65a
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/weaviate_hybrid_search.py
@@ -0,0 +1,162 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Optional, cast
+from uuid import uuid4
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.retrievers import BaseRetriever
+
+
+class WeaviateHybridSearchRetriever(BaseRetriever):
+ """`Weaviate hybrid search` retriever.
+
+ See the documentation:
+ https://weaviate.io/blog/hybrid-search-explained
+ """
+
+ client: Any
+ """keyword arguments to pass to the Weaviate client."""
+ index_name: str
+ """The name of the index to use."""
+ text_key: str
+ """The name of the text key to use."""
+ alpha: float = 0.5
+ """The weight of the text key in the hybrid search."""
+ k: int = 4
+ """The number of results to return."""
+ attributes: List[str]
+ """The attributes to return in the results."""
+ create_schema_if_missing: bool = True
+ """Whether to create the schema if it doesn't exist."""
+
+ @root_validator(pre=True)
+ def validate_client(
+ cls,
+ values: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ try:
+ import weaviate
+ except ImportError:
+ raise ImportError(
+ "Could not import weaviate python package. "
+ "Please install it with `pip install weaviate-client`."
+ )
+ if not isinstance(values["client"], weaviate.Client):
+ client = values["client"]
+ raise ValueError(
+ f"client should be an instance of weaviate.Client, got {type(client)}"
+ )
+ if values.get("attributes") is None:
+ values["attributes"] = []
+
+ cast(List, values["attributes"]).append(values["text_key"])
+
+ if values.get("create_schema_if_missing", True):
+ class_obj = {
+ "class": values["index_name"],
+ "properties": [{"name": values["text_key"], "dataType": ["text"]}],
+ "vectorizer": "text2vec-openai",
+ }
+
+ if not values["client"].schema.exists(values["index_name"]):
+ values["client"].schema.create_class(class_obj)
+
+ return values
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ # added text_key
+ def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
+ """Upload documents to Weaviate."""
+ from weaviate.util import get_valid_uuid
+
+ with self.client.batch as batch:
+ ids = []
+ for i, doc in enumerate(docs):
+ metadata = doc.metadata or {}
+ data_properties = {self.text_key: doc.page_content, **metadata}
+
+ # If the UUID of one of the objects already exists
+ # then the existing objectwill be replaced by the new object.
+ if "uuids" in kwargs:
+ _id = kwargs["uuids"][i]
+ else:
+ _id = get_valid_uuid(uuid4())
+
+ batch.add_data_object(data_properties, self.index_name, _id)
+ ids.append(_id)
+ return ids
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun,
+ where_filter: Optional[Dict[str, object]] = None,
+ score: bool = False,
+ hybrid_search_kwargs: Optional[Dict[str, object]] = None,
+ ) -> List[Document]:
+ """Look up similar documents in Weaviate.
+
+ query: The query to search for relevant documents
+ of using weviate hybrid search.
+
+ where_filter: A filter to apply to the query.
+ https://weaviate.io/developers/weaviate/guides/querying/#filtering
+
+ score: Whether to include the score, and score explanation
+ in the returned Documents meta_data.
+
+ hybrid_search_kwargs: Used to pass additional arguments
+ to the .with_hybrid() method.
+ The primary uses cases for this are:
+ 1) Search specific properties only -
+ specify which properties to be used during hybrid search portion.
+ Note: this is not the same as the (self.attributes) to be returned.
+ Example - hybrid_search_kwargs={"properties": ["question", "answer"]}
+ https://weaviate.io/developers/weaviate/search/hybrid#selected-properties-only
+
+ 2) Weight boosted searched properties -
+ Boost the weight of certain properties during the hybrid search portion.
+ Example - hybrid_search_kwargs={"properties": ["question^2", "answer"]}
+ https://weaviate.io/developers/weaviate/search/hybrid#weight-boost-searched-properties
+
+ 3) Search with a custom vector - Define a different vector
+ to be used during the hybrid search portion.
+ Example - hybrid_search_kwargs={"vector": [0.1, 0.2, 0.3, ...]}
+ https://weaviate.io/developers/weaviate/search/hybrid#with-a-custom-vector
+
+ 4) Use Fusion ranking method
+ Example - from weaviate.gql.get import HybridFusion
+ hybrid_search_kwargs={"fusion": fusion_type=HybridFusion.RELATIVE_SCORE}
+ https://weaviate.io/developers/weaviate/search/hybrid#fusion-ranking-method
+ """
+ query_obj = self.client.query.get(self.index_name, self.attributes)
+ if where_filter:
+ query_obj = query_obj.with_where(where_filter)
+
+ if score:
+ query_obj = query_obj.with_additional(["score", "explainScore"])
+
+ if hybrid_search_kwargs is None:
+ hybrid_search_kwargs = {}
+
+ result = (
+ query_obj.with_hybrid(query, alpha=self.alpha, **hybrid_search_kwargs)
+ .with_limit(self.k)
+ .do()
+ )
+ if "errors" in result:
+ raise ValueError(f"Error during query: {result['errors']}")
+
+ docs = []
+
+ for res in result["data"]["Get"][self.index_name]:
+ text = res.pop(self.text_key)
+ docs.append(Document(page_content=text, metadata=res))
+ return docs
diff --git a/libs/community/langchain_community/retrievers/wikipedia.py b/libs/community/langchain_community/retrievers/wikipedia.py
new file mode 100644
index 00000000000..cd7e38180ce
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/wikipedia.py
@@ -0,0 +1,20 @@
+from typing import List
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
+
+
+class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
+ """`Wikipedia API` retriever.
+
+ It wraps load() to get_relevant_documents().
+ It uses all WikipediaAPIWrapper arguments without any change.
+ """
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ return self.load(query=query)
diff --git a/libs/community/langchain_community/retrievers/you.py b/libs/community/langchain_community/retrievers/you.py
new file mode 100644
index 00000000000..b65f2aad78f
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/you.py
@@ -0,0 +1,64 @@
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.utils import get_from_dict_or_env
+
+
+class YouRetriever(BaseRetriever):
+ """`You` retriever that uses You.com's search API.
+
+ To connect to the You.com api requires an API key which
+ you can get by emailing api@you.com.
+ You can check out our docs at https://documentation.you.com.
+
+ You need to set the environment variable `YDC_API_KEY` for retriever to operate.
+ """
+
+ ydc_api_key: str
+ k: Optional[int] = None
+ n_hits: Optional[int] = None
+ n_snippets_per_hit: Optional[int] = None
+ endpoint_type: str = "web"
+
+ @root_validator(pre=True)
+ def validate_client(
+ cls,
+ values: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ values["ydc_api_key"] = get_from_dict_or_env(
+ values, "ydc_api_key", "YDC_API_KEY"
+ )
+ return values
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ import requests
+
+ headers = {"X-API-Key": self.ydc_api_key}
+ if self.endpoint_type == "web":
+ results = requests.get(
+ f"https://api.ydc-index.io/search?query={query}",
+ headers=headers,
+ ).json()
+
+ docs = []
+ n_hits = self.n_hits or len(results["hits"])
+ for hit in results["hits"][:n_hits]:
+ n_snippets_per_hit = self.n_snippets_per_hit or len(hit["snippets"])
+ for snippet in hit["snippets"][:n_snippets_per_hit]:
+ docs.append(Document(page_content=snippet))
+ if self.k is not None and len(docs) >= self.k:
+ return docs
+ return docs
+ elif self.endpoint_type == "snippet":
+ results = requests.get(
+ f"https://api.ydc-index.io/snippet_search?query={query}",
+ headers=headers,
+ ).json()
+ return [Document(page_content=snippet) for snippet in results]
+ else:
+ raise RuntimeError(f"Invalid endpoint type provided {self.endpoint_type}")
diff --git a/libs/community/langchain_community/retrievers/zep.py b/libs/community/langchain_community/retrievers/zep.py
new file mode 100644
index 00000000000..2060788e4cd
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/zep.py
@@ -0,0 +1,182 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForRetrieverRun,
+ CallbackManagerForRetrieverRun,
+)
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.retrievers import BaseRetriever
+
+if TYPE_CHECKING:
+ from zep_python.memory import MemorySearchResult
+
+
+class SearchScope(str, Enum):
+ """Which documents to search. Messages or Summaries?"""
+
+ messages = "messages"
+ """Search chat history messages."""
+ summary = "summary"
+ """Search chat history summaries."""
+
+
+class SearchType(str, Enum):
+ """Enumerator of the types of search to perform."""
+
+ similarity = "similarity"
+ """Similarity search."""
+ mmr = "mmr"
+ """Maximal Marginal Relevance reranking of similarity search."""
+
+
+class ZepRetriever(BaseRetriever):
+ """`Zep` MemoryStore Retriever.
+
+ Search your user's long-term chat history with Zep.
+
+ Zep offers both simple semantic search and Maximal Marginal Relevance (MMR)
+ reranking of search results.
+
+ Note: You will need to provide the user's `session_id` to use this retriever.
+
+ Args:
+ url: URL of your Zep server (required)
+ api_key: Your Zep API key (optional)
+ session_id: Identifies your user or a user's session (required)
+ top_k: Number of documents to return (default: 3, optional)
+ search_type: Type of search to perform (similarity / mmr) (default: similarity,
+ optional)
+ mmr_lambda: Lambda value for MMR search. Defaults to 0.5 (optional)
+
+ Zep - Fast, scalable building blocks for LLM Apps
+ =========
+ Zep is an open source platform for productionizing LLM apps. Go from a prototype
+ built in LangChain or LlamaIndex, or a custom app, to production in minutes without
+ rewriting code.
+
+ For server installation instructions, see:
+ https://docs.getzep.com/deployment/quickstart/
+ """
+
+ zep_client: Optional[Any] = None
+ """Zep client."""
+ url: str
+ """URL of your Zep server."""
+ api_key: Optional[str] = None
+ """Your Zep API key."""
+ session_id: str
+ """Zep session ID."""
+ top_k: Optional[int]
+ """Number of items to return."""
+ search_scope: SearchScope = SearchScope.messages
+ """Which documents to search. Messages or Summaries?"""
+ search_type: SearchType = SearchType.similarity
+ """Type of search to perform (similarity / mmr)"""
+ mmr_lambda: Optional[float] = None
+ """Lambda value for MMR search."""
+
+ @root_validator(pre=True)
+ def create_client(cls, values: dict) -> dict:
+ try:
+ from zep_python import ZepClient
+ except ImportError:
+ raise ImportError(
+ "Could not import zep-python package. "
+ "Please install it with `pip install zep-python`."
+ )
+ values["zep_client"] = values.get(
+ "zep_client",
+ ZepClient(base_url=values["url"], api_key=values.get("api_key")),
+ )
+ return values
+
+ def _messages_search_result_to_doc(
+ self, results: List[MemorySearchResult]
+ ) -> List[Document]:
+ return [
+ Document(
+ page_content=r.message.pop("content"),
+ metadata={"score": r.dist, **r.message},
+ )
+ for r in results
+ if r.message
+ ]
+
+ def _summary_search_result_to_doc(
+ self, results: List[MemorySearchResult]
+ ) -> List[Document]:
+ return [
+ Document(
+ page_content=r.summary.content,
+ metadata={
+ "score": r.dist,
+ "uuid": r.summary.uuid,
+ "created_at": r.summary.created_at,
+ "token_count": r.summary.token_count,
+ },
+ )
+ for r in results
+ if r.summary
+ ]
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> List[Document]:
+ from zep_python.memory import MemorySearchPayload
+
+ if not self.zep_client:
+ raise RuntimeError("Zep client not initialized.")
+
+ payload = MemorySearchPayload(
+ text=query,
+ metadata=metadata,
+ search_scope=self.search_scope,
+ search_type=self.search_type,
+ mmr_lambda=self.mmr_lambda,
+ )
+
+ results: List[MemorySearchResult] = self.zep_client.memory.search_memory(
+ self.session_id, payload, limit=self.top_k
+ )
+
+ if self.search_scope == SearchScope.summary:
+ return self._summary_search_result_to_doc(results)
+
+ return self._messages_search_result_to_doc(results)
+
+ async def _aget_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: AsyncCallbackManagerForRetrieverRun,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> List[Document]:
+ from zep_python.memory import MemorySearchPayload
+
+ if not self.zep_client:
+ raise RuntimeError("Zep client not initialized.")
+
+ payload = MemorySearchPayload(
+ text=query,
+ metadata=metadata,
+ search_scope=self.search_scope,
+ search_type=self.search_type,
+ mmr_lambda=self.mmr_lambda,
+ )
+
+ results: List[MemorySearchResult] = await self.zep_client.memory.asearch_memory(
+ self.session_id, payload, limit=self.top_k
+ )
+
+ if self.search_scope == SearchScope.summary:
+ return self._summary_search_result_to_doc(results)
+
+ return self._messages_search_result_to_doc(results)
diff --git a/libs/community/langchain_community/retrievers/zilliz.py b/libs/community/langchain_community/retrievers/zilliz.py
new file mode 100644
index 00000000000..43cd2c4e8cd
--- /dev/null
+++ b/libs/community/langchain_community/retrievers/zilliz.py
@@ -0,0 +1,86 @@
+import warnings
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.retrievers import BaseRetriever
+
+from langchain_community.vectorstores.zilliz import Zilliz
+
+# TODO: Update to ZillizClient + Hybrid Search when available
+
+
+class ZillizRetriever(BaseRetriever):
+ """`Zilliz API` retriever."""
+
+ embedding_function: Embeddings
+ """The underlying embedding function from which documents will be retrieved."""
+ collection_name: str = "LangChainCollection"
+ """The name of the collection in Zilliz."""
+ connection_args: Optional[Dict[str, Any]] = None
+ """The connection arguments for the Zilliz client."""
+ consistency_level: str = "Session"
+ """The consistency level for the Zilliz client."""
+ search_params: Optional[dict] = None
+ """The search parameters for the Zilliz client."""
+ store: Zilliz
+ """The underlying Zilliz store."""
+ retriever: BaseRetriever
+ """The underlying retriever."""
+
+ @root_validator(pre=True)
+ def create_client(cls, values: dict) -> dict:
+ values["store"] = Zilliz(
+ values["embedding_function"],
+ values["collection_name"],
+ values["connection_args"],
+ values["consistency_level"],
+ )
+ values["retriever"] = values["store"].as_retriever(
+ search_kwargs={"param": values["search_params"]}
+ )
+ return values
+
+ def add_texts(
+ self, texts: List[str], metadatas: Optional[List[dict]] = None
+ ) -> None:
+ """Add text to the Zilliz store
+
+ Args:
+ texts (List[str]): The text
+ metadatas (List[dict]): Metadata dicts, must line up with existing store
+ """
+ self.store.add_texts(texts, metadatas)
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun,
+ **kwargs: Any,
+ ) -> List[Document]:
+ return self.retriever.get_relevant_documents(
+ query, run_manager=run_manager.get_child(), **kwargs
+ )
+
+
+def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
+ """Deprecated ZillizRetreiver.
+
+ Please use ZillizRetriever ('i' before 'e') instead.
+
+ Args:
+ *args:
+ **kwargs:
+
+ Returns:
+ ZillizRetriever
+ """
+ warnings.warn(
+ "ZillizRetreiver will be deprecated in the future. "
+ "Please use ZillizRetriever ('i' before 'e') instead.",
+ DeprecationWarning,
+ )
+ return ZillizRetriever(*args, **kwargs)
diff --git a/libs/community/langchain_community/storage/__init__.py b/libs/community/langchain_community/storage/__init__.py
new file mode 100644
index 00000000000..7af3d5b3000
--- /dev/null
+++ b/libs/community/langchain_community/storage/__init__.py
@@ -0,0 +1,19 @@
+"""Implementations of key-value stores and storage helpers.
+
+Module provides implementations of various key-value stores that conform
+to a simple key-value interface.
+
+The primary goal of these storages is to support implementation of caching.
+"""
+
+from langchain_community.storage.redis import RedisStore
+from langchain_community.storage.upstash_redis import (
+ UpstashRedisByteStore,
+ UpstashRedisStore,
+)
+
+__all__ = [
+ "RedisStore",
+ "UpstashRedisByteStore",
+ "UpstashRedisStore",
+]
diff --git a/libs/community/langchain_community/storage/exceptions.py b/libs/community/langchain_community/storage/exceptions.py
new file mode 100644
index 00000000000..d7231de65c4
--- /dev/null
+++ b/libs/community/langchain_community/storage/exceptions.py
@@ -0,0 +1,5 @@
+from langchain_core.exceptions import LangChainException
+
+
+class InvalidKeyException(LangChainException):
+ """Raised when a key is invalid; e.g., uses incorrect characters."""
diff --git a/libs/community/langchain_community/storage/redis.py b/libs/community/langchain_community/storage/redis.py
new file mode 100644
index 00000000000..0cccbf071be
--- /dev/null
+++ b/libs/community/langchain_community/storage/redis.py
@@ -0,0 +1,141 @@
+from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast
+
+from langchain_core.stores import ByteStore
+
+from langchain_community.utilities.redis import get_client
+
+
+class RedisStore(ByteStore):
+ """BaseStore implementation using Redis as the underlying store.
+
+ Examples:
+ Create a RedisStore instance and perform operations on it:
+
+ .. code-block:: python
+
+ # Instantiate the RedisStore with a Redis connection
+ from langchain_community.storage import RedisStore
+ from langchain_community.utilities.redis import get_client
+
+ client = get_client('redis://localhost:6379')
+ redis_store = RedisStore(client)
+
+ # Set values for keys
+ redis_store.mset([("key1", b"value1"), ("key2", b"value2")])
+
+ # Get values for keys
+ values = redis_store.mget(["key1", "key2"])
+ # [b"value1", b"value2"]
+
+ # Delete keys
+ redis_store.mdelete(["key1"])
+
+ # Iterate over keys
+ for key in redis_store.yield_keys():
+ print(key)
+ """
+
+ def __init__(
+ self,
+ *,
+ client: Any = None,
+ redis_url: Optional[str] = None,
+ client_kwargs: Optional[dict] = None,
+ ttl: Optional[int] = None,
+ namespace: Optional[str] = None,
+ ) -> None:
+ """Initialize the RedisStore with a Redis connection.
+
+ Must provide either a Redis client or a redis_url with optional client_kwargs.
+
+ Args:
+ client: A Redis connection instance
+ redis_url: redis url
+ client_kwargs: Keyword arguments to pass to the Redis client
+ ttl: time to expire keys in seconds if provided,
+ if None keys will never expire
+ namespace: if provided, all keys will be prefixed with this namespace
+ """
+ try:
+ from redis import Redis
+ except ImportError as e:
+ raise ImportError(
+ "The RedisStore requires the redis library to be installed. "
+ "pip install redis"
+ ) from e
+
+ if client and redis_url or client and client_kwargs:
+ raise ValueError(
+ "Either a Redis client or a redis_url with optional client_kwargs "
+ "must be provided, but not both."
+ )
+
+ if client:
+ if not isinstance(client, Redis):
+ raise TypeError(
+ f"Expected Redis client, got {type(client).__name__} instead."
+ )
+ _client = client
+ else:
+ if not redis_url:
+ raise ValueError(
+ "Either a Redis client or a redis_url must be provided."
+ )
+ _client = get_client(redis_url, **(client_kwargs or {}))
+
+ self.client = _client
+
+ if not isinstance(ttl, int) and ttl is not None:
+ raise TypeError(f"Expected int or None, got {type(ttl)} instead.")
+
+ self.ttl = ttl
+ self.namespace = namespace
+
+ def _get_prefixed_key(self, key: str) -> str:
+ """Get the key with the namespace prefix.
+
+ Args:
+ key (str): The original key.
+
+ Returns:
+ str: The key with the namespace prefix.
+ """
+ delimiter = "/"
+ if self.namespace:
+ return f"{self.namespace}{delimiter}{key}"
+ return key
+
+ def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
+ """Get the values associated with the given keys."""
+ return cast(
+ List[Optional[bytes]],
+ self.client.mget([self._get_prefixed_key(key) for key in keys]),
+ )
+
+ def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
+ """Set the given key-value pairs."""
+ pipe = self.client.pipeline()
+
+ for key, value in key_value_pairs:
+ pipe.set(self._get_prefixed_key(key), value, ex=self.ttl)
+ pipe.execute()
+
+ def mdelete(self, keys: Sequence[str]) -> None:
+ """Delete the given keys."""
+ _keys = [self._get_prefixed_key(key) for key in keys]
+ self.client.delete(*_keys)
+
+ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
+ """Yield keys in the store."""
+ if prefix:
+ pattern = self._get_prefixed_key(prefix)
+ else:
+ pattern = self._get_prefixed_key("*")
+ scan_iter = cast(Iterator[bytes], self.client.scan_iter(match=pattern))
+ for key in scan_iter:
+ decoded_key = key.decode("utf-8")
+ if self.namespace:
+ relative_key = decoded_key[len(self.namespace) + 1 :]
+ yield relative_key
+ else:
+ yield decoded_key
diff --git a/libs/community/langchain_community/storage/upstash_redis.py b/libs/community/langchain_community/storage/upstash_redis.py
new file mode 100644
index 00000000000..7fc49b49c76
--- /dev/null
+++ b/libs/community/langchain_community/storage/upstash_redis.py
@@ -0,0 +1,174 @@
+from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast
+
+from langchain_core._api.deprecation import deprecated
+from langchain_core.stores import BaseStore, ByteStore
+
+
+class _UpstashRedisStore(BaseStore[str, str]):
+ """BaseStore implementation using Upstash Redis as the underlying store."""
+
+ def __init__(
+ self,
+ *,
+ client: Any = None,
+ url: Optional[str] = None,
+ token: Optional[str] = None,
+ ttl: Optional[int] = None,
+ namespace: Optional[str] = None,
+ ) -> None:
+ """Initialize the UpstashRedisStore with HTTP API.
+
+ Must provide either an Upstash Redis client or a url.
+
+ Args:
+ client: An Upstash Redis instance
+ url: UPSTASH_REDIS_REST_URL
+ token: UPSTASH_REDIS_REST_TOKEN
+ ttl: time to expire keys in seconds if provided,
+ if None keys will never expire
+ namespace: if provided, all keys will be prefixed with this namespace
+ """
+ try:
+ from upstash_redis import Redis
+ except ImportError as e:
+ raise ImportError(
+ "UpstashRedisStore requires the upstash_redis library to be installed. "
+ "pip install upstash_redis"
+ ) from e
+
+ if client and url:
+ raise ValueError(
+ "Either an Upstash Redis client or a url must be provided, not both."
+ )
+
+ if client:
+ if not isinstance(client, Redis):
+ raise TypeError(
+ f"Expected Upstash Redis client, got {type(client).__name__}."
+ )
+ _client = client
+ else:
+ if not url or not token:
+ raise ValueError(
+ "Either an Upstash Redis client or url and token must be provided."
+ )
+ _client = Redis(url=url, token=token)
+
+ self.client = _client
+
+ if not isinstance(ttl, int) and ttl is not None:
+ raise TypeError(f"Expected int or None, got {type(ttl)} instead.")
+
+ self.ttl = ttl
+ self.namespace = namespace
+
+ def _get_prefixed_key(self, key: str) -> str:
+ """Get the key with the namespace prefix.
+
+ Args:
+ key (str): The original key.
+
+ Returns:
+ str: The key with the namespace prefix.
+ """
+ delimiter = "/"
+ if self.namespace:
+ return f"{self.namespace}{delimiter}{key}"
+ return key
+
+ def mget(self, keys: Sequence[str]) -> List[Optional[str]]:
+ """Get the values associated with the given keys."""
+
+ keys = [self._get_prefixed_key(key) for key in keys]
+ return cast(
+ List[Optional[str]],
+ self.client.mget(*keys),
+ )
+
+ def mset(self, key_value_pairs: Sequence[Tuple[str, str]]) -> None:
+ """Set the given key-value pairs."""
+ for key, value in key_value_pairs:
+ self.client.set(self._get_prefixed_key(key), value, ex=self.ttl)
+
+ def mdelete(self, keys: Sequence[str]) -> None:
+ """Delete the given keys."""
+ _keys = [self._get_prefixed_key(key) for key in keys]
+ self.client.delete(*_keys)
+
+ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
+ """Yield keys in the store."""
+ if prefix:
+ pattern = self._get_prefixed_key(prefix)
+ else:
+ pattern = self._get_prefixed_key("*")
+
+ cursor, keys = self.client.scan(0, match=pattern)
+ for key in keys:
+ if self.namespace:
+ relative_key = key[len(self.namespace) + 1 :]
+ yield relative_key
+ else:
+ yield key
+
+ while cursor != 0:
+ cursor, keys = self.client.scan(cursor, match=pattern)
+ for key in keys:
+ if self.namespace:
+ relative_key = key[len(self.namespace) + 1 :]
+ yield relative_key
+ else:
+ yield key
+
+
+@deprecated("0.0.335", alternative="UpstashRedisByteStore")
+class UpstashRedisStore(_UpstashRedisStore):
+ """
+ BaseStore implementation using Upstash Redis
+ as the underlying store to store strings.
+
+ Deprecated in favor of the more generic UpstashRedisByteStore.
+ """
+
+
+class UpstashRedisByteStore(ByteStore):
+ """
+ BaseStore implementation using Upstash Redis
+ as the underlying store to store raw bytes.
+ """
+
+ def __init__(
+ self,
+ *,
+ client: Any = None,
+ url: Optional[str] = None,
+ token: Optional[str] = None,
+ ttl: Optional[int] = None,
+ namespace: Optional[str] = None,
+ ) -> None:
+ self.underlying_store = _UpstashRedisStore(
+ client=client, url=url, token=token, ttl=ttl, namespace=namespace
+ )
+
+ def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
+ """Get the values associated with the given keys."""
+ return [
+ value.encode("utf-8") if value is not None else None
+ for value in self.underlying_store.mget(keys)
+ ]
+
+ def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
+ """Set the given key-value pairs."""
+ self.underlying_store.mset(
+ [
+ (k, v.decode("utf-8")) if v is not None else None
+ for k, v in key_value_pairs
+ ]
+ )
+
+ def mdelete(self, keys: Sequence[str]) -> None:
+ """Delete the given keys."""
+ self.underlying_store.mdelete(keys)
+
+ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
+ """Yield keys in the store."""
+ yield from self.underlying_store.yield_keys(prefix=prefix)
diff --git a/libs/community/langchain_community/tools/__init__.py b/libs/community/langchain_community/tools/__init__.py
new file mode 100644
index 00000000000..a92353c477b
--- /dev/null
+++ b/libs/community/langchain_community/tools/__init__.py
@@ -0,0 +1,1125 @@
+"""**Tools** are classes that an Agent uses to interact with the world.
+
+Each tool has a **description**. Agent uses the description to choose the right
+tool for the job.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ ToolMetaclass --> BaseTool --> Tool # Examples: AIPluginTool, BaseGraphQLTool
+ # Examples: BraveSearch, HumanInputRun
+
+**Main helpers:**
+
+.. code-block::
+
+ CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
+"""
+from typing import Any
+
+from langchain_core.tools import BaseTool, StructuredTool, Tool, tool
+
+# Used for internal purposes
+_DEPRECATED_TOOLS = {"PythonAstREPLTool", "PythonREPLTool"}
+
+
+def _import_ainetwork_app() -> Any:
+ from langchain_community.tools.ainetwork.app import AINAppOps
+
+ return AINAppOps
+
+
+def _import_ainetwork_owner() -> Any:
+ from langchain_community.tools.ainetwork.owner import AINOwnerOps
+
+ return AINOwnerOps
+
+
+def _import_ainetwork_rule() -> Any:
+ from langchain_community.tools.ainetwork.rule import AINRuleOps
+
+ return AINRuleOps
+
+
+def _import_ainetwork_transfer() -> Any:
+ from langchain_community.tools.ainetwork.transfer import AINTransfer
+
+ return AINTransfer
+
+
+def _import_ainetwork_value() -> Any:
+ from langchain_community.tools.ainetwork.value import AINValueOps
+
+ return AINValueOps
+
+
+def _import_arxiv_tool() -> Any:
+ from langchain_community.tools.arxiv.tool import ArxivQueryRun
+
+ return ArxivQueryRun
+
+
+def _import_azure_cognitive_services_AzureCogsFormRecognizerTool() -> Any:
+ from langchain_community.tools.azure_cognitive_services import (
+ AzureCogsFormRecognizerTool,
+ )
+
+ return AzureCogsFormRecognizerTool
+
+
+def _import_azure_cognitive_services_AzureCogsImageAnalysisTool() -> Any:
+ from langchain_community.tools.azure_cognitive_services import (
+ AzureCogsImageAnalysisTool,
+ )
+
+ return AzureCogsImageAnalysisTool
+
+
+def _import_azure_cognitive_services_AzureCogsSpeech2TextTool() -> Any:
+ from langchain_community.tools.azure_cognitive_services import (
+ AzureCogsSpeech2TextTool,
+ )
+
+ return AzureCogsSpeech2TextTool
+
+
+def _import_azure_cognitive_services_AzureCogsText2SpeechTool() -> Any:
+ from langchain_community.tools.azure_cognitive_services import (
+ AzureCogsText2SpeechTool,
+ )
+
+ return AzureCogsText2SpeechTool
+
+
+def _import_azure_cognitive_services_AzureCogsTextAnalyticsHealthTool() -> Any:
+ from langchain_community.tools.azure_cognitive_services import (
+ AzureCogsTextAnalyticsHealthTool,
+ )
+
+ return AzureCogsTextAnalyticsHealthTool
+
+
+def _import_bing_search_tool_BingSearchResults() -> Any:
+ from langchain_community.tools.bing_search.tool import BingSearchResults
+
+ return BingSearchResults
+
+
+def _import_bing_search_tool_BingSearchRun() -> Any:
+ from langchain_community.tools.bing_search.tool import BingSearchRun
+
+ return BingSearchRun
+
+
+def _import_brave_search_tool() -> Any:
+ from langchain_community.tools.brave_search.tool import BraveSearch
+
+ return BraveSearch
+
+
+def _import_ddg_search_tool_DuckDuckGoSearchResults() -> Any:
+ from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchResults
+
+ return DuckDuckGoSearchResults
+
+
+def _import_ddg_search_tool_DuckDuckGoSearchRun() -> Any:
+ from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun
+
+ return DuckDuckGoSearchRun
+
+
+def _import_edenai_EdenAiExplicitImageTool() -> Any:
+ from langchain_community.tools.edenai import EdenAiExplicitImageTool
+
+ return EdenAiExplicitImageTool
+
+
+def _import_edenai_EdenAiObjectDetectionTool() -> Any:
+ from langchain_community.tools.edenai import EdenAiObjectDetectionTool
+
+ return EdenAiObjectDetectionTool
+
+
+def _import_edenai_EdenAiParsingIDTool() -> Any:
+ from langchain_community.tools.edenai import EdenAiParsingIDTool
+
+ return EdenAiParsingIDTool
+
+
+def _import_edenai_EdenAiParsingInvoiceTool() -> Any:
+ from langchain_community.tools.edenai import EdenAiParsingInvoiceTool
+
+ return EdenAiParsingInvoiceTool
+
+
+def _import_edenai_EdenAiSpeechToTextTool() -> Any:
+ from langchain_community.tools.edenai import EdenAiSpeechToTextTool
+
+ return EdenAiSpeechToTextTool
+
+
+def _import_edenai_EdenAiTextModerationTool() -> Any:
+ from langchain_community.tools.edenai import EdenAiTextModerationTool
+
+ return EdenAiTextModerationTool
+
+
+def _import_edenai_EdenAiTextToSpeechTool() -> Any:
+ from langchain_community.tools.edenai import EdenAiTextToSpeechTool
+
+ return EdenAiTextToSpeechTool
+
+
+def _import_edenai_EdenaiTool() -> Any:
+ from langchain_community.tools.edenai import EdenaiTool
+
+ return EdenaiTool
+
+
+def _import_eleven_labs_text2speech() -> Any:
+ from langchain_community.tools.eleven_labs.text2speech import (
+ ElevenLabsText2SpeechTool,
+ )
+
+ return ElevenLabsText2SpeechTool
+
+
+def _import_file_management_CopyFileTool() -> Any:
+ from langchain_community.tools.file_management import CopyFileTool
+
+ return CopyFileTool
+
+
+def _import_file_management_DeleteFileTool() -> Any:
+ from langchain_community.tools.file_management import DeleteFileTool
+
+ return DeleteFileTool
+
+
+def _import_file_management_FileSearchTool() -> Any:
+ from langchain_community.tools.file_management import FileSearchTool
+
+ return FileSearchTool
+
+
+def _import_file_management_ListDirectoryTool() -> Any:
+ from langchain_community.tools.file_management import ListDirectoryTool
+
+ return ListDirectoryTool
+
+
+def _import_file_management_MoveFileTool() -> Any:
+ from langchain_community.tools.file_management import MoveFileTool
+
+ return MoveFileTool
+
+
+def _import_file_management_ReadFileTool() -> Any:
+ from langchain_community.tools.file_management import ReadFileTool
+
+ return ReadFileTool
+
+
+def _import_file_management_WriteFileTool() -> Any:
+ from langchain_community.tools.file_management import WriteFileTool
+
+ return WriteFileTool
+
+
+def _import_gmail_GmailCreateDraft() -> Any:
+ from langchain_community.tools.gmail import GmailCreateDraft
+
+ return GmailCreateDraft
+
+
+def _import_gmail_GmailGetMessage() -> Any:
+ from langchain_community.tools.gmail import GmailGetMessage
+
+ return GmailGetMessage
+
+
+def _import_gmail_GmailGetThread() -> Any:
+ from langchain_community.tools.gmail import GmailGetThread
+
+ return GmailGetThread
+
+
+def _import_gmail_GmailSearch() -> Any:
+ from langchain_community.tools.gmail import GmailSearch
+
+ return GmailSearch
+
+
+def _import_gmail_GmailSendMessage() -> Any:
+ from langchain_community.tools.gmail import GmailSendMessage
+
+ return GmailSendMessage
+
+
+def _import_google_cloud_texttospeech() -> Any:
+ from langchain_community.tools.google_cloud.texttospeech import (
+ GoogleCloudTextToSpeechTool,
+ )
+
+ return GoogleCloudTextToSpeechTool
+
+
+def _import_google_places_tool() -> Any:
+ from langchain_community.tools.google_places.tool import GooglePlacesTool
+
+ return GooglePlacesTool
+
+
+def _import_google_search_tool_GoogleSearchResults() -> Any:
+ from langchain_community.tools.google_search.tool import GoogleSearchResults
+
+ return GoogleSearchResults
+
+
+def _import_google_search_tool_GoogleSearchRun() -> Any:
+ from langchain_community.tools.google_search.tool import GoogleSearchRun
+
+ return GoogleSearchRun
+
+
+def _import_google_serper_tool_GoogleSerperResults() -> Any:
+ from langchain_community.tools.google_serper.tool import GoogleSerperResults
+
+ return GoogleSerperResults
+
+
+def _import_google_serper_tool_GoogleSerperRun() -> Any:
+ from langchain_community.tools.google_serper.tool import GoogleSerperRun
+
+ return GoogleSerperRun
+
+
+def _import_searchapi_tool_SearchAPIResults() -> Any:
+ from langchain_community.tools.searchapi.tool import SearchAPIResults
+
+ return SearchAPIResults
+
+
+def _import_searchapi_tool_SearchAPIRun() -> Any:
+ from langchain_community.tools.searchapi.tool import SearchAPIRun
+
+ return SearchAPIRun
+
+
+def _import_graphql_tool() -> Any:
+ from langchain_community.tools.graphql.tool import BaseGraphQLTool
+
+ return BaseGraphQLTool
+
+
+def _import_human_tool() -> Any:
+ from langchain_community.tools.human.tool import HumanInputRun
+
+ return HumanInputRun
+
+
+def _import_ifttt() -> Any:
+ from langchain_community.tools.ifttt import IFTTTWebhook
+
+ return IFTTTWebhook
+
+
+def _import_interaction_tool() -> Any:
+ from langchain_community.tools.interaction.tool import StdInInquireTool
+
+ return StdInInquireTool
+
+
+def _import_jira_tool() -> Any:
+ from langchain_community.tools.jira.tool import JiraAction
+
+ return JiraAction
+
+
+def _import_json_tool_JsonGetValueTool() -> Any:
+ from langchain_community.tools.json.tool import JsonGetValueTool
+
+ return JsonGetValueTool
+
+
+def _import_json_tool_JsonListKeysTool() -> Any:
+ from langchain_community.tools.json.tool import JsonListKeysTool
+
+ return JsonListKeysTool
+
+
+def _import_merriam_webster_tool() -> Any:
+ from langchain_community.tools.merriam_webster.tool import MerriamWebsterQueryRun
+
+ return MerriamWebsterQueryRun
+
+
+def _import_metaphor_search() -> Any:
+ from langchain_community.tools.metaphor_search import MetaphorSearchResults
+
+ return MetaphorSearchResults
+
+
+def _import_nasa_tool() -> Any:
+ from langchain_community.tools.nasa.tool import NasaAction
+
+ return NasaAction
+
+
+def _import_office365_create_draft_message() -> Any:
+ from langchain_community.tools.office365.create_draft_message import (
+ O365CreateDraftMessage,
+ )
+
+ return O365CreateDraftMessage
+
+
+def _import_office365_events_search() -> Any:
+ from langchain_community.tools.office365.events_search import O365SearchEvents
+
+ return O365SearchEvents
+
+
+def _import_office365_messages_search() -> Any:
+ from langchain_community.tools.office365.messages_search import O365SearchEmails
+
+ return O365SearchEmails
+
+
+def _import_office365_send_event() -> Any:
+ from langchain_community.tools.office365.send_event import O365SendEvent
+
+ return O365SendEvent
+
+
+def _import_office365_send_message() -> Any:
+ from langchain_community.tools.office365.send_message import O365SendMessage
+
+ return O365SendMessage
+
+
+def _import_office365_utils() -> Any:
+ from langchain_community.tools.office365.utils import authenticate
+
+ return authenticate
+
+
+def _import_openapi_utils_api_models() -> Any:
+ from langchain_community.tools.openapi.utils.api_models import APIOperation
+
+ return APIOperation
+
+
+def _import_openapi_utils_openapi_utils() -> Any:
+ from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec
+
+ return OpenAPISpec
+
+
+def _import_openweathermap_tool() -> Any:
+ from langchain_community.tools.openweathermap.tool import OpenWeatherMapQueryRun
+
+ return OpenWeatherMapQueryRun
+
+
+def _import_playwright_ClickTool() -> Any:
+ from langchain_community.tools.playwright import ClickTool
+
+ return ClickTool
+
+
+def _import_playwright_CurrentWebPageTool() -> Any:
+ from langchain_community.tools.playwright import CurrentWebPageTool
+
+ return CurrentWebPageTool
+
+
+def _import_playwright_ExtractHyperlinksTool() -> Any:
+ from langchain_community.tools.playwright import ExtractHyperlinksTool
+
+ return ExtractHyperlinksTool
+
+
+def _import_playwright_ExtractTextTool() -> Any:
+ from langchain_community.tools.playwright import ExtractTextTool
+
+ return ExtractTextTool
+
+
+def _import_playwright_GetElementsTool() -> Any:
+ from langchain_community.tools.playwright import GetElementsTool
+
+ return GetElementsTool
+
+
+def _import_playwright_NavigateBackTool() -> Any:
+ from langchain_community.tools.playwright import NavigateBackTool
+
+ return NavigateBackTool
+
+
+def _import_playwright_NavigateTool() -> Any:
+ from langchain_community.tools.playwright import NavigateTool
+
+ return NavigateTool
+
+
+def _import_plugin() -> Any:
+ from langchain_community.tools.plugin import AIPluginTool
+
+ return AIPluginTool
+
+
+def _import_powerbi_tool_InfoPowerBITool() -> Any:
+ from langchain_community.tools.powerbi.tool import InfoPowerBITool
+
+ return InfoPowerBITool
+
+
+def _import_powerbi_tool_ListPowerBITool() -> Any:
+ from langchain_community.tools.powerbi.tool import ListPowerBITool
+
+ return ListPowerBITool
+
+
+def _import_powerbi_tool_QueryPowerBITool() -> Any:
+ from langchain_community.tools.powerbi.tool import QueryPowerBITool
+
+ return QueryPowerBITool
+
+
+def _import_pubmed_tool() -> Any:
+ from langchain_community.tools.pubmed.tool import PubmedQueryRun
+
+ return PubmedQueryRun
+
+
+def _import_python_tool_PythonAstREPLTool() -> Any:
+ raise ImportError(
+ "This tool has been moved to langchain experiment. "
+ "This tool has access to a python REPL. "
+ "For best practices make sure to sandbox this tool. "
+ "Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
+ "To keep using this code as is, install langchain experimental and "
+ "update relevant imports replacing 'langchain' with 'langchain_experimental'"
+ )
+
+
+def _import_python_tool_PythonREPLTool() -> Any:
+ raise ImportError(
+ "This tool has been moved to langchain experiment. "
+ "This tool has access to a python REPL. "
+ "For best practices make sure to sandbox this tool. "
+ "Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
+ "To keep using this code as is, install langchain experimental and "
+ "update relevant imports replacing 'langchain' with 'langchain_experimental'"
+ )
+
+
+def _import_reddit_search_RedditSearchRun() -> Any:
+ from langchain_community.tools.reddit_search.tool import RedditSearchRun
+
+ return RedditSearchRun
+
+
+def _import_render() -> Any:
+ from langchain_community.tools.render import format_tool_to_openai_function
+
+ return format_tool_to_openai_function
+
+
+def _import_requests_tool_BaseRequestsTool() -> Any:
+ from langchain_community.tools.requests.tool import BaseRequestsTool
+
+ return BaseRequestsTool
+
+
+def _import_requests_tool_RequestsDeleteTool() -> Any:
+ from langchain_community.tools.requests.tool import RequestsDeleteTool
+
+ return RequestsDeleteTool
+
+
+def _import_requests_tool_RequestsGetTool() -> Any:
+ from langchain_community.tools.requests.tool import RequestsGetTool
+
+ return RequestsGetTool
+
+
+def _import_requests_tool_RequestsPatchTool() -> Any:
+ from langchain_community.tools.requests.tool import RequestsPatchTool
+
+ return RequestsPatchTool
+
+
+def _import_requests_tool_RequestsPostTool() -> Any:
+ from langchain_community.tools.requests.tool import RequestsPostTool
+
+ return RequestsPostTool
+
+
+def _import_requests_tool_RequestsPutTool() -> Any:
+ from langchain_community.tools.requests.tool import RequestsPutTool
+
+ return RequestsPutTool
+
+
+def _import_steam_webapi_tool() -> Any:
+ from langchain_community.tools.steam.tool import SteamWebAPIQueryRun
+
+ return SteamWebAPIQueryRun
+
+
+def _import_scenexplain_tool() -> Any:
+ from langchain_community.tools.scenexplain.tool import SceneXplainTool
+
+ return SceneXplainTool
+
+
+def _import_searx_search_tool_SearxSearchResults() -> Any:
+ from langchain_community.tools.searx_search.tool import SearxSearchResults
+
+ return SearxSearchResults
+
+
+def _import_searx_search_tool_SearxSearchRun() -> Any:
+ from langchain_community.tools.searx_search.tool import SearxSearchRun
+
+ return SearxSearchRun
+
+
+def _import_shell_tool() -> Any:
+ from langchain_community.tools.shell.tool import ShellTool
+
+ return ShellTool
+
+
+def _import_slack_get_channel() -> Any:
+ from langchain_community.tools.slack.get_channel import SlackGetChannel
+
+ return SlackGetChannel
+
+
+def _import_slack_get_message() -> Any:
+ from langchain_community.tools.slack.get_message import SlackGetMessage
+
+ return SlackGetMessage
+
+
+def _import_slack_schedule_message() -> Any:
+ from langchain_community.tools.slack.schedule_message import SlackScheduleMessage
+
+ return SlackScheduleMessage
+
+
+def _import_slack_send_message() -> Any:
+ from langchain_community.tools.slack.send_message import SlackSendMessage
+
+ return SlackSendMessage
+
+
+def _import_sleep_tool() -> Any:
+ from langchain_community.tools.sleep.tool import SleepTool
+
+ return SleepTool
+
+
+def _import_spark_sql_tool_BaseSparkSQLTool() -> Any:
+ from langchain_community.tools.spark_sql.tool import BaseSparkSQLTool
+
+ return BaseSparkSQLTool
+
+
+def _import_spark_sql_tool_InfoSparkSQLTool() -> Any:
+ from langchain_community.tools.spark_sql.tool import InfoSparkSQLTool
+
+ return InfoSparkSQLTool
+
+
+def _import_spark_sql_tool_ListSparkSQLTool() -> Any:
+ from langchain_community.tools.spark_sql.tool import ListSparkSQLTool
+
+ return ListSparkSQLTool
+
+
+def _import_spark_sql_tool_QueryCheckerTool() -> Any:
+ from langchain_community.tools.spark_sql.tool import QueryCheckerTool
+
+ return QueryCheckerTool
+
+
+def _import_spark_sql_tool_QuerySparkSQLTool() -> Any:
+ from langchain_community.tools.spark_sql.tool import QuerySparkSQLTool
+
+ return QuerySparkSQLTool
+
+
+def _import_sql_database_tool_BaseSQLDatabaseTool() -> Any:
+ from langchain_community.tools.sql_database.tool import BaseSQLDatabaseTool
+
+ return BaseSQLDatabaseTool
+
+
+def _import_sql_database_tool_InfoSQLDatabaseTool() -> Any:
+ from langchain_community.tools.sql_database.tool import InfoSQLDatabaseTool
+
+ return InfoSQLDatabaseTool
+
+
+def _import_sql_database_tool_ListSQLDatabaseTool() -> Any:
+ from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool
+
+ return ListSQLDatabaseTool
+
+
+def _import_sql_database_tool_QuerySQLCheckerTool() -> Any:
+ from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool
+
+ return QuerySQLCheckerTool
+
+
+def _import_sql_database_tool_QuerySQLDataBaseTool() -> Any:
+ from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
+
+ return QuerySQLDataBaseTool
+
+
+def _import_stackexchange_tool() -> Any:
+ from langchain_community.tools.stackexchange.tool import StackExchangeTool
+
+ return StackExchangeTool
+
+
+def _import_steamship_image_generation() -> Any:
+ from langchain_community.tools.steamship_image_generation import (
+ SteamshipImageGenerationTool,
+ )
+
+ return SteamshipImageGenerationTool
+
+
+def _import_vectorstore_tool_VectorStoreQATool() -> Any:
+ from langchain_community.tools.vectorstore.tool import VectorStoreQATool
+
+ return VectorStoreQATool
+
+
+def _import_vectorstore_tool_VectorStoreQAWithSourcesTool() -> Any:
+ from langchain_community.tools.vectorstore.tool import VectorStoreQAWithSourcesTool
+
+ return VectorStoreQAWithSourcesTool
+
+
+def _import_wikipedia_tool() -> Any:
+ from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
+
+ return WikipediaQueryRun
+
+
+def _import_wolfram_alpha_tool() -> Any:
+ from langchain_community.tools.wolfram_alpha.tool import WolframAlphaQueryRun
+
+ return WolframAlphaQueryRun
+
+
+def _import_yahoo_finance_news() -> Any:
+ from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool
+
+ return YahooFinanceNewsTool
+
+
+def _import_youtube_search() -> Any:
+ from langchain_community.tools.youtube.search import YouTubeSearchTool
+
+ return YouTubeSearchTool
+
+
+def _import_zapier_tool_ZapierNLAListActions() -> Any:
+ from langchain_community.tools.zapier.tool import ZapierNLAListActions
+
+ return ZapierNLAListActions
+
+
+def _import_zapier_tool_ZapierNLARunAction() -> Any:
+ from langchain_community.tools.zapier.tool import ZapierNLARunAction
+
+ return ZapierNLARunAction
+
+
+def _import_bearly_tool() -> Any:
+ from langchain_community.tools.bearly.tool import BearlyInterpreterTool
+
+ return BearlyInterpreterTool
+
+
+def _import_e2b_data_analysis() -> Any:
+ from langchain_community.tools.e2b_data_analysis.tool import E2BDataAnalysisTool
+
+ return E2BDataAnalysisTool
+
+
+def __getattr__(name: str) -> Any:
+ if name == "AINAppOps":
+ return _import_ainetwork_app()
+ elif name == "AINOwnerOps":
+ return _import_ainetwork_owner()
+ elif name == "AINRuleOps":
+ return _import_ainetwork_rule()
+ elif name == "AINTransfer":
+ return _import_ainetwork_transfer()
+ elif name == "AINValueOps":
+ return _import_ainetwork_value()
+ elif name == "ArxivQueryRun":
+ return _import_arxiv_tool()
+ elif name == "AzureCogsFormRecognizerTool":
+ return _import_azure_cognitive_services_AzureCogsFormRecognizerTool()
+ elif name == "AzureCogsImageAnalysisTool":
+ return _import_azure_cognitive_services_AzureCogsImageAnalysisTool()
+ elif name == "AzureCogsSpeech2TextTool":
+ return _import_azure_cognitive_services_AzureCogsSpeech2TextTool()
+ elif name == "AzureCogsText2SpeechTool":
+ return _import_azure_cognitive_services_AzureCogsText2SpeechTool()
+ elif name == "AzureCogsTextAnalyticsHealthTool":
+ return _import_azure_cognitive_services_AzureCogsTextAnalyticsHealthTool()
+ elif name == "BingSearchResults":
+ return _import_bing_search_tool_BingSearchResults()
+ elif name == "BingSearchRun":
+ return _import_bing_search_tool_BingSearchRun()
+ elif name == "BraveSearch":
+ return _import_brave_search_tool()
+ elif name == "DuckDuckGoSearchResults":
+ return _import_ddg_search_tool_DuckDuckGoSearchResults()
+ elif name == "DuckDuckGoSearchRun":
+ return _import_ddg_search_tool_DuckDuckGoSearchRun()
+ elif name == "EdenAiExplicitImageTool":
+ return _import_edenai_EdenAiExplicitImageTool()
+ elif name == "EdenAiObjectDetectionTool":
+ return _import_edenai_EdenAiObjectDetectionTool()
+ elif name == "EdenAiParsingIDTool":
+ return _import_edenai_EdenAiParsingIDTool()
+ elif name == "EdenAiParsingInvoiceTool":
+ return _import_edenai_EdenAiParsingInvoiceTool()
+ elif name == "EdenAiSpeechToTextTool":
+ return _import_edenai_EdenAiSpeechToTextTool()
+ elif name == "EdenAiTextModerationTool":
+ return _import_edenai_EdenAiTextModerationTool()
+ elif name == "EdenAiTextToSpeechTool":
+ return _import_edenai_EdenAiTextToSpeechTool()
+ elif name == "EdenaiTool":
+ return _import_edenai_EdenaiTool()
+ elif name == "ElevenLabsText2SpeechTool":
+ return _import_eleven_labs_text2speech()
+ elif name == "CopyFileTool":
+ return _import_file_management_CopyFileTool()
+ elif name == "DeleteFileTool":
+ return _import_file_management_DeleteFileTool()
+ elif name == "FileSearchTool":
+ return _import_file_management_FileSearchTool()
+ elif name == "ListDirectoryTool":
+ return _import_file_management_ListDirectoryTool()
+ elif name == "MoveFileTool":
+ return _import_file_management_MoveFileTool()
+ elif name == "ReadFileTool":
+ return _import_file_management_ReadFileTool()
+ elif name == "WriteFileTool":
+ return _import_file_management_WriteFileTool()
+ elif name == "GmailCreateDraft":
+ return _import_gmail_GmailCreateDraft()
+ elif name == "GmailGetMessage":
+ return _import_gmail_GmailGetMessage()
+ elif name == "GmailGetThread":
+ return _import_gmail_GmailGetThread()
+ elif name == "GmailSearch":
+ return _import_gmail_GmailSearch()
+ elif name == "GmailSendMessage":
+ return _import_gmail_GmailSendMessage()
+ elif name == "GoogleCloudTextToSpeechTool":
+ return _import_google_cloud_texttospeech()
+ elif name == "GooglePlacesTool":
+ return _import_google_places_tool()
+ elif name == "GoogleSearchResults":
+ return _import_google_search_tool_GoogleSearchResults()
+ elif name == "GoogleSearchRun":
+ return _import_google_search_tool_GoogleSearchRun()
+ elif name == "GoogleSerperResults":
+ return _import_google_serper_tool_GoogleSerperResults()
+ elif name == "GoogleSerperRun":
+ return _import_google_serper_tool_GoogleSerperRun()
+ elif name == "SearchAPIResults":
+ return _import_searchapi_tool_SearchAPIResults()
+ elif name == "SearchAPIRun":
+ return _import_searchapi_tool_SearchAPIRun()
+ elif name == "BaseGraphQLTool":
+ return _import_graphql_tool()
+ elif name == "HumanInputRun":
+ return _import_human_tool()
+ elif name == "IFTTTWebhook":
+ return _import_ifttt()
+ elif name == "StdInInquireTool":
+ return _import_interaction_tool()
+ elif name == "JiraAction":
+ return _import_jira_tool()
+ elif name == "JsonGetValueTool":
+ return _import_json_tool_JsonGetValueTool()
+ elif name == "JsonListKeysTool":
+ return _import_json_tool_JsonListKeysTool()
+ elif name == "MerriamWebsterQueryRun":
+ return _import_merriam_webster_tool()
+ elif name == "MetaphorSearchResults":
+ return _import_metaphor_search()
+ elif name == "NasaAction":
+ return _import_nasa_tool()
+ elif name == "O365CreateDraftMessage":
+ return _import_office365_create_draft_message()
+ elif name == "O365SearchEvents":
+ return _import_office365_events_search()
+ elif name == "O365SearchEmails":
+ return _import_office365_messages_search()
+ elif name == "O365SendEvent":
+ return _import_office365_send_event()
+ elif name == "O365SendMessage":
+ return _import_office365_send_message()
+ elif name == "authenticate":
+ return _import_office365_utils()
+ elif name == "APIOperation":
+ return _import_openapi_utils_api_models()
+ elif name == "OpenAPISpec":
+ return _import_openapi_utils_openapi_utils()
+ elif name == "OpenWeatherMapQueryRun":
+ return _import_openweathermap_tool()
+ elif name == "ClickTool":
+ return _import_playwright_ClickTool()
+ elif name == "CurrentWebPageTool":
+ return _import_playwright_CurrentWebPageTool()
+ elif name == "ExtractHyperlinksTool":
+ return _import_playwright_ExtractHyperlinksTool()
+ elif name == "ExtractTextTool":
+ return _import_playwright_ExtractTextTool()
+ elif name == "GetElementsTool":
+ return _import_playwright_GetElementsTool()
+ elif name == "NavigateBackTool":
+ return _import_playwright_NavigateBackTool()
+ elif name == "NavigateTool":
+ return _import_playwright_NavigateTool()
+ elif name == "AIPluginTool":
+ return _import_plugin()
+ elif name == "InfoPowerBITool":
+ return _import_powerbi_tool_InfoPowerBITool()
+ elif name == "ListPowerBITool":
+ return _import_powerbi_tool_ListPowerBITool()
+ elif name == "QueryPowerBITool":
+ return _import_powerbi_tool_QueryPowerBITool()
+ elif name == "PubmedQueryRun":
+ return _import_pubmed_tool()
+ elif name == "PythonAstREPLTool":
+ return _import_python_tool_PythonAstREPLTool()
+ elif name == "PythonREPLTool":
+ return _import_python_tool_PythonREPLTool()
+ elif name == "RedditSearchRun":
+ return _import_reddit_search_RedditSearchRun()
+ elif name == "format_tool_to_openai_function":
+ return _import_render()
+ elif name == "BaseRequestsTool":
+ return _import_requests_tool_BaseRequestsTool()
+ elif name == "RequestsDeleteTool":
+ return _import_requests_tool_RequestsDeleteTool()
+ elif name == "RequestsGetTool":
+ return _import_requests_tool_RequestsGetTool()
+ elif name == "RequestsPatchTool":
+ return _import_requests_tool_RequestsPatchTool()
+ elif name == "RequestsPostTool":
+ return _import_requests_tool_RequestsPostTool()
+ elif name == "RequestsPutTool":
+ return _import_requests_tool_RequestsPutTool()
+ elif name == "SteamWebAPIQueryRun":
+ return _import_steam_webapi_tool()
+ elif name == "SceneXplainTool":
+ return _import_scenexplain_tool()
+ elif name == "SearxSearchResults":
+ return _import_searx_search_tool_SearxSearchResults()
+ elif name == "SearxSearchRun":
+ return _import_searx_search_tool_SearxSearchRun()
+ elif name == "ShellTool":
+ return _import_shell_tool()
+ elif name == "SlackGetChannel":
+ return _import_slack_get_channel
+ elif name == "SlackGetMessage":
+ return _import_slack_get_message
+ elif name == "SlackScheduleMessage":
+ return _import_slack_schedule_message
+ elif name == "SlackSendMessage":
+ return _import_slack_send_message
+ elif name == "SleepTool":
+ return _import_sleep_tool()
+ elif name == "BaseSparkSQLTool":
+ return _import_spark_sql_tool_BaseSparkSQLTool()
+ elif name == "InfoSparkSQLTool":
+ return _import_spark_sql_tool_InfoSparkSQLTool()
+ elif name == "ListSparkSQLTool":
+ return _import_spark_sql_tool_ListSparkSQLTool()
+ elif name == "QueryCheckerTool":
+ return _import_spark_sql_tool_QueryCheckerTool()
+ elif name == "QuerySparkSQLTool":
+ return _import_spark_sql_tool_QuerySparkSQLTool()
+ elif name == "BaseSQLDatabaseTool":
+ return _import_sql_database_tool_BaseSQLDatabaseTool()
+ elif name == "InfoSQLDatabaseTool":
+ return _import_sql_database_tool_InfoSQLDatabaseTool()
+ elif name == "ListSQLDatabaseTool":
+ return _import_sql_database_tool_ListSQLDatabaseTool()
+ elif name == "QuerySQLCheckerTool":
+ return _import_sql_database_tool_QuerySQLCheckerTool()
+ elif name == "QuerySQLDataBaseTool":
+ return _import_sql_database_tool_QuerySQLDataBaseTool()
+ elif name == "StackExchangeTool":
+ return _import_stackexchange_tool()
+ elif name == "SteamshipImageGenerationTool":
+ return _import_steamship_image_generation()
+ elif name == "VectorStoreQATool":
+ return _import_vectorstore_tool_VectorStoreQATool()
+ elif name == "VectorStoreQAWithSourcesTool":
+ return _import_vectorstore_tool_VectorStoreQAWithSourcesTool()
+ elif name == "WikipediaQueryRun":
+ return _import_wikipedia_tool()
+ elif name == "WolframAlphaQueryRun":
+ return _import_wolfram_alpha_tool()
+ elif name == "YahooFinanceNewsTool":
+ return _import_yahoo_finance_news()
+ elif name == "YouTubeSearchTool":
+ return _import_youtube_search()
+ elif name == "ZapierNLAListActions":
+ return _import_zapier_tool_ZapierNLAListActions()
+ elif name == "ZapierNLARunAction":
+ return _import_zapier_tool_ZapierNLARunAction()
+ elif name == "BearlyInterpreterTool":
+ return _import_bearly_tool()
+ elif name == "E2BDataAnalysisTool":
+ return _import_e2b_data_analysis()
+ else:
+ raise AttributeError(f"Could not find: {name}")
+
+
+__all__ = [
+ "AINAppOps",
+ "AINOwnerOps",
+ "AINRuleOps",
+ "AINTransfer",
+ "AINValueOps",
+ "AIPluginTool",
+ "APIOperation",
+ "ArxivQueryRun",
+ "AzureCogsFormRecognizerTool",
+ "AzureCogsImageAnalysisTool",
+ "AzureCogsSpeech2TextTool",
+ "AzureCogsText2SpeechTool",
+ "AzureCogsTextAnalyticsHealthTool",
+ "BaseGraphQLTool",
+ "BaseRequestsTool",
+ "BaseSQLDatabaseTool",
+ "BaseSparkSQLTool",
+ "BaseTool",
+ "BearlyInterpreterTool",
+ "BingSearchResults",
+ "BingSearchRun",
+ "BraveSearch",
+ "ClickTool",
+ "CopyFileTool",
+ "CurrentWebPageTool",
+ "DeleteFileTool",
+ "DuckDuckGoSearchResults",
+ "DuckDuckGoSearchRun",
+ "E2BDataAnalysisTool",
+ "EdenAiExplicitImageTool",
+ "EdenAiObjectDetectionTool",
+ "EdenAiParsingIDTool",
+ "EdenAiParsingInvoiceTool",
+ "EdenAiSpeechToTextTool",
+ "EdenAiTextModerationTool",
+ "EdenAiTextToSpeechTool",
+ "EdenaiTool",
+ "ElevenLabsText2SpeechTool",
+ "ExtractHyperlinksTool",
+ "ExtractTextTool",
+ "FileSearchTool",
+ "GetElementsTool",
+ "GmailCreateDraft",
+ "GmailGetMessage",
+ "GmailGetThread",
+ "GmailSearch",
+ "GmailSendMessage",
+ "GoogleCloudTextToSpeechTool",
+ "GooglePlacesTool",
+ "GoogleSearchResults",
+ "GoogleSearchRun",
+ "GoogleSerperResults",
+ "GoogleSerperRun",
+ "SearchAPIResults",
+ "SearchAPIRun",
+ "HumanInputRun",
+ "IFTTTWebhook",
+ "InfoPowerBITool",
+ "InfoSQLDatabaseTool",
+ "InfoSparkSQLTool",
+ "JiraAction",
+ "JsonGetValueTool",
+ "JsonListKeysTool",
+ "ListDirectoryTool",
+ "ListPowerBITool",
+ "ListSQLDatabaseTool",
+ "ListSparkSQLTool",
+ "MerriamWebsterQueryRun",
+ "MetaphorSearchResults",
+ "MoveFileTool",
+ "NasaAction",
+ "NavigateBackTool",
+ "NavigateTool",
+ "O365CreateDraftMessage",
+ "O365SearchEmails",
+ "O365SearchEvents",
+ "O365SendEvent",
+ "O365SendMessage",
+ "OpenAPISpec",
+ "OpenWeatherMapQueryRun",
+ "PubmedQueryRun",
+ "RedditSearchRun",
+ "QueryCheckerTool",
+ "QueryPowerBITool",
+ "QuerySQLCheckerTool",
+ "QuerySQLDataBaseTool",
+ "QuerySparkSQLTool",
+ "ReadFileTool",
+ "RequestsDeleteTool",
+ "RequestsGetTool",
+ "RequestsPatchTool",
+ "RequestsPostTool",
+ "RequestsPutTool",
+ "SteamWebAPIQueryRun",
+ "SceneXplainTool",
+ "SearxSearchResults",
+ "SearxSearchRun",
+ "ShellTool",
+ "SlackGetChannel",
+ "SlackGetMessage",
+ "SlackScheduleMessage",
+ "SlackSendMessage",
+ "SleepTool",
+ "StdInInquireTool",
+ "StackExchangeTool",
+ "SteamshipImageGenerationTool",
+ "StructuredTool",
+ "Tool",
+ "VectorStoreQATool",
+ "VectorStoreQAWithSourcesTool",
+ "WikipediaQueryRun",
+ "WolframAlphaQueryRun",
+ "WriteFileTool",
+ "YahooFinanceNewsTool",
+ "YouTubeSearchTool",
+ "ZapierNLAListActions",
+ "ZapierNLARunAction",
+ "authenticate",
+ "format_tool_to_openai_function",
+ "tool",
+]
diff --git a/libs/community/langchain_community/tools/ainetwork/app.py b/libs/community/langchain_community/tools/ainetwork/app.py
new file mode 100644
index 00000000000..faef6120f9f
--- /dev/null
+++ b/libs/community/langchain_community/tools/ainetwork/app.py
@@ -0,0 +1,102 @@
+import builtins
+import json
+from enum import Enum
+from typing import List, Optional, Type, Union
+
+from langchain_core.callbacks import AsyncCallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.ainetwork.base import AINBaseTool
+
+
+class AppOperationType(str, Enum):
+ """Type of app operation as enumerator."""
+
+ SET_ADMIN = "SET_ADMIN"
+ GET_CONFIG = "GET_CONFIG"
+
+
+class AppSchema(BaseModel):
+ """Schema for app operations."""
+
+ type: AppOperationType = Field(...)
+ appName: str = Field(..., description="Name of the application on the blockchain")
+ address: Optional[Union[str, List[str]]] = Field(
+ None,
+ description=(
+ "A single address or a list of addresses. Default: current session's "
+ "address"
+ ),
+ )
+
+
+class AINAppOps(AINBaseTool):
+ """Tool for app operations."""
+
+ name: str = "AINappOps"
+ description: str = """
+Create an app in the AINetwork Blockchain database by creating the /apps/ path.
+An address set as `admin` can grant `owner` rights to other addresses (refer to `AINownerOps` for more details).
+Also, `admin` is initialized to have all `owner` permissions and `rule` allowed for that path.
+
+## appName Rule
+- [a-z_0-9]+
+
+## address Rules
+- 0x[0-9a-fA-F]{40}
+- Defaults to the current session's address
+- Multiple addresses can be specified if needed
+
+## SET_ADMIN Example 1
+- type: SET_ADMIN
+- appName: ain_project
+
+### Result:
+1. Path /apps/ain_project created.
+2. Current session's address registered as admin.
+
+## SET_ADMIN Example 2
+- type: SET_ADMIN
+- appName: test_project
+- address: [, ]
+
+### Result:
+1. Path /apps/test_project created.
+2. and registered as admin.
+
+""" # noqa: E501
+ args_schema: Type[BaseModel] = AppSchema
+
+ async def _arun(
+ self,
+ type: AppOperationType,
+ appName: str,
+ address: Optional[Union[str, List[str]]] = None,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ from ain.types import ValueOnlyTransactionInput
+ from ain.utils import getTimestamp
+
+ try:
+ if type is AppOperationType.SET_ADMIN:
+ if address is None:
+ address = self.interface.wallet.defaultAccount.address
+ if isinstance(address, str):
+ address = [address]
+
+ res = await self.interface.db.ref(
+ f"/manage_app/{appName}/create/{getTimestamp()}"
+ ).setValue(
+ transactionInput=ValueOnlyTransactionInput(
+ value={"admin": {address: True for address in address}}
+ )
+ )
+ elif type is AppOperationType.GET_CONFIG:
+ res = await self.interface.db.ref(
+ f"/manage_app/{appName}/config"
+ ).getValue()
+ else:
+ raise ValueError(f"Unsupported 'type': {type}.")
+ return json.dumps(res, ensure_ascii=False)
+ except Exception as e:
+ return f"{builtins.type(e).__name__}: {str(e)}"
diff --git a/libs/community/langchain_community/tools/ainetwork/base.py b/libs/community/langchain_community/tools/ainetwork/base.py
new file mode 100644
index 00000000000..2f941275769
--- /dev/null
+++ b/libs/community/langchain_community/tools/ainetwork/base.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+import asyncio
+import threading
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.ainetwork.utils import authenticate
+
+if TYPE_CHECKING:
+ from ain.ain import Ain
+
+
+class OperationType(str, Enum):
+ """Type of operation as enumerator."""
+
+ SET = "SET"
+ GET = "GET"
+
+
+class AINBaseTool(BaseTool):
+ """Base class for the AINetwork tools."""
+
+ interface: Ain = Field(default_factory=authenticate)
+ """The interface object for the AINetwork Blockchain."""
+
+ def _run(
+ self,
+ *args: Any,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ try:
+ loop = asyncio.get_event_loop()
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ if loop.is_closed():
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ if loop.is_running():
+ result_container = []
+
+ def thread_target() -> None:
+ nonlocal result_container
+ new_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(new_loop)
+ try:
+ result_container.append(
+ new_loop.run_until_complete(self._arun(*args, **kwargs))
+ )
+ except Exception as e:
+ result_container.append(e)
+ finally:
+ new_loop.close()
+
+ thread = threading.Thread(target=thread_target)
+ thread.start()
+ thread.join()
+ result = result_container[0]
+ if isinstance(result, Exception):
+ raise result
+ return result
+
+ else:
+ result = loop.run_until_complete(self._arun(*args, **kwargs))
+ loop.close()
+ return result
diff --git a/libs/community/langchain_community/tools/ainetwork/owner.py b/libs/community/langchain_community/tools/ainetwork/owner.py
new file mode 100644
index 00000000000..a89134f2a05
--- /dev/null
+++ b/libs/community/langchain_community/tools/ainetwork/owner.py
@@ -0,0 +1,115 @@
+import builtins
+import json
+from typing import List, Optional, Type, Union
+
+from langchain_core.callbacks import AsyncCallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.ainetwork.base import AINBaseTool, OperationType
+
+
+class RuleSchema(BaseModel):
+ """Schema for owner operations."""
+
+ type: OperationType = Field(...)
+ path: str = Field(..., description="Blockchain reference path")
+ address: Optional[Union[str, List[str]]] = Field(
+ None, description="A single address or a list of addresses"
+ )
+ write_owner: Optional[bool] = Field(
+ False, description="Authority to edit the `owner` property of the path"
+ )
+ write_rule: Optional[bool] = Field(
+ False, description="Authority to edit `write rule` for the path"
+ )
+ write_function: Optional[bool] = Field(
+ False, description="Authority to `set function` for the path"
+ )
+ branch_owner: Optional[bool] = Field(
+ False, description="Authority to initialize `owner` of sub-paths"
+ )
+
+
+class AINOwnerOps(AINBaseTool):
+ """Tool for owner operations."""
+
+ name: str = "AINownerOps"
+ description: str = """
+Rules for `owner` in AINetwork Blockchain database.
+An address set as `owner` can modify permissions according to its granted authorities
+
+## Path Rule
+- (/[a-zA-Z_0-9]+)+
+- Permission checks ascend from the most specific (child) path to broader (parent) paths until an `owner` is located.
+
+## Address Rules
+- 0x[0-9a-fA-F]{40}: 40-digit hexadecimal address
+- *: All addresses permitted
+- Defaults to the current session's address
+
+## SET
+- `SET` alters permissions for specific addresses, while other addresses remain unaffected.
+- When removing an address of `owner`, set all authorities for that address to false.
+- message `write_owner permission evaluated false` if fail
+
+### Example
+- type: SET
+- path: /apps/langchain
+- address: [, ]
+- write_owner: True
+- write_rule: True
+- write_function: True
+- branch_owner: True
+
+## GET
+- Provides all addresses with `owner` permissions and their authorities in the path.
+
+### Example
+- type: GET
+- path: /apps/langchain
+""" # noqa: E501
+ args_schema: Type[BaseModel] = RuleSchema
+
+ async def _arun(
+ self,
+ type: OperationType,
+ path: str,
+ address: Optional[Union[str, List[str]]] = None,
+ write_owner: Optional[bool] = None,
+ write_rule: Optional[bool] = None,
+ write_function: Optional[bool] = None,
+ branch_owner: Optional[bool] = None,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ from ain.types import ValueOnlyTransactionInput
+
+ try:
+ if type is OperationType.SET:
+ if address is None:
+ address = self.interface.wallet.defaultAccount.address
+ if isinstance(address, str):
+ address = [address]
+ res = await self.interface.db.ref(path).setOwner(
+ transactionInput=ValueOnlyTransactionInput(
+ value={
+ ".owner": {
+ "owners": {
+ address: {
+ "write_owner": write_owner or False,
+ "write_rule": write_rule or False,
+ "write_function": write_function or False,
+ "branch_owner": branch_owner or False,
+ }
+ for address in address
+ }
+ }
+ }
+ )
+ )
+ elif type is OperationType.GET:
+ res = await self.interface.db.ref(path).getOwner()
+ else:
+ raise ValueError(f"Unsupported 'type': {type}.")
+ return json.dumps(res, ensure_ascii=False)
+ except Exception as e:
+ return f"{builtins.type(e).__name__}: {str(e)}"
diff --git a/libs/community/langchain_community/tools/ainetwork/rule.py b/libs/community/langchain_community/tools/ainetwork/rule.py
new file mode 100644
index 00000000000..309010f5ba3
--- /dev/null
+++ b/libs/community/langchain_community/tools/ainetwork/rule.py
@@ -0,0 +1,82 @@
+import builtins
+import json
+from typing import Optional, Type
+
+from langchain_core.callbacks import AsyncCallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.ainetwork.base import AINBaseTool, OperationType
+
+
+class RuleSchema(BaseModel):
+ """Schema for owner operations."""
+
+ type: OperationType = Field(...)
+ path: str = Field(..., description="Path on the blockchain where the rule applies")
+ eval: Optional[str] = Field(None, description="eval string to determine permission")
+
+
+class AINRuleOps(AINBaseTool):
+ """Tool for owner operations."""
+
+ name: str = "AINruleOps"
+ description: str = """
+Covers the write `rule` for the AINetwork Blockchain database. The SET type specifies write permissions using the `eval` variable as a JavaScript eval string.
+In order to AINvalueOps with SET at the path, the execution result of the `eval` string must be true.
+
+## Path Rules
+1. Allowed characters for directory: `[a-zA-Z_0-9]`
+2. Use `$` for template variables as directory.
+
+## Eval String Special Variables
+- auth.addr: Address of the writer for the path
+- newData: New data for the path
+- data: Current data for the path
+- currentTime: Time in seconds
+- lastBlockNumber: Latest processed block number
+
+## Eval String Functions
+- getValue()
+- getRule()
+- getOwner()
+- getFunction()
+- evalRule(, , auth, currentTime)
+- evalOwner(, 'write_owner', auth)
+
+## SET Example
+- type: SET
+- path: /apps/langchain_project_1/$from/$to/$img
+- eval: auth.addr===$from&&!getValue('/apps/image_db/'+$img)
+
+## GET Example
+- type: GET
+- path: /apps/langchain_project_1
+""" # noqa: E501
+ args_schema: Type[BaseModel] = RuleSchema
+
+ async def _arun(
+ self,
+ type: OperationType,
+ path: str,
+ eval: Optional[str] = None,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ from ain.types import ValueOnlyTransactionInput
+
+ try:
+ if type is OperationType.SET:
+ if eval is None:
+ raise ValueError("'eval' is required for SET operation.")
+
+ res = await self.interface.db.ref(path).setRule(
+ transactionInput=ValueOnlyTransactionInput(
+ value={".rule": {"write": eval}}
+ )
+ )
+ elif type is OperationType.GET:
+ res = await self.interface.db.ref(path).getRule()
+ else:
+ raise ValueError(f"Unsupported 'type': {type}.")
+ return json.dumps(res, ensure_ascii=False)
+ except Exception as e:
+ return f"{builtins.type(e).__name__}: {str(e)}"
diff --git a/libs/community/langchain_community/tools/ainetwork/transfer.py b/libs/community/langchain_community/tools/ainetwork/transfer.py
new file mode 100644
index 00000000000..deab34ad9c2
--- /dev/null
+++ b/libs/community/langchain_community/tools/ainetwork/transfer.py
@@ -0,0 +1,34 @@
+import json
+from typing import Optional, Type
+
+from langchain_core.callbacks import AsyncCallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.ainetwork.base import AINBaseTool
+
+
+class TransferSchema(BaseModel):
+ """Schema for transfer operations."""
+
+ address: str = Field(..., description="Address to transfer AIN to")
+ amount: int = Field(..., description="Amount of AIN to transfer")
+
+
+class AINTransfer(AINBaseTool):
+ """Tool for transfer operations."""
+
+ name: str = "AINtransfer"
+ description: str = "Transfers AIN to a specified address"
+ args_schema: Type[TransferSchema] = TransferSchema
+
+ async def _arun(
+ self,
+ address: str,
+ amount: int,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ res = await self.interface.wallet.transfer(address, amount, nonce=-1)
+ return json.dumps(res, ensure_ascii=False)
+ except Exception as e:
+ return f"{type(e).__name__}: {str(e)}"
diff --git a/libs/community/langchain_community/tools/ainetwork/utils.py b/libs/community/langchain_community/tools/ainetwork/utils.py
new file mode 100644
index 00000000000..0f8179a6075
--- /dev/null
+++ b/libs/community/langchain_community/tools/ainetwork/utils.py
@@ -0,0 +1,62 @@
+"""AINetwork Blockchain tool utils."""
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Literal, Optional
+
+if TYPE_CHECKING:
+ from ain.ain import Ain
+
+
+def authenticate(network: Optional[Literal["mainnet", "testnet"]] = "testnet") -> Ain:
+ """Authenticate using the AIN Blockchain"""
+
+ try:
+ from ain.ain import Ain
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import ain-py related modules. Please install the package with "
+ "`pip install ain-py`."
+ ) from e
+
+ if network == "mainnet":
+ provider_url = "https://mainnet-api.ainetwork.ai/"
+ chain_id = 1
+ if "AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY" in os.environ:
+ private_key = os.environ["AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY"]
+ else:
+ raise EnvironmentError(
+ "Error: The AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY environmental variable "
+ "has not been set."
+ )
+ elif network == "testnet":
+ provider_url = "https://testnet-api.ainetwork.ai/"
+ chain_id = 0
+ if "AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY" in os.environ:
+ private_key = os.environ["AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY"]
+ else:
+ raise EnvironmentError(
+ "Error: The AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY environmental variable "
+ "has not been set."
+ )
+ elif network is None:
+ if (
+ "AIN_BLOCKCHAIN_PROVIDER_URL" in os.environ
+ and "AIN_BLOCKCHAIN_CHAIN_ID" in os.environ
+ and "AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY" in os.environ
+ ):
+ provider_url = os.environ["AIN_BLOCKCHAIN_PROVIDER_URL"]
+ chain_id = int(os.environ["AIN_BLOCKCHAIN_CHAIN_ID"])
+ private_key = os.environ["AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY"]
+ else:
+ raise EnvironmentError(
+ "Error: The AIN_BLOCKCHAIN_PROVIDER_URL and "
+ "AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY and AIN_BLOCKCHAIN_CHAIN_ID "
+ "environmental variable has not been set."
+ )
+ else:
+ raise ValueError(f"Unsupported 'network': {network}")
+
+ ain = Ain(provider_url, chain_id)
+ ain.wallet.addAndSetDefaultAccount(private_key)
+ return ain
diff --git a/libs/community/langchain_community/tools/ainetwork/value.py b/libs/community/langchain_community/tools/ainetwork/value.py
new file mode 100644
index 00000000000..300e36c573c
--- /dev/null
+++ b/libs/community/langchain_community/tools/ainetwork/value.py
@@ -0,0 +1,85 @@
+import builtins
+import json
+from typing import Optional, Type, Union
+
+from langchain_core.callbacks import AsyncCallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.ainetwork.base import AINBaseTool, OperationType
+
+
+class ValueSchema(BaseModel):
+ """Schema for value operations."""
+
+ type: OperationType = Field(...)
+ path: str = Field(..., description="Blockchain reference path")
+ value: Optional[Union[int, str, float, dict]] = Field(
+ None, description="Value to be set at the path"
+ )
+
+
+class AINValueOps(AINBaseTool):
+ """Tool for value operations."""
+
+ name: str = "AINvalueOps"
+ description: str = """
+Covers the read and write value for the AINetwork Blockchain database.
+
+## SET
+- Set a value at a given path
+
+### Example
+- type: SET
+- path: /apps/langchain_test_1/object
+- value: {1: 2, "34": 56}
+
+## GET
+- Retrieve a value at a given path
+
+### Example
+- type: GET
+- path: /apps/langchain_test_1/DB
+
+## Special paths
+- `/accounts//balance`: Account balance
+- `/accounts//nonce`: Account nonce
+- `/apps`: Applications
+- `/consensus`: Consensus
+- `/checkin`: Check-in
+- `/deposit///`: Deposit
+- `/deposit_accounts///`: Deposit accounts
+- `/escrow`: Escrow
+- `/payments`: Payment
+- `/sharding`: Sharding
+- `/token/name`: Token name
+- `/token/symbol`: Token symbol
+- `/token/total_supply`: Token total supply
+- `/transfer////value`: Transfer
+- `/withdraw///`: Withdraw
+"""
+ args_schema: Type[BaseModel] = ValueSchema
+
+ async def _arun(
+ self,
+ type: OperationType,
+ path: str,
+ value: Optional[Union[int, str, float, dict]] = None,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ from ain.types import ValueOnlyTransactionInput
+
+ try:
+ if type is OperationType.SET:
+ if value is None:
+ raise ValueError("'value' is required for SET operation.")
+
+ res = await self.interface.db.ref(path).setValue(
+ transactionInput=ValueOnlyTransactionInput(value=value)
+ )
+ elif type is OperationType.GET:
+ res = await self.interface.db.ref(path).getValue()
+ else:
+ raise ValueError(f"Unsupported 'type': {type}.")
+ return json.dumps(res, ensure_ascii=False)
+ except Exception as e:
+ return f"{builtins.type(e).__name__}: {str(e)}"
diff --git a/libs/community/langchain_community/tools/amadeus/__init__.py b/libs/community/langchain_community/tools/amadeus/__init__.py
new file mode 100644
index 00000000000..570958f8098
--- /dev/null
+++ b/libs/community/langchain_community/tools/amadeus/__init__.py
@@ -0,0 +1,9 @@
+"""Amadeus tools."""
+
+from langchain_community.tools.amadeus.closest_airport import AmadeusClosestAirport
+from langchain_community.tools.amadeus.flight_search import AmadeusFlightSearch
+
+__all__ = [
+ "AmadeusClosestAirport",
+ "AmadeusFlightSearch",
+]
diff --git a/libs/community/langchain_community/tools/amadeus/base.py b/libs/community/langchain_community/tools/amadeus/base.py
new file mode 100644
index 00000000000..3bc53bc3f0b
--- /dev/null
+++ b/libs/community/langchain_community/tools/amadeus/base.py
@@ -0,0 +1,18 @@
+"""Base class for Amadeus tools."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.amadeus.utils import authenticate
+
+if TYPE_CHECKING:
+ from amadeus import Client
+
+
+class AmadeusBaseTool(BaseTool):
+ """Base Tool for Amadeus."""
+
+ client: Client = Field(default_factory=authenticate)
diff --git a/libs/community/langchain_community/tools/amadeus/closest_airport.py b/libs/community/langchain_community/tools/amadeus/closest_airport.py
new file mode 100644
index 00000000000..4e8b90a1b2a
--- /dev/null
+++ b/libs/community/langchain_community/tools/amadeus/closest_airport.py
@@ -0,0 +1,50 @@
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.chat_models import ChatOpenAI
+from langchain_community.tools.amadeus.base import AmadeusBaseTool
+
+
+class ClosestAirportSchema(BaseModel):
+ """Schema for the AmadeusClosestAirport tool."""
+
+ location: str = Field(
+ description=(
+ " The location for which you would like to find the nearest airport "
+ " along with optional details such as country, state, region, or "
+ " province, allowing for easy processing and identification of "
+ " the closest airport. Examples of the format are the following:\n"
+ " Cali, Colombia\n "
+ " Lincoln, Nebraska, United States\n"
+ " New York, United States\n"
+ " Sydney, New South Wales, Australia\n"
+ " Rome, Lazio, Italy\n"
+ " Toronto, Ontario, Canada\n"
+ )
+ )
+
+
+class AmadeusClosestAirport(AmadeusBaseTool):
+ """Tool for finding the closest airport to a particular location."""
+
+ name: str = "closest_airport"
+ description: str = (
+ "Use this tool to find the closest airport to a particular location."
+ )
+ args_schema: Type[ClosestAirportSchema] = ClosestAirportSchema
+
+ def _run(
+ self,
+ location: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ content = (
+ f" What is the nearest airport to {location}? Please respond with the "
+ " airport's International Air Transport Association (IATA) Location "
+ ' Identifier in the following JSON format. JSON: "iataCode": "IATA '
+ ' Location Identifier" '
+ )
+
+ return ChatOpenAI(temperature=0).predict(content)
diff --git a/libs/community/langchain_community/tools/amadeus/flight_search.py b/libs/community/langchain_community/tools/amadeus/flight_search.py
new file mode 100644
index 00000000000..85c173c1198
--- /dev/null
+++ b/libs/community/langchain_community/tools/amadeus/flight_search.py
@@ -0,0 +1,152 @@
+import logging
+from datetime import datetime as dt
+from typing import Dict, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.amadeus.base import AmadeusBaseTool
+
+logger = logging.getLogger(__name__)
+
+
+class FlightSearchSchema(BaseModel):
+ """Schema for the AmadeusFlightSearch tool."""
+
+ originLocationCode: str = Field(
+ description=(
+ " The three letter International Air Transport "
+ " Association (IATA) Location Identifier for the "
+ " search's origin airport. "
+ )
+ )
+ destinationLocationCode: str = Field(
+ description=(
+ " The three letter International Air Transport "
+ " Association (IATA) Location Identifier for the "
+ " search's destination airport. "
+ )
+ )
+ departureDateTimeEarliest: str = Field(
+ description=(
+ " The earliest departure datetime from the origin airport "
+ " for the flight search in the following format: "
+ ' "YYYY-MM-DDTHH:MM", where "T" separates the date and time '
+ ' components. For example: "2023-06-09T10:30:00" represents '
+ " June 9th, 2023, at 10:30 AM. "
+ )
+ )
+ departureDateTimeLatest: str = Field(
+ description=(
+ " The latest departure datetime from the origin airport "
+ " for the flight search in the following format: "
+ ' "YYYY-MM-DDTHH:MM", where "T" separates the date and time '
+ ' components. For example: "2023-06-09T10:30:00" represents '
+ " June 9th, 2023, at 10:30 AM. "
+ )
+ )
+ page_number: int = Field(
+ default=1,
+ description="The specific page number of flight results to retrieve",
+ )
+
+
+class AmadeusFlightSearch(AmadeusBaseTool):
+ """Tool for searching for a single flight between two airports."""
+
+ name: str = "single_flight_search"
+ description: str = (
+ " Use this tool to search for a single flight between the origin and "
+ " destination airports at a departure between an earliest and "
+ " latest datetime. "
+ )
+ args_schema: Type[FlightSearchSchema] = FlightSearchSchema
+
+ def _run(
+ self,
+ originLocationCode: str,
+ destinationLocationCode: str,
+ departureDateTimeEarliest: str,
+ departureDateTimeLatest: str,
+ page_number: int = 1,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> list:
+ try:
+ from amadeus import ResponseError
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import amadeus, please install with `pip install amadeus`."
+ ) from e
+
+ RESULTS_PER_PAGE = 10
+
+ # Authenticate and retrieve a client
+ client = self.client
+
+ # Check that earliest and latest dates are in the same day
+ earliestDeparture = dt.strptime(departureDateTimeEarliest, "%Y-%m-%dT%H:%M:%S")
+ latestDeparture = dt.strptime(departureDateTimeLatest, "%Y-%m-%dT%H:%M:%S")
+
+ if earliestDeparture.date() != latestDeparture.date():
+ logger.error(
+ " Error: Earliest and latest departure dates need to be the "
+ " same date. If you're trying to search for round-trip "
+ " flights, call this function for the outbound flight first, "
+ " and then call again for the return flight. "
+ )
+ return [None]
+
+ # Collect all results from the Amadeus Flight Offers Search API
+ try:
+ response = client.shopping.flight_offers_search.get(
+ originLocationCode=originLocationCode,
+ destinationLocationCode=destinationLocationCode,
+ departureDate=latestDeparture.strftime("%Y-%m-%d"),
+ adults=1,
+ )
+ except ResponseError as error:
+ print(error)
+
+ # Generate output dictionary
+ output = []
+
+ for offer in response.data:
+ itinerary: Dict = {}
+ itinerary["price"] = {}
+ itinerary["price"]["total"] = offer["price"]["total"]
+ currency = offer["price"]["currency"]
+ currency = response.result["dictionaries"]["currencies"][currency]
+ itinerary["price"]["currency"] = {}
+ itinerary["price"]["currency"] = currency
+
+ segments = []
+ for segment in offer["itineraries"][0]["segments"]:
+ flight = {}
+ flight["departure"] = segment["departure"]
+ flight["arrival"] = segment["arrival"]
+ flight["flightNumber"] = segment["number"]
+ carrier = segment["carrierCode"]
+ carrier = response.result["dictionaries"]["carriers"][carrier]
+ flight["carrier"] = carrier
+
+ segments.append(flight)
+
+ itinerary["segments"] = []
+ itinerary["segments"] = segments
+
+ output.append(itinerary)
+
+ # Filter out flights after latest departure time
+ for index, offer in enumerate(output):
+ offerDeparture = dt.strptime(
+ offer["segments"][0]["departure"]["at"], "%Y-%m-%dT%H:%M:%S"
+ )
+
+ if offerDeparture > latestDeparture:
+ output.pop(index)
+
+ # Return the paginated results
+ startIndex = (page_number - 1) * RESULTS_PER_PAGE
+ endIndex = startIndex + RESULTS_PER_PAGE
+
+ return output[startIndex:endIndex]
diff --git a/libs/community/langchain_community/tools/amadeus/utils.py b/libs/community/langchain_community/tools/amadeus/utils.py
new file mode 100644
index 00000000000..7c04ec0528c
--- /dev/null
+++ b/libs/community/langchain_community/tools/amadeus/utils.py
@@ -0,0 +1,42 @@
+"""O365 tool utils."""
+from __future__ import annotations
+
+import logging
+import os
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from amadeus import Client
+
+logger = logging.getLogger(__name__)
+
+
+def authenticate() -> Client:
+ """Authenticate using the Amadeus API"""
+ try:
+ from amadeus import Client
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import amadeus. Please install the package with "
+ "`pip install amadeus`."
+ ) from e
+
+ if "AMADEUS_CLIENT_ID" in os.environ and "AMADEUS_CLIENT_SECRET" in os.environ:
+ client_id = os.environ["AMADEUS_CLIENT_ID"]
+ client_secret = os.environ["AMADEUS_CLIENT_SECRET"]
+ else:
+ logger.error(
+ "Error: The AMADEUS_CLIENT_ID and AMADEUS_CLIENT_SECRET environmental "
+ "variables have not been set. Visit the following link on how to "
+ "acquire these authorization tokens: "
+ "https://developers.amadeus.com/register"
+ )
+ return None
+
+ hostname = "test" # Default hostname
+ if "AMADEUS_HOSTNAME" in os.environ:
+ hostname = os.environ["AMADEUS_HOSTNAME"]
+
+ client = Client(client_id=client_id, client_secret=client_secret, hostname=hostname)
+
+ return client
diff --git a/libs/community/langchain_community/tools/arxiv/__init__.py b/libs/community/langchain_community/tools/arxiv/__init__.py
new file mode 100644
index 00000000000..2607cb19bb7
--- /dev/null
+++ b/libs/community/langchain_community/tools/arxiv/__init__.py
@@ -0,0 +1 @@
+"""Arxiv API toolkit."""
diff --git a/libs/community/langchain_community/tools/arxiv/tool.py b/libs/community/langchain_community/tools/arxiv/tool.py
new file mode 100644
index 00000000000..023cd8939c3
--- /dev/null
+++ b/libs/community/langchain_community/tools/arxiv/tool.py
@@ -0,0 +1,37 @@
+"""Tool for the Arxiv API."""
+
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.arxiv import ArxivAPIWrapper
+
+
+class ArxivInput(BaseModel):
+ query: str = Field(description="search query to look up")
+
+
+class ArxivQueryRun(BaseTool):
+ """Tool that searches the Arxiv API."""
+
+ name: str = "arxiv"
+ description: str = (
+ "A wrapper around Arxiv.org "
+ "Useful for when you need to answer questions about Physics, Mathematics, "
+ "Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
+ "Electrical Engineering, and Economics "
+ "from scientific articles on arxiv.org. "
+ "Input should be a search query."
+ )
+ api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper)
+ args_schema: Type[BaseModel] = ArxivInput
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Arxiv tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/__init__.py b/libs/community/langchain_community/tools/azure_cognitive_services/__init__.py
new file mode 100644
index 00000000000..1121e4e89d1
--- /dev/null
+++ b/libs/community/langchain_community/tools/azure_cognitive_services/__init__.py
@@ -0,0 +1,25 @@
+"""Azure Cognitive Services Tools."""
+
+from langchain_community.tools.azure_cognitive_services.form_recognizer import (
+ AzureCogsFormRecognizerTool,
+)
+from langchain_community.tools.azure_cognitive_services.image_analysis import (
+ AzureCogsImageAnalysisTool,
+)
+from langchain_community.tools.azure_cognitive_services.speech2text import (
+ AzureCogsSpeech2TextTool,
+)
+from langchain_community.tools.azure_cognitive_services.text2speech import (
+ AzureCogsText2SpeechTool,
+)
+from langchain_community.tools.azure_cognitive_services.text_analytics_health import (
+ AzureCogsTextAnalyticsHealthTool,
+)
+
+__all__ = [
+ "AzureCogsImageAnalysisTool",
+ "AzureCogsFormRecognizerTool",
+ "AzureCogsSpeech2TextTool",
+ "AzureCogsText2SpeechTool",
+ "AzureCogsTextAnalyticsHealthTool",
+]
diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/form_recognizer.py b/libs/community/langchain_community/tools/azure_cognitive_services/form_recognizer.py
new file mode 100644
index 00000000000..42d11b4ac4e
--- /dev/null
+++ b/libs/community/langchain_community/tools/azure_cognitive_services/form_recognizer.py
@@ -0,0 +1,143 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, List, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.tools import BaseTool
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.tools.azure_cognitive_services.utils import (
+ detect_file_src_type,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class AzureCogsFormRecognizerTool(BaseTool):
+ """Tool that queries the Azure Cognitive Services Form Recognizer API.
+
+ In order to set this up, follow instructions at:
+ https://learn.microsoft.com/en-us/azure/applied-ai-services/form-recognizer/quickstarts/get-started-sdks-rest-api?view=form-recog-3.0.0&pivots=programming-language-python
+ """
+
+ azure_cogs_key: str = "" #: :meta private:
+ azure_cogs_endpoint: str = "" #: :meta private:
+ doc_analysis_client: Any #: :meta private:
+
+ name: str = "azure_cognitive_services_form_recognizer"
+ description: str = (
+ "A wrapper around Azure Cognitive Services Form Recognizer. "
+ "Useful for when you need to "
+ "extract text, tables, and key-value pairs from documents. "
+ "Input should be a url to a document."
+ )
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and endpoint exists in environment."""
+ azure_cogs_key = get_from_dict_or_env(
+ values, "azure_cogs_key", "AZURE_COGS_KEY"
+ )
+
+ azure_cogs_endpoint = get_from_dict_or_env(
+ values, "azure_cogs_endpoint", "AZURE_COGS_ENDPOINT"
+ )
+
+ try:
+ from azure.ai.formrecognizer import DocumentAnalysisClient
+ from azure.core.credentials import AzureKeyCredential
+
+ values["doc_analysis_client"] = DocumentAnalysisClient(
+ endpoint=azure_cogs_endpoint,
+ credential=AzureKeyCredential(azure_cogs_key),
+ )
+
+ except ImportError:
+ raise ImportError(
+ "azure-ai-formrecognizer is not installed. "
+ "Run `pip install azure-ai-formrecognizer` to install."
+ )
+
+ return values
+
+ def _parse_tables(self, tables: List[Any]) -> List[Any]:
+ result = []
+ for table in tables:
+ rc, cc = table.row_count, table.column_count
+ _table = [["" for _ in range(cc)] for _ in range(rc)]
+ for cell in table.cells:
+ _table[cell.row_index][cell.column_index] = cell.content
+ result.append(_table)
+ return result
+
+ def _parse_kv_pairs(self, kv_pairs: List[Any]) -> List[Any]:
+ result = []
+ for kv_pair in kv_pairs:
+ key = kv_pair.key.content if kv_pair.key else ""
+ value = kv_pair.value.content if kv_pair.value else ""
+ result.append((key, value))
+ return result
+
+ def _document_analysis(self, document_path: str) -> Dict:
+ document_src_type = detect_file_src_type(document_path)
+ if document_src_type == "local":
+ with open(document_path, "rb") as document:
+ poller = self.doc_analysis_client.begin_analyze_document(
+ "prebuilt-document", document
+ )
+ elif document_src_type == "remote":
+ poller = self.doc_analysis_client.begin_analyze_document_from_url(
+ "prebuilt-document", document_path
+ )
+ else:
+ raise ValueError(f"Invalid document path: {document_path}")
+
+ result = poller.result()
+ res_dict = {}
+
+ if result.content is not None:
+ res_dict["content"] = result.content
+
+ if result.tables is not None:
+ res_dict["tables"] = self._parse_tables(result.tables)
+
+ if result.key_value_pairs is not None:
+ res_dict["key_value_pairs"] = self._parse_kv_pairs(result.key_value_pairs)
+
+ return res_dict
+
+ def _format_document_analysis_result(self, document_analysis_result: Dict) -> str:
+ formatted_result = []
+ if "content" in document_analysis_result:
+ formatted_result.append(
+ f"Content: {document_analysis_result['content']}".replace("\n", " ")
+ )
+
+ if "tables" in document_analysis_result:
+ for i, table in enumerate(document_analysis_result["tables"]):
+ formatted_result.append(f"Table {i}: {table}".replace("\n", " "))
+
+ if "key_value_pairs" in document_analysis_result:
+ for kv_pair in document_analysis_result["key_value_pairs"]:
+ formatted_result.append(
+ f"{kv_pair[0]}: {kv_pair[1]}".replace("\n", " ")
+ )
+
+ return "\n".join(formatted_result)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ try:
+ document_analysis_result = self._document_analysis(query)
+ if not document_analysis_result:
+ return "No good document analysis result was found"
+
+ return self._format_document_analysis_result(document_analysis_result)
+ except Exception as e:
+ raise RuntimeError(f"Error while running AzureCogsFormRecognizerTool: {e}")
diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/image_analysis.py b/libs/community/langchain_community/tools/azure_cognitive_services/image_analysis.py
new file mode 100644
index 00000000000..801aa57bd2d
--- /dev/null
+++ b/libs/community/langchain_community/tools/azure_cognitive_services/image_analysis.py
@@ -0,0 +1,147 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.tools import BaseTool
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.tools.azure_cognitive_services.utils import (
+ detect_file_src_type,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class AzureCogsImageAnalysisTool(BaseTool):
+ """Tool that queries the Azure Cognitive Services Image Analysis API.
+
+ In order to set this up, follow instructions at:
+ https://learn.microsoft.com/en-us/azure/cognitive-services/computer-vision/quickstarts-sdk/image-analysis-client-library-40
+ """
+
+ azure_cogs_key: str = "" #: :meta private:
+ azure_cogs_endpoint: str = "" #: :meta private:
+ vision_service: Any #: :meta private:
+ analysis_options: Any #: :meta private:
+
+ name: str = "azure_cognitive_services_image_analysis"
+ description: str = (
+ "A wrapper around Azure Cognitive Services Image Analysis. "
+ "Useful for when you need to analyze images. "
+ "Input should be a url to an image."
+ )
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and endpoint exists in environment."""
+ azure_cogs_key = get_from_dict_or_env(
+ values, "azure_cogs_key", "AZURE_COGS_KEY"
+ )
+
+ azure_cogs_endpoint = get_from_dict_or_env(
+ values, "azure_cogs_endpoint", "AZURE_COGS_ENDPOINT"
+ )
+
+ try:
+ import azure.ai.vision as sdk
+
+ values["vision_service"] = sdk.VisionServiceOptions(
+ endpoint=azure_cogs_endpoint, key=azure_cogs_key
+ )
+
+ values["analysis_options"] = sdk.ImageAnalysisOptions()
+ values["analysis_options"].features = (
+ sdk.ImageAnalysisFeature.CAPTION
+ | sdk.ImageAnalysisFeature.OBJECTS
+ | sdk.ImageAnalysisFeature.TAGS
+ | sdk.ImageAnalysisFeature.TEXT
+ )
+ except ImportError:
+ raise ImportError(
+ "azure-ai-vision is not installed. "
+ "Run `pip install azure-ai-vision` to install."
+ )
+
+ return values
+
+ def _image_analysis(self, image_path: str) -> Dict:
+ try:
+ import azure.ai.vision as sdk
+ except ImportError:
+ pass
+
+ image_src_type = detect_file_src_type(image_path)
+ if image_src_type == "local":
+ vision_source = sdk.VisionSource(filename=image_path)
+ elif image_src_type == "remote":
+ vision_source = sdk.VisionSource(url=image_path)
+ else:
+ raise ValueError(f"Invalid image path: {image_path}")
+
+ image_analyzer = sdk.ImageAnalyzer(
+ self.vision_service, vision_source, self.analysis_options
+ )
+ result = image_analyzer.analyze()
+
+ res_dict = {}
+ if result.reason == sdk.ImageAnalysisResultReason.ANALYZED:
+ if result.caption is not None:
+ res_dict["caption"] = result.caption.content
+
+ if result.objects is not None:
+ res_dict["objects"] = [obj.name for obj in result.objects]
+
+ if result.tags is not None:
+ res_dict["tags"] = [tag.name for tag in result.tags]
+
+ if result.text is not None:
+ res_dict["text"] = [line.content for line in result.text.lines]
+
+ else:
+ error_details = sdk.ImageAnalysisErrorDetails.from_result(result)
+ raise RuntimeError(
+ f"Image analysis failed.\n"
+ f"Reason: {error_details.reason}\n"
+ f"Details: {error_details.message}"
+ )
+
+ return res_dict
+
+ def _format_image_analysis_result(self, image_analysis_result: Dict) -> str:
+ formatted_result = []
+ if "caption" in image_analysis_result:
+ formatted_result.append("Caption: " + image_analysis_result["caption"])
+
+ if (
+ "objects" in image_analysis_result
+ and len(image_analysis_result["objects"]) > 0
+ ):
+ formatted_result.append(
+ "Objects: " + ", ".join(image_analysis_result["objects"])
+ )
+
+ if "tags" in image_analysis_result and len(image_analysis_result["tags"]) > 0:
+ formatted_result.append("Tags: " + ", ".join(image_analysis_result["tags"]))
+
+ if "text" in image_analysis_result and len(image_analysis_result["text"]) > 0:
+ formatted_result.append("Text: " + ", ".join(image_analysis_result["text"]))
+
+ return "\n".join(formatted_result)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ try:
+ image_analysis_result = self._image_analysis(query)
+ if not image_analysis_result:
+ return "No good image analysis result was found"
+
+ return self._format_image_analysis_result(image_analysis_result)
+ except Exception as e:
+ raise RuntimeError(f"Error while running AzureCogsImageAnalysisTool: {e}")
diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/speech2text.py b/libs/community/langchain_community/tools/azure_cognitive_services/speech2text.py
new file mode 100644
index 00000000000..b1aa90cb701
--- /dev/null
+++ b/libs/community/langchain_community/tools/azure_cognitive_services/speech2text.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+import logging
+import time
+from typing import Any, Dict, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.tools import BaseTool
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.tools.azure_cognitive_services.utils import (
+ detect_file_src_type,
+ download_audio_from_url,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class AzureCogsSpeech2TextTool(BaseTool):
+ """Tool that queries the Azure Cognitive Services Speech2Text API.
+
+ In order to set this up, follow instructions at:
+ https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/get-started-speech-to-text?pivots=programming-language-python
+ """
+
+ azure_cogs_key: str = "" #: :meta private:
+ azure_cogs_region: str = "" #: :meta private:
+ speech_language: str = "en-US" #: :meta private:
+ speech_config: Any #: :meta private:
+
+ name: str = "azure_cognitive_services_speech2text"
+ description: str = (
+ "A wrapper around Azure Cognitive Services Speech2Text. "
+ "Useful for when you need to transcribe audio to text. "
+ "Input should be a url to an audio file."
+ )
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and endpoint exists in environment."""
+ azure_cogs_key = get_from_dict_or_env(
+ values, "azure_cogs_key", "AZURE_COGS_KEY"
+ )
+
+ azure_cogs_region = get_from_dict_or_env(
+ values, "azure_cogs_region", "AZURE_COGS_REGION"
+ )
+
+ try:
+ import azure.cognitiveservices.speech as speechsdk
+
+ values["speech_config"] = speechsdk.SpeechConfig(
+ subscription=azure_cogs_key, region=azure_cogs_region
+ )
+ except ImportError:
+ raise ImportError(
+ "azure-cognitiveservices-speech is not installed. "
+ "Run `pip install azure-cognitiveservices-speech` to install."
+ )
+
+ return values
+
+ def _continuous_recognize(self, speech_recognizer: Any) -> str:
+ done = False
+ text = ""
+
+ def stop_cb(evt: Any) -> None:
+ """callback that stop continuous recognition"""
+ speech_recognizer.stop_continuous_recognition_async()
+ nonlocal done
+ done = True
+
+ def retrieve_cb(evt: Any) -> None:
+ """callback that retrieves the intermediate recognition results"""
+ nonlocal text
+ text += evt.result.text
+
+ # retrieve text on recognized events
+ speech_recognizer.recognized.connect(retrieve_cb)
+ # stop continuous recognition on either session stopped or canceled events
+ speech_recognizer.session_stopped.connect(stop_cb)
+ speech_recognizer.canceled.connect(stop_cb)
+
+ # Start continuous speech recognition
+ speech_recognizer.start_continuous_recognition_async()
+ while not done:
+ time.sleep(0.5)
+ return text
+
+ def _speech2text(self, audio_path: str, speech_language: str) -> str:
+ try:
+ import azure.cognitiveservices.speech as speechsdk
+ except ImportError:
+ pass
+
+ audio_src_type = detect_file_src_type(audio_path)
+ if audio_src_type == "local":
+ audio_config = speechsdk.AudioConfig(filename=audio_path)
+ elif audio_src_type == "remote":
+ tmp_audio_path = download_audio_from_url(audio_path)
+ audio_config = speechsdk.AudioConfig(filename=tmp_audio_path)
+ else:
+ raise ValueError(f"Invalid audio path: {audio_path}")
+
+ self.speech_config.speech_recognition_language = speech_language
+ speech_recognizer = speechsdk.SpeechRecognizer(self.speech_config, audio_config)
+ return self._continuous_recognize(speech_recognizer)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ try:
+ text = self._speech2text(query, self.speech_language)
+ return text
+ except Exception as e:
+ raise RuntimeError(f"Error while running AzureCogsSpeech2TextTool: {e}")
diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/text2speech.py b/libs/community/langchain_community/tools/azure_cognitive_services/text2speech.py
new file mode 100644
index 00000000000..f65049f13a9
--- /dev/null
+++ b/libs/community/langchain_community/tools/azure_cognitive_services/text2speech.py
@@ -0,0 +1,102 @@
+from __future__ import annotations
+
+import logging
+import tempfile
+from typing import Any, Dict, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.tools import BaseTool
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class AzureCogsText2SpeechTool(BaseTool):
+ """Tool that queries the Azure Cognitive Services Text2Speech API.
+
+ In order to set this up, follow instructions at:
+ https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/get-started-text-to-speech?pivots=programming-language-python
+ """
+
+ azure_cogs_key: str = "" #: :meta private:
+ azure_cogs_region: str = "" #: :meta private:
+ speech_language: str = "en-US" #: :meta private:
+ speech_config: Any #: :meta private:
+
+ name: str = "azure_cognitive_services_text2speech"
+ description: str = (
+ "A wrapper around Azure Cognitive Services Text2Speech. "
+ "Useful for when you need to convert text to speech. "
+ )
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and endpoint exists in environment."""
+ azure_cogs_key = get_from_dict_or_env(
+ values, "azure_cogs_key", "AZURE_COGS_KEY"
+ )
+
+ azure_cogs_region = get_from_dict_or_env(
+ values, "azure_cogs_region", "AZURE_COGS_REGION"
+ )
+
+ try:
+ import azure.cognitiveservices.speech as speechsdk
+
+ values["speech_config"] = speechsdk.SpeechConfig(
+ subscription=azure_cogs_key, region=azure_cogs_region
+ )
+ except ImportError:
+ raise ImportError(
+ "azure-cognitiveservices-speech is not installed. "
+ "Run `pip install azure-cognitiveservices-speech` to install."
+ )
+
+ return values
+
+ def _text2speech(self, text: str, speech_language: str) -> str:
+ try:
+ import azure.cognitiveservices.speech as speechsdk
+ except ImportError:
+ pass
+
+ self.speech_config.speech_synthesis_language = speech_language
+ speech_synthesizer = speechsdk.SpeechSynthesizer(
+ speech_config=self.speech_config, audio_config=None
+ )
+ result = speech_synthesizer.speak_text(text)
+
+ if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
+ stream = speechsdk.AudioDataStream(result)
+ with tempfile.NamedTemporaryFile(
+ mode="wb", suffix=".wav", delete=False
+ ) as f:
+ stream.save_to_wav_file(f.name)
+
+ return f.name
+
+ elif result.reason == speechsdk.ResultReason.Canceled:
+ cancellation_details = result.cancellation_details
+ logger.debug(f"Speech synthesis canceled: {cancellation_details.reason}")
+ if cancellation_details.reason == speechsdk.CancellationReason.Error:
+ raise RuntimeError(
+ f"Speech synthesis error: {cancellation_details.error_details}"
+ )
+
+ return "Speech synthesis canceled."
+
+ else:
+ return f"Speech synthesis failed: {result.reason}"
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ try:
+ speech_file = self._text2speech(query, self.speech_language)
+ return speech_file
+ except Exception as e:
+ raise RuntimeError(f"Error while running AzureCogsText2SpeechTool: {e}")
diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/text_analytics_health.py b/libs/community/langchain_community/tools/azure_cognitive_services/text_analytics_health.py
new file mode 100644
index 00000000000..00e97dfc55a
--- /dev/null
+++ b/libs/community/langchain_community/tools/azure_cognitive_services/text_analytics_health.py
@@ -0,0 +1,104 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.tools import BaseTool
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class AzureCogsTextAnalyticsHealthTool(BaseTool):
+ """Tool that queries the Azure Cognitive Services Text Analytics for Health API.
+
+ In order to set this up, follow instructions at:
+ https://learn.microsoft.com/en-us/azure/ai-services/language-service/text-analytics-for-health/quickstart?tabs=windows&pivots=programming-language-python
+ """
+
+ azure_cogs_key: str = "" #: :meta private:
+ azure_cogs_endpoint: str = "" #: :meta private:
+ text_analytics_client: Any #: :meta private:
+
+ name: str = "azure_cognitive_services_text_analyics_health"
+ description: str = (
+ "A wrapper around Azure Cognitive Services Text Analytics for Health. "
+ "Useful for when you need to identify entities in healthcare data. "
+ "Input should be text."
+ )
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and endpoint exists in environment."""
+ azure_cogs_key = get_from_dict_or_env(
+ values, "azure_cogs_key", "AZURE_COGS_KEY"
+ )
+
+ azure_cogs_endpoint = get_from_dict_or_env(
+ values, "azure_cogs_endpoint", "AZURE_COGS_ENDPOINT"
+ )
+
+ try:
+ import azure.ai.textanalytics as sdk
+ from azure.core.credentials import AzureKeyCredential
+
+ values["text_analytics_client"] = sdk.TextAnalyticsClient(
+ endpoint=azure_cogs_endpoint,
+ credential=AzureKeyCredential(azure_cogs_key),
+ )
+
+ except ImportError:
+ raise ImportError(
+ "azure-ai-textanalytics is not installed. "
+ "Run `pip install azure-ai-textanalytics` to install."
+ )
+
+ return values
+
+ def _text_analysis(self, text: str) -> Dict:
+ poller = self.text_analytics_client.begin_analyze_healthcare_entities(
+ [{"id": "1", "language": "en", "text": text}]
+ )
+
+ result = poller.result()
+
+ res_dict = {}
+
+ docs = [doc for doc in result if not doc.is_error]
+
+ if docs is not None:
+ res_dict["entities"] = [
+ f"{x.text} is a healthcare entity of type {x.category}"
+ for y in docs
+ for x in y.entities
+ ]
+
+ return res_dict
+
+ def _format_text_analysis_result(self, text_analysis_result: Dict) -> str:
+ formatted_result = []
+ if "entities" in text_analysis_result:
+ formatted_result.append(
+ f"""The text contains the following healthcare entities: {
+ ', '.join(text_analysis_result['entities'])
+ }""".replace("\n", " ")
+ )
+
+ return "\n".join(formatted_result)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ try:
+ text_analysis_result = self._text_analysis(query)
+
+ return self._format_text_analysis_result(text_analysis_result)
+ except Exception as e:
+ raise RuntimeError(
+ f"Error while running AzureCogsTextAnalyticsHealthTool: {e}"
+ )
diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/utils.py b/libs/community/langchain_community/tools/azure_cognitive_services/utils.py
new file mode 100644
index 00000000000..9de8f923b72
--- /dev/null
+++ b/libs/community/langchain_community/tools/azure_cognitive_services/utils.py
@@ -0,0 +1,29 @@
+import os
+import tempfile
+from urllib.parse import urlparse
+
+import requests
+
+
+def detect_file_src_type(file_path: str) -> str:
+ """Detect if the file is local or remote."""
+ if os.path.isfile(file_path):
+ return "local"
+
+ parsed_url = urlparse(file_path)
+ if parsed_url.scheme and parsed_url.netloc:
+ return "remote"
+
+ return "invalid"
+
+
+def download_audio_from_url(audio_url: str) -> str:
+ """Download audio from url to local."""
+ ext = audio_url.split(".")[-1]
+ response = requests.get(audio_url, stream=True)
+ response.raise_for_status()
+ with tempfile.NamedTemporaryFile(mode="wb", suffix=f".{ext}", delete=False) as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+
+ return f.name
diff --git a/libs/langchain/tests/integration_tests/retrievers/__init__.py b/libs/community/langchain_community/tools/bearly/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/retrievers/__init__.py
rename to libs/community/langchain_community/tools/bearly/__init__.py
diff --git a/libs/community/langchain_community/tools/bearly/tool.py b/libs/community/langchain_community/tools/bearly/tool.py
new file mode 100644
index 00000000000..8f4e46c0faa
--- /dev/null
+++ b/libs/community/langchain_community/tools/bearly/tool.py
@@ -0,0 +1,162 @@
+import base64
+import itertools
+import json
+import re
+from pathlib import Path
+from typing import Dict, List, Type
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools import Tool
+
+
+def strip_markdown_code(md_string: str) -> str:
+ """Strip markdown code from a string."""
+ stripped_string = re.sub(r"^`{1,3}.*?\n", "", md_string, flags=re.DOTALL)
+ stripped_string = re.sub(r"`{1,3}$", "", stripped_string)
+ return stripped_string
+
+
+def head_file(path: str, n: int) -> List[str]:
+ """Get the first n lines of a file."""
+ try:
+ with open(path, "r") as f:
+ return [str(line) for line in itertools.islice(f, n)]
+ except Exception:
+ return []
+
+
+def file_to_base64(path: str) -> str:
+ """Convert a file to base64."""
+ with open(path, "rb") as f:
+ return base64.b64encode(f.read()).decode()
+
+
+class BearlyInterpreterToolArguments(BaseModel):
+ """Arguments for the BearlyInterpreterTool."""
+
+ python_code: str = Field(
+ ...,
+ example="print('Hello World')",
+ description=(
+ "The pure python script to be evaluated. "
+ "The contents will be in main.py. "
+ "It should not be in markdown format."
+ ),
+ )
+
+
+base_description = """Evaluates python code in a sandbox environment. \
+The environment resets on every execution. \
+You must send the whole script every time and print your outputs. \
+Script should be pure python code that can be evaluated. \
+It should be in python format NOT markdown. \
+The code should NOT be wrapped in backticks. \
+All python packages including requests, matplotlib, scipy, numpy, pandas, \
+etc are available. \
+If you have any files outputted write them to "output/" relative to the execution \
+path. Output can only be read from the directory, stdout, and stdin. \
+Do not use things like plot.show() as it will \
+not work instead write them out `output/` and a link to the file will be returned. \
+print() any output and results so you can capture the output."""
+
+
+class FileInfo(BaseModel):
+ """Information about a file to be uploaded."""
+
+ source_path: str
+ description: str
+ target_path: str
+
+
+class BearlyInterpreterTool:
+ """Tool for evaluating python code in a sandbox environment."""
+
+ api_key: str
+ endpoint = "https://exec.bearly.ai/v1/interpreter"
+ name = "bearly_interpreter"
+ args_schema: Type[BaseModel] = BearlyInterpreterToolArguments
+ files: Dict[str, FileInfo] = {}
+
+ def __init__(self, api_key: str):
+ self.api_key = api_key
+
+ @property
+ def file_description(self) -> str:
+ if len(self.files) == 0:
+ return ""
+ lines = ["The following files available in the evaluation environment:"]
+ for target_path, file_info in self.files.items():
+ peek_content = head_file(file_info.source_path, 4)
+ lines.append(
+ f"- path: `{target_path}` \n first four lines: {peek_content}"
+ f" \n description: `{file_info.description}`"
+ )
+ return "\n".join(lines)
+
+ @property
+ def description(self) -> str:
+ return (base_description + "\n\n" + self.file_description).strip()
+
+ def make_input_files(self) -> List[dict]:
+ files = []
+ for target_path, file_info in self.files.items():
+ files.append(
+ {
+ "pathname": target_path,
+ "contentsBasesixtyfour": file_to_base64(file_info.source_path),
+ }
+ )
+ return files
+
+ def _run(self, python_code: str) -> dict:
+ script = strip_markdown_code(python_code)
+ resp = requests.post(
+ "https://exec.bearly.ai/v1/interpreter",
+ data=json.dumps(
+ {
+ "fileContents": script,
+ "inputFiles": self.make_input_files(),
+ "outputDir": "output/",
+ "outputAsLinks": True,
+ }
+ ),
+ headers={"Authorization": self.api_key},
+ ).json()
+ return {
+ "stdout": base64.b64decode(resp["stdoutBasesixtyfour"]).decode()
+ if resp["stdoutBasesixtyfour"]
+ else "",
+ "stderr": base64.b64decode(resp["stderrBasesixtyfour"]).decode()
+ if resp["stderrBasesixtyfour"]
+ else "",
+ "fileLinks": resp["fileLinks"],
+ "exitCode": resp["exitCode"],
+ }
+
+ async def _arun(self, query: str) -> str:
+ """Use the tool asynchronously."""
+ raise NotImplementedError("custom_search does not support async")
+
+ def add_file(self, source_path: str, target_path: str, description: str) -> None:
+ if target_path in self.files:
+ raise ValueError("target_path already exists")
+ if not Path(source_path).exists():
+ raise ValueError("source_path does not exist")
+ self.files[target_path] = FileInfo(
+ target_path=target_path, source_path=source_path, description=description
+ )
+
+ def clear_files(self) -> None:
+ self.files = {}
+
+ # TODO: this is because we can't have a dynamic description
+ # because of the base pydantic class
+ def as_tool(self) -> Tool:
+ return Tool.from_function(
+ func=self._run,
+ name=self.name,
+ description=self.description,
+ args_schema=self.args_schema,
+ )
diff --git a/libs/community/langchain_community/tools/bing_search/__init__.py b/libs/community/langchain_community/tools/bing_search/__init__.py
new file mode 100644
index 00000000000..b5e133a05a0
--- /dev/null
+++ b/libs/community/langchain_community/tools/bing_search/__init__.py
@@ -0,0 +1,5 @@
+"""Bing Search API toolkit."""
+
+from langchain_community.tools.bing_search.tool import BingSearchResults, BingSearchRun
+
+__all__ = ["BingSearchRun", "BingSearchResults"]
diff --git a/libs/community/langchain_community/tools/bing_search/tool.py b/libs/community/langchain_community/tools/bing_search/tool.py
new file mode 100644
index 00000000000..027f2750c94
--- /dev/null
+++ b/libs/community/langchain_community/tools/bing_search/tool.py
@@ -0,0 +1,49 @@
+"""Tool for the Bing search API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.bing_search import BingSearchAPIWrapper
+
+
+class BingSearchRun(BaseTool):
+ """Tool that queries the Bing search API."""
+
+ name: str = "bing_search"
+ description: str = (
+ "A wrapper around Bing Search. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query."
+ )
+ api_wrapper: BingSearchAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
+
+
+class BingSearchResults(BaseTool):
+ """Tool that queries the Bing Search API and gets back json."""
+
+ name: str = "Bing Search Results JSON"
+ description: str = (
+ "A wrapper around Bing Search. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query. Output is a JSON array of the query results"
+ )
+ num_results: int = 4
+ api_wrapper: BingSearchAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return str(self.api_wrapper.results(query, self.num_results))
diff --git a/libs/langchain/tests/integration_tests/retrievers/docarray/__init__.py b/libs/community/langchain_community/tools/brave_search/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/retrievers/docarray/__init__.py
rename to libs/community/langchain_community/tools/brave_search/__init__.py
diff --git a/libs/community/langchain_community/tools/brave_search/tool.py b/libs/community/langchain_community/tools/brave_search/tool.py
new file mode 100644
index 00000000000..4ca8d68501e
--- /dev/null
+++ b/libs/community/langchain_community/tools/brave_search/tool.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+from typing import Any, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.brave_search import BraveSearchWrapper
+
+
+class BraveSearch(BaseTool):
+ """Tool that queries the BraveSearch."""
+
+ name: str = "brave_search"
+ description: str = (
+ "a search engine. "
+ "useful for when you need to answer questions about current events."
+ " input should be a search query."
+ )
+ search_wrapper: BraveSearchWrapper
+
+ @classmethod
+ def from_api_key(
+ cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any
+ ) -> BraveSearch:
+ """Create a tool from an api key.
+
+ Args:
+ api_key: The api key to use.
+ search_kwargs: Any additional kwargs to pass to the search wrapper.
+ **kwargs: Any additional kwargs to pass to the tool.
+
+ Returns:
+ A tool.
+ """
+ wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {})
+ return cls(search_wrapper=wrapper, **kwargs)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.search_wrapper.run(query)
diff --git a/libs/langchain/tests/integration_tests/storage/__init__.py b/libs/community/langchain_community/tools/clickup/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/storage/__init__.py
rename to libs/community/langchain_community/tools/clickup/__init__.py
diff --git a/libs/community/langchain_community/tools/clickup/prompt.py b/libs/community/langchain_community/tools/clickup/prompt.py
new file mode 100644
index 00000000000..5f1f51dbf77
--- /dev/null
+++ b/libs/community/langchain_community/tools/clickup/prompt.py
@@ -0,0 +1,131 @@
+# flake8: noqa
+CLICKUP_TASK_CREATE_PROMPT = """
+ This tool is a wrapper around clickup's create_task API, useful when you need to create a CLICKUP task.
+ The input to this tool is a dictionary specifying the fields of the CLICKUP task, and will be passed into clickup's CLICKUP `create_task` function.
+ Only add fields described by the user.
+ Use the following mapping in order to map the user's priority to the clickup priority: {{
+ Urgent = 1,
+ High = 2,
+ Normal = 3,
+ Low = 4,
+ }}. If the user passes in "urgent" replace the priority value as 1.
+
+ Here are a few task descriptions and corresponding input examples:
+ Task: create a task called "Daily report"
+ Example Input: {{"name": "Daily report"}}
+ Task: Make an open task called "ClickUp toolkit refactor" with description "Refactor the clickup toolkit to use dataclasses for parsing", with status "open"
+ Example Input: {{"name": "ClickUp toolkit refactor", "description": "Refactor the clickup toolkit to use dataclasses for parsing", "status": "Open"}}
+ Task: create a task with priority 3 called "New Task Name" with description "New Task Description", with status "open"
+ Example Input: {{"name": "New Task Name", "description": "New Task Description", "status": "Open", "priority": 3}}
+ Task: Add a task called "Bob's task" and assign it to Bob (user id: 81928627)
+ Example Input: {{"name": "Bob's task", "description": "Task for Bob", "assignees": [81928627]}}
+ """
+
+CLICKUP_LIST_CREATE_PROMPT = """
+ This tool is a wrapper around clickup's create_list API, useful when you need to create a CLICKUP list.
+ The input to this tool is a dictionary specifying the fields of a clickup list, and will be passed to clickup's create_list function.
+ Only add fields described by the user.
+ Use the following mapping in order to map the user's priority to the clickup priority: {{
+ Urgent = 1,
+ High = 2,
+ Normal = 3,
+ Low = 4,
+ }}. If the user passes in "urgent" replace the priority value as 1.
+
+ Here are a few list descriptions and corresponding input examples:
+ Description: make a list with name "General List"
+ Example Input: {{"name": "General List"}}
+ Description: add a new list ("TODOs") with low priority
+ Example Input: {{"name": "General List", "priority": 4}}
+ Description: create a list with name "List name", content "List content", priority 2, and status "red"
+ Example Input: {{"name": "List name", "content": "List content", "priority": 2, "status": "red"}}
+"""
+
+CLICKUP_FOLDER_CREATE_PROMPT = """
+ This tool is a wrapper around clickup's create_folder API, useful when you need to create a CLICKUP folder.
+ The input to this tool is a dictionary specifying the fields of a clickup folder, and will be passed to clickup's create_folder function.
+ For example, to create a folder with name "Folder name" you would pass in the following dictionary:
+ {{
+ "name": "Folder name",
+ }}
+"""
+
+CLICKUP_GET_TASK_PROMPT = """
+ This tool is a wrapper around clickup's API,
+ Do NOT use to get a task specific attribute. Use get task attribute instead.
+ useful when you need to get a specific task for the user. Given the task id you want to create a request similar to the following dictionary:
+ payload = {{"task_id": "86a0t44tq"}}
+ """
+
+CLICKUP_GET_TASK_ATTRIBUTE_PROMPT = """
+ This tool is a wrapper around clickup's API,
+ useful when you need to get a specific attribute from a task. Given the task id and desired attribute create a request similar to the following dictionary:
+ payload = {{"task_id": "", "attribute_name": ""}}
+
+ Here are some example queries their corresponding payloads:
+ Get the name of task 23jn23kjn -> {{"task_id": "23jn23kjn", "attribute_name": "name"}}
+ What is the priority of task 86a0t44tq? -> {{"task_id": "86a0t44tq", "attribute_name": "priority"}}
+ Output the description of task sdc9ds9jc -> {{"task_id": "sdc9ds9jc", "attribute_name": "description"}}
+ Who is assigned to task bgjfnbfg0 -> {{"task_id": "bgjfnbfg0", "attribute_name": "assignee"}}
+ Which is the status of task kjnsdcjc? -> {{"task_id": "kjnsdcjc", "attribute_name": "description"}}
+ How long is the time estimate of task sjncsd999? -> {{"task_id": "sjncsd999", "attribute_name": "time_estimate"}}
+ Is task jnsd98sd archived?-> {{"task_id": "jnsd98sd", "attribute_name": "archive"}}
+ """
+
+CLICKUP_GET_ALL_TEAMS_PROMPT = """
+ This tool is a wrapper around clickup's API, useful when you need to get all teams that the user is a part of.
+ To get a list of all the teams there is no necessary request parameters.
+ """
+
+CLICKUP_GET_LIST_PROMPT = """
+ This tool is a wrapper around clickup's API,
+ useful when you need to get a specific list for the user. Given the list id you want to create a request similar to the following dictionary:
+ payload = {{"list_id": "901300608424"}}
+ """
+
+CLICKUP_GET_FOLDERS_PROMPT = """
+ This tool is a wrapper around clickup's API,
+ useful when you need to get a specific folder for the user. Given the user's workspace id you want to create a request similar to the following dictionary:
+ payload = {{"folder_id": "90130119692"}}
+ """
+
+CLICKUP_GET_SPACES_PROMPT = """
+ This tool is a wrapper around clickup's API,
+ useful when you need to get all the spaces available to a user. Given the user's workspace id you want to create a request similar to the following dictionary:
+ payload = {{"team_id": "90130119692"}}
+ """
+
+CLICKUP_GET_SPACES_PROMPT = """
+ This tool is a wrapper around clickup's API,
+ useful when you need to get all the spaces available to a user. Given the user's workspace id you want to create a request similar to the following dictionary:
+ payload = {{"team_id": "90130119692"}}
+ """
+
+CLICKUP_UPDATE_TASK_PROMPT = """
+ This tool is a wrapper around clickup's API,
+ useful when you need to update a specific attribute of a task. Given the task id, desired attribute to change and the new value you want to create a request similar to the following dictionary:
+ payload = {{"task_id": "", "attribute_name": "", "value": ""}}
+
+ Here are some example queries their corresponding payloads:
+ Change the name of task 23jn23kjn to new task name -> {{"task_id": "23jn23kjn", "attribute_name": "name", "value": "new task name"}}
+ Update the priority of task 86a0t44tq to 1 -> {{"task_id": "86a0t44tq", "attribute_name": "priority", "value": 1}}
+ Re-write the description of task sdc9ds9jc to 'a new task description' -> {{"task_id": "sdc9ds9jc", "attribute_name": "description", "value": "a new task description"}}
+ Forward the status of task kjnsdcjc to done -> {{"task_id": "kjnsdcjc", "attribute_name": "description", "status": "done"}}
+ Increase the time estimate of task sjncsd999 to 3h -> {{"task_id": "sjncsd999", "attribute_name": "time_estimate", "value": 8000}}
+ Archive task jnsd98sd -> {{"task_id": "jnsd98sd", "attribute_name": "archive", "value": true}}
+ *IMPORTANT*: Pay attention to the exact syntax above and the correct use of quotes.
+ For changing priority and time estimates, we expect integers (int).
+ For name, description and status we expect strings (str).
+ For archive, we expect a boolean (bool).
+ """
+
+CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT = """
+ This tool is a wrapper around clickup's API,
+ useful when you need to update the assignees of a task. Given the task id, the operation add or remove (rem), and the list of user ids. You want to create a request similar to the following dictionary:
+ payload = {{"task_id": "", "operation": "", "users": [, ]}}
+
+ Here are some example queries their corresponding payloads:
+ Add 81928627 and 3987234 as assignees to task 21hw21jn -> {{"task_id": "21hw21jn", "operation": "add", "users": [81928627, 3987234]}}
+ Remove 67823487 as assignee from task jin34ji4 -> {{"task_id": "jin34ji4", "operation": "rem", "users": [67823487]}}
+ *IMPORTANT*: Users id should always be ints.
+ """
diff --git a/libs/community/langchain_community/tools/clickup/tool.py b/libs/community/langchain_community/tools/clickup/tool.py
new file mode 100644
index 00000000000..93988dd7d59
--- /dev/null
+++ b/libs/community/langchain_community/tools/clickup/tool.py
@@ -0,0 +1,42 @@
+"""
+This tool allows agents to interact with the clickup library
+and operate on a Clickup instance.
+To use this tool, you must first set as environment variables:
+ client_secret
+ client_id
+ code
+
+Below is a sample script that uses the Clickup tool:
+
+```python
+from langchain_community.agent_toolkits.clickup.toolkit import ClickupToolkit
+from langchain_community.utilities.clickup import ClickupAPIWrapper
+
+clickup = ClickupAPIWrapper()
+toolkit = ClickupToolkit.from_clickup_api_wrapper(clickup)
+```
+"""
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.clickup import ClickupAPIWrapper
+
+
+class ClickupAction(BaseTool):
+ """Tool that queries the Clickup API."""
+
+ api_wrapper: ClickupAPIWrapper = Field(default_factory=ClickupAPIWrapper)
+ mode: str
+ name: str = ""
+ description: str = ""
+
+ def _run(
+ self,
+ instructions: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Clickup API to run an operation."""
+ return self.api_wrapper.run(self.mode, instructions)
diff --git a/libs/community/langchain_community/tools/convert_to_openai.py b/libs/community/langchain_community/tools/convert_to_openai.py
new file mode 100644
index 00000000000..ff1193fd410
--- /dev/null
+++ b/libs/community/langchain_community/tools/convert_to_openai.py
@@ -0,0 +1,4 @@
+from langchain_community.tools.render import format_tool_to_openai_function
+
+# For backwards compatibility
+__all__ = ["format_tool_to_openai_function"]
diff --git a/libs/community/langchain_community/tools/dataforseo_api_search/__init__.py b/libs/community/langchain_community/tools/dataforseo_api_search/__init__.py
new file mode 100644
index 00000000000..1e2cd9efe9e
--- /dev/null
+++ b/libs/community/langchain_community/tools/dataforseo_api_search/__init__.py
@@ -0,0 +1,9 @@
+from langchain_community.tools.dataforseo_api_search.tool import (
+ DataForSeoAPISearchResults,
+ DataForSeoAPISearchRun,
+)
+
+"""DataForSeo API Toolkit."""
+"""Tool for the DataForSeo SERP API."""
+
+__all__ = ["DataForSeoAPISearchRun", "DataForSeoAPISearchResults"]
diff --git a/libs/community/langchain_community/tools/dataforseo_api_search/tool.py b/libs/community/langchain_community/tools/dataforseo_api_search/tool.py
new file mode 100644
index 00000000000..bb10187f8d5
--- /dev/null
+++ b/libs/community/langchain_community/tools/dataforseo_api_search/tool.py
@@ -0,0 +1,71 @@
+"""Tool for the DataForSeo SERP API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.dataforseo_api_search import DataForSeoAPIWrapper
+
+
+class DataForSeoAPISearchRun(BaseTool):
+ """Tool that queries the DataForSeo Google search API."""
+
+ name: str = "dataforseo_api_search"
+ description: str = (
+ "A robust Google Search API provided by DataForSeo."
+ "This tool is handy when you need information about trending topics "
+ "or current events."
+ )
+ api_wrapper: DataForSeoAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return str(self.api_wrapper.run(query))
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ return (await self.api_wrapper.arun(query)).__str__()
+
+
+class DataForSeoAPISearchResults(BaseTool):
+ """Tool that queries the DataForSeo Google Search API
+ and get back json."""
+
+ name: str = "DataForSeo-Results-JSON"
+ description: str = (
+ "A comprehensive Google Search API provided by DataForSeo."
+ "This tool is useful for obtaining real-time data on current events "
+ "or popular searches."
+ "The input should be a search query and the output is a JSON object "
+ "of the query results."
+ )
+ api_wrapper: DataForSeoAPIWrapper = Field(default_factory=DataForSeoAPIWrapper)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return str(self.api_wrapper.results(query))
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ return (await self.api_wrapper.aresults(query)).__str__()
diff --git a/libs/community/langchain_community/tools/ddg_search/__init__.py b/libs/community/langchain_community/tools/ddg_search/__init__.py
new file mode 100644
index 00000000000..5b7de286b89
--- /dev/null
+++ b/libs/community/langchain_community/tools/ddg_search/__init__.py
@@ -0,0 +1,5 @@
+"""DuckDuckGo Search API toolkit."""
+
+from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun
+
+__all__ = ["DuckDuckGoSearchRun"]
diff --git a/libs/community/langchain_community/tools/ddg_search/tool.py b/libs/community/langchain_community/tools/ddg_search/tool.py
new file mode 100644
index 00000000000..c748422a03f
--- /dev/null
+++ b/libs/community/langchain_community/tools/ddg_search/tool.py
@@ -0,0 +1,83 @@
+"""Tool for the DuckDuckGo search API."""
+
+import warnings
+from typing import Any, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
+
+
+class DDGInput(BaseModel):
+ query: str = Field(description="search query to look up")
+
+
+class DuckDuckGoSearchRun(BaseTool):
+ """Tool that queries the DuckDuckGo search API."""
+
+ name: str = "duckduckgo_search"
+ description: str = (
+ "A wrapper around DuckDuckGo Search. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query."
+ )
+ api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
+ default_factory=DuckDuckGoSearchAPIWrapper
+ )
+ args_schema: Type[BaseModel] = DDGInput
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
+
+
+class DuckDuckGoSearchResults(BaseTool):
+ """Tool that queries the DuckDuckGo search API and gets back json."""
+
+ name: str = "DuckDuckGo Results JSON"
+ description: str = (
+ "A wrapper around Duck Duck Go Search. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query. Output is a JSON array of the query results"
+ )
+ max_results: int = Field(alias="num_results", default=4)
+ api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
+ default_factory=DuckDuckGoSearchAPIWrapper
+ )
+ backend: str = "text"
+ args_schema: Type[BaseModel] = DDGInput
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ res = self.api_wrapper.results(query, self.max_results, source=self.backend)
+ res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
+ return ", ".join([f"[{rs}]" for rs in res_strs])
+
+
+def DuckDuckGoSearchTool(*args: Any, **kwargs: Any) -> DuckDuckGoSearchRun:
+ """
+ Deprecated. Use DuckDuckGoSearchRun instead.
+
+ Args:
+ *args:
+ **kwargs:
+
+ Returns:
+ DuckDuckGoSearchRun
+ """
+ warnings.warn(
+ "DuckDuckGoSearchTool will be deprecated in the future. "
+ "Please use DuckDuckGoSearchRun instead.",
+ DeprecationWarning,
+ )
+ return DuckDuckGoSearchRun(*args, **kwargs)
diff --git a/libs/langchain/tests/integration_tests/tools/edenai/__init__.py b/libs/community/langchain_community/tools/e2b_data_analysis/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/tools/edenai/__init__.py
rename to libs/community/langchain_community/tools/e2b_data_analysis/__init__.py
diff --git a/libs/community/langchain_community/tools/e2b_data_analysis/tool.py b/libs/community/langchain_community/tools/e2b_data_analysis/tool.py
new file mode 100644
index 00000000000..0794b87469e
--- /dev/null
+++ b/libs/community/langchain_community/tools/e2b_data_analysis/tool.py
@@ -0,0 +1,243 @@
+from __future__ import annotations
+
+import ast
+import json
+import os
+from io import StringIO
+from sys import version_info
+from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManager,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr
+
+from langchain_community.tools import BaseTool, Tool
+from langchain_community.tools.e2b_data_analysis.unparse import Unparser
+
+if TYPE_CHECKING:
+ from e2b import EnvVars
+ from e2b.templates.data_analysis import Artifact
+
+base_description = """Evaluates python code in a sandbox environment. \
+The environment is long running and exists across multiple executions. \
+You must send the whole script every time and print your outputs. \
+Script should be pure python code that can be evaluated. \
+It should be in python format NOT markdown. \
+The code should NOT be wrapped in backticks. \
+All python packages including requests, matplotlib, scipy, numpy, pandas, \
+etc are available. Create and display chart using `plt.show()`."""
+
+
+def _unparse(tree: ast.AST) -> str:
+ """Unparse the AST."""
+ if version_info.minor < 9:
+ s = StringIO()
+ Unparser(tree, file=s)
+ source_code = s.getvalue()
+ s.close()
+ else:
+ source_code = ast.unparse(tree) # type: ignore[attr-defined]
+ return source_code
+
+
+def add_last_line_print(code: str) -> str:
+ """Add print statement to the last line if it's missing.
+
+ Sometimes, the LLM-generated code doesn't have `print(variable_name)`, instead the
+ LLM tries to print the variable only by writing `variable_name` (as you would in
+ REPL, for example).
+
+ This methods checks the AST of the generated Python code and adds the print
+ statement to the last line if it's missing.
+ """
+ tree = ast.parse(code)
+ node = tree.body[-1]
+ if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
+ if isinstance(node.value.func, ast.Name) and node.value.func.id == "print":
+ return _unparse(tree)
+
+ if isinstance(node, ast.Expr):
+ tree.body[-1] = ast.Expr(
+ value=ast.Call(
+ func=ast.Name(id="print", ctx=ast.Load()),
+ args=[node.value],
+ keywords=[],
+ )
+ )
+
+ return _unparse(tree)
+
+
+class UploadedFile(BaseModel):
+ """Description of the uploaded path with its remote path."""
+
+ name: str
+ remote_path: str
+ description: str
+
+
+class E2BDataAnalysisToolArguments(BaseModel):
+ """Arguments for the E2BDataAnalysisTool."""
+
+ python_code: str = Field(
+ ...,
+ example="print('Hello World')",
+ description=(
+ "The python script to be evaluated. "
+ "The contents will be in main.py. "
+ "It should not be in markdown format."
+ ),
+ )
+
+
+class E2BDataAnalysisTool(BaseTool):
+ """Tool for running python code in a sandboxed environment for data analysis."""
+
+ name = "e2b_data_analysis"
+ args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
+ session: Any
+ description: str
+ _uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list)
+
+ def __init__(
+ self,
+ api_key: Optional[str] = None,
+ cwd: Optional[str] = None,
+ env_vars: Optional[EnvVars] = None,
+ on_stdout: Optional[Callable[[str], Any]] = None,
+ on_stderr: Optional[Callable[[str], Any]] = None,
+ on_artifact: Optional[Callable[[Artifact], Any]] = None,
+ on_exit: Optional[Callable[[int], Any]] = None,
+ **kwargs: Any,
+ ):
+ try:
+ from e2b import DataAnalysis
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import e2b, please install with `pip install e2b`."
+ ) from e
+
+ # If no API key is provided, E2B will try to read it from the environment
+ # variable E2B_API_KEY
+ super().__init__(description=base_description, **kwargs)
+ self.session = DataAnalysis(
+ api_key=api_key,
+ cwd=cwd,
+ env_vars=env_vars,
+ on_stdout=on_stdout,
+ on_stderr=on_stderr,
+ on_exit=on_exit,
+ on_artifact=on_artifact,
+ )
+
+ def close(self) -> None:
+ """Close the cloud sandbox."""
+ self._uploaded_files = []
+ self.session.close()
+
+ @property
+ def uploaded_files_description(self) -> str:
+ if len(self._uploaded_files) == 0:
+ return ""
+ lines = ["The following files available in the sandbox:"]
+
+ for f in self._uploaded_files:
+ if f.description == "":
+ lines.append(f"- path: `{f.remote_path}`")
+ else:
+ lines.append(
+ f"- path: `{f.remote_path}` \n description: `{f.description}`"
+ )
+ return "\n".join(lines)
+
+ def _run(
+ self,
+ python_code: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ callbacks: Optional[CallbackManager] = None,
+ ) -> str:
+ python_code = add_last_line_print(python_code)
+
+ if callbacks is not None:
+ on_artifact = getattr(callbacks.metadata, "on_artifact", None)
+ else:
+ on_artifact = None
+
+ stdout, stderr, artifacts = self.session.run_python(
+ python_code, on_artifact=on_artifact
+ )
+
+ out = {
+ "stdout": stdout,
+ "stderr": stderr,
+ "artifacts": list(map(lambda artifact: artifact.name, artifacts)),
+ }
+ return json.dumps(out)
+
+ async def _arun(
+ self,
+ python_code: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ raise NotImplementedError("e2b_data_analysis does not support async")
+
+ def run_command(
+ self,
+ cmd: str,
+ ) -> dict:
+ """Run shell command in the sandbox."""
+ proc = self.session.process.start(cmd)
+ output = proc.wait()
+ return {
+ "stdout": output.stdout,
+ "stderr": output.stderr,
+ "exit_code": output.exit_code,
+ }
+
+ def install_python_packages(self, package_names: str | List[str]) -> None:
+ """Install python packages in the sandbox."""
+ self.session.install_python_packages(package_names)
+
+ def install_system_packages(self, package_names: str | List[str]) -> None:
+ """Install system packages (via apt) in the sandbox."""
+ self.session.install_system_packages(package_names)
+
+ def download_file(self, remote_path: str) -> bytes:
+ """Download file from the sandbox."""
+ return self.session.download_file(remote_path)
+
+ def upload_file(self, file: IO, description: str) -> UploadedFile:
+ """Upload file to the sandbox.
+
+ The file is uploaded to the '/home/user/' path."""
+ remote_path = self.session.upload_file(file)
+
+ f = UploadedFile(
+ name=os.path.basename(file.name),
+ remote_path=remote_path,
+ description=description,
+ )
+ self._uploaded_files.append(f)
+ self.description = self.description + "\n" + self.uploaded_files_description
+ return f
+
+ def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
+ """Remove uploaded file from the sandbox."""
+ self.session.filesystem.remove(uploaded_file.remote_path)
+ self._uploaded_files = [
+ f
+ for f in self._uploaded_files
+ if f.remote_path != uploaded_file.remote_path
+ ]
+ self.description = self.description + "\n" + self.uploaded_files_description
+
+ def as_tool(self) -> Tool:
+ return Tool.from_function(
+ func=self._run,
+ name=self.name,
+ description=self.description,
+ args_schema=self.args_schema,
+ )
diff --git a/libs/community/langchain_community/tools/e2b_data_analysis/unparse.py b/libs/community/langchain_community/tools/e2b_data_analysis/unparse.py
new file mode 100644
index 00000000000..68682023957
--- /dev/null
+++ b/libs/community/langchain_community/tools/e2b_data_analysis/unparse.py
@@ -0,0 +1,736 @@
+# mypy: disable-error-code=no-untyped-def
+# Because Python >3.9 doesn't support ast.unparse,
+# we copied the unparse functionality from here:
+# https://github.com/python/cpython/blob/3.8/Tools/parser/unparse.py
+"Usage: unparse.py "
+import ast
+import io
+import sys
+import tokenize
+
+# Large float and imaginary literals get turned into infinities in the AST.
+# We unparse those infinities to INFSTR.
+INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
+
+
+def interleave(inter, f, seq):
+ """Call f on each item in seq, calling inter() in between."""
+ seq = iter(seq)
+ try:
+ f(next(seq))
+ except StopIteration:
+ pass
+ else:
+ for x in seq:
+ inter()
+ f(x)
+
+
+class Unparser:
+ """Methods in this class recursively traverse an AST and
+ output source code for the abstract syntax; original formatting
+ is disregarded."""
+
+ def __init__(self, tree, file=sys.stdout):
+ """Unparser(tree, file=sys.stdout) -> None.
+ Print the source for tree to file."""
+ self.f = file
+ self._indent = 0
+ self.dispatch(tree)
+ self.f.flush()
+
+ def fill(self, text=""):
+ "Indent a piece of text, according to the current indentation level"
+ self.f.write("\n" + " " * self._indent + text)
+
+ def write(self, text):
+ "Append a piece of text to the current line."
+ self.f.write(text)
+
+ def enter(self):
+ "Print ':', and increase the indentation."
+ self.write(":")
+ self._indent += 1
+
+ def leave(self):
+ "Decrease the indentation level."
+ self._indent -= 1
+
+ def dispatch(self, tree):
+ "Dispatcher function, dispatching tree type T to method _T."
+ if isinstance(tree, list):
+ for t in tree:
+ self.dispatch(t)
+ return
+ meth = getattr(self, "_" + tree.__class__.__name__)
+ meth(tree)
+
+ ############### Unparsing methods ######################
+ # There should be one method per concrete grammar type #
+ # Constructors should be grouped by sum type. Ideally, #
+ # this would follow the order in the grammar, but #
+ # currently doesn't. #
+ ########################################################
+
+ def _Module(self, tree):
+ for stmt in tree.body:
+ self.dispatch(stmt)
+
+ # stmt
+ def _Expr(self, tree):
+ self.fill()
+ self.dispatch(tree.value)
+
+ def _NamedExpr(self, tree):
+ self.write("(")
+ self.dispatch(tree.target)
+ self.write(" := ")
+ self.dispatch(tree.value)
+ self.write(")")
+
+ def _Import(self, t):
+ self.fill("import ")
+ interleave(lambda: self.write(", "), self.dispatch, t.names)
+
+ def _ImportFrom(self, t):
+ self.fill("from ")
+ self.write("." * t.level)
+ if t.module:
+ self.write(t.module)
+ self.write(" import ")
+ interleave(lambda: self.write(", "), self.dispatch, t.names)
+
+ def _Assign(self, t):
+ self.fill()
+ for target in t.targets:
+ self.dispatch(target)
+ self.write(" = ")
+ self.dispatch(t.value)
+
+ def _AugAssign(self, t):
+ self.fill()
+ self.dispatch(t.target)
+ self.write(" " + self.binop[t.op.__class__.__name__] + "= ")
+ self.dispatch(t.value)
+
+ def _AnnAssign(self, t):
+ self.fill()
+ if not t.simple and isinstance(t.target, ast.Name):
+ self.write("(")
+ self.dispatch(t.target)
+ if not t.simple and isinstance(t.target, ast.Name):
+ self.write(")")
+ self.write(": ")
+ self.dispatch(t.annotation)
+ if t.value:
+ self.write(" = ")
+ self.dispatch(t.value)
+
+ def _Return(self, t):
+ self.fill("return")
+ if t.value:
+ self.write(" ")
+ self.dispatch(t.value)
+
+ def _Pass(self, t):
+ self.fill("pass")
+
+ def _Break(self, t):
+ self.fill("break")
+
+ def _Continue(self, t):
+ self.fill("continue")
+
+ def _Delete(self, t):
+ self.fill("del ")
+ interleave(lambda: self.write(", "), self.dispatch, t.targets)
+
+ def _Assert(self, t):
+ self.fill("assert ")
+ self.dispatch(t.test)
+ if t.msg:
+ self.write(", ")
+ self.dispatch(t.msg)
+
+ def _Global(self, t):
+ self.fill("global ")
+ interleave(lambda: self.write(", "), self.write, t.names)
+
+ def _Nonlocal(self, t):
+ self.fill("nonlocal ")
+ interleave(lambda: self.write(", "), self.write, t.names)
+
+ def _Await(self, t):
+ self.write("(")
+ self.write("await")
+ if t.value:
+ self.write(" ")
+ self.dispatch(t.value)
+ self.write(")")
+
+ def _Yield(self, t):
+ self.write("(")
+ self.write("yield")
+ if t.value:
+ self.write(" ")
+ self.dispatch(t.value)
+ self.write(")")
+
+ def _YieldFrom(self, t):
+ self.write("(")
+ self.write("yield from")
+ if t.value:
+ self.write(" ")
+ self.dispatch(t.value)
+ self.write(")")
+
+ def _Raise(self, t):
+ self.fill("raise")
+ if not t.exc:
+ assert not t.cause
+ return
+ self.write(" ")
+ self.dispatch(t.exc)
+ if t.cause:
+ self.write(" from ")
+ self.dispatch(t.cause)
+
+ def _Try(self, t):
+ self.fill("try")
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+ for ex in t.handlers:
+ self.dispatch(ex)
+ if t.orelse:
+ self.fill("else")
+ self.enter()
+ self.dispatch(t.orelse)
+ self.leave()
+ if t.finalbody:
+ self.fill("finally")
+ self.enter()
+ self.dispatch(t.finalbody)
+ self.leave()
+
+ def _ExceptHandler(self, t):
+ self.fill("except")
+ if t.type:
+ self.write(" ")
+ self.dispatch(t.type)
+ if t.name:
+ self.write(" as ")
+ self.write(t.name)
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+
+ def _ClassDef(self, t):
+ self.write("\n")
+ for deco in t.decorator_list:
+ self.fill("@")
+ self.dispatch(deco)
+ self.fill("class " + t.name)
+ self.write("(")
+ comma = False
+ for e in t.bases:
+ if comma:
+ self.write(", ")
+ else:
+ comma = True
+ self.dispatch(e)
+ for e in t.keywords:
+ if comma:
+ self.write(", ")
+ else:
+ comma = True
+ self.dispatch(e)
+ self.write(")")
+
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+
+ def _FunctionDef(self, t):
+ self.__FunctionDef_helper(t, "def")
+
+ def _AsyncFunctionDef(self, t):
+ self.__FunctionDef_helper(t, "async def")
+
+ def __FunctionDef_helper(self, t, fill_suffix):
+ self.write("\n")
+ for deco in t.decorator_list:
+ self.fill("@")
+ self.dispatch(deco)
+ def_str = fill_suffix + " " + t.name + "("
+ self.fill(def_str)
+ self.dispatch(t.args)
+ self.write(")")
+ if t.returns:
+ self.write(" -> ")
+ self.dispatch(t.returns)
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+
+ def _For(self, t):
+ self.__For_helper("for ", t)
+
+ def _AsyncFor(self, t):
+ self.__For_helper("async for ", t)
+
+ def __For_helper(self, fill, t):
+ self.fill(fill)
+ self.dispatch(t.target)
+ self.write(" in ")
+ self.dispatch(t.iter)
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+ if t.orelse:
+ self.fill("else")
+ self.enter()
+ self.dispatch(t.orelse)
+ self.leave()
+
+ def _If(self, t):
+ self.fill("if ")
+ self.dispatch(t.test)
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+ # collapse nested ifs into equivalent elifs.
+ while t.orelse and len(t.orelse) == 1 and isinstance(t.orelse[0], ast.If):
+ t = t.orelse[0]
+ self.fill("elif ")
+ self.dispatch(t.test)
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+ # final else
+ if t.orelse:
+ self.fill("else")
+ self.enter()
+ self.dispatch(t.orelse)
+ self.leave()
+
+ def _While(self, t):
+ self.fill("while ")
+ self.dispatch(t.test)
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+ if t.orelse:
+ self.fill("else")
+ self.enter()
+ self.dispatch(t.orelse)
+ self.leave()
+
+ def _With(self, t):
+ self.fill("with ")
+ interleave(lambda: self.write(", "), self.dispatch, t.items)
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+
+ def _AsyncWith(self, t):
+ self.fill("async with ")
+ interleave(lambda: self.write(", "), self.dispatch, t.items)
+ self.enter()
+ self.dispatch(t.body)
+ self.leave()
+
+ # expr
+ def _JoinedStr(self, t):
+ self.write("f")
+ string = io.StringIO()
+ self._fstring_JoinedStr(t, string.write)
+ self.write(repr(string.getvalue()))
+
+ def _FormattedValue(self, t):
+ self.write("f")
+ string = io.StringIO()
+ self._fstring_FormattedValue(t, string.write)
+ self.write(repr(string.getvalue()))
+
+ def _fstring_JoinedStr(self, t, write):
+ for value in t.values:
+ meth = getattr(self, "_fstring_" + type(value).__name__)
+ meth(value, write)
+
+ def _fstring_Constant(self, t, write):
+ assert isinstance(t.value, str)
+ value = t.value.replace("{", "{{").replace("}", "}}")
+ write(value)
+
+ def _fstring_FormattedValue(self, t, write):
+ write("{")
+ expr = io.StringIO()
+ Unparser(t.value, expr)
+ expr = expr.getvalue().rstrip("\n")
+ if expr.startswith("{"):
+ write(" ") # Separate pair of opening brackets as "{ {"
+ write(expr)
+ if t.conversion != -1:
+ conversion = chr(t.conversion)
+ assert conversion in "sra"
+ write(f"!{conversion}")
+ if t.format_spec:
+ write(":")
+ meth = getattr(self, "_fstring_" + type(t.format_spec).__name__)
+ meth(t.format_spec, write)
+ write("}")
+
+ def _Name(self, t):
+ self.write(t.id)
+
+ def _write_constant(self, value):
+ if isinstance(value, (float, complex)):
+ # Substitute overflowing decimal literal for AST infinities.
+ self.write(repr(value).replace("inf", INFSTR))
+ else:
+ self.write(repr(value))
+
+ def _Constant(self, t):
+ value = t.value
+ if isinstance(value, tuple):
+ self.write("(")
+ if len(value) == 1:
+ self._write_constant(value[0])
+ self.write(",")
+ else:
+ interleave(lambda: self.write(", "), self._write_constant, value)
+ self.write(")")
+ elif value is ...:
+ self.write("...")
+ else:
+ if t.kind == "u":
+ self.write("u")
+ self._write_constant(t.value)
+
+ def _List(self, t):
+ self.write("[")
+ interleave(lambda: self.write(", "), self.dispatch, t.elts)
+ self.write("]")
+
+ def _ListComp(self, t):
+ self.write("[")
+ self.dispatch(t.elt)
+ for gen in t.generators:
+ self.dispatch(gen)
+ self.write("]")
+
+ def _GeneratorExp(self, t):
+ self.write("(")
+ self.dispatch(t.elt)
+ for gen in t.generators:
+ self.dispatch(gen)
+ self.write(")")
+
+ def _SetComp(self, t):
+ self.write("{")
+ self.dispatch(t.elt)
+ for gen in t.generators:
+ self.dispatch(gen)
+ self.write("}")
+
+ def _DictComp(self, t):
+ self.write("{")
+ self.dispatch(t.key)
+ self.write(": ")
+ self.dispatch(t.value)
+ for gen in t.generators:
+ self.dispatch(gen)
+ self.write("}")
+
+ def _comprehension(self, t):
+ if t.is_async:
+ self.write(" async for ")
+ else:
+ self.write(" for ")
+ self.dispatch(t.target)
+ self.write(" in ")
+ self.dispatch(t.iter)
+ for if_clause in t.ifs:
+ self.write(" if ")
+ self.dispatch(if_clause)
+
+ def _IfExp(self, t):
+ self.write("(")
+ self.dispatch(t.body)
+ self.write(" if ")
+ self.dispatch(t.test)
+ self.write(" else ")
+ self.dispatch(t.orelse)
+ self.write(")")
+
+ def _Set(self, t):
+ assert t.elts # should be at least one element
+ self.write("{")
+ interleave(lambda: self.write(", "), self.dispatch, t.elts)
+ self.write("}")
+
+ def _Dict(self, t):
+ self.write("{")
+
+ def write_key_value_pair(k, v):
+ self.dispatch(k)
+ self.write(": ")
+ self.dispatch(v)
+
+ def write_item(item):
+ k, v = item
+ if k is None:
+ # for dictionary unpacking operator in dicts {**{'y': 2}}
+ # see PEP 448 for details
+ self.write("**")
+ self.dispatch(v)
+ else:
+ write_key_value_pair(k, v)
+
+ interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
+ self.write("}")
+
+ def _Tuple(self, t):
+ self.write("(")
+ if len(t.elts) == 1:
+ elt = t.elts[0]
+ self.dispatch(elt)
+ self.write(",")
+ else:
+ interleave(lambda: self.write(", "), self.dispatch, t.elts)
+ self.write(")")
+
+ unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
+
+ def _UnaryOp(self, t):
+ self.write("(")
+ self.write(self.unop[t.op.__class__.__name__])
+ self.write(" ")
+ self.dispatch(t.operand)
+ self.write(")")
+
+ binop = {
+ "Add": "+",
+ "Sub": "-",
+ "Mult": "*",
+ "MatMult": "@",
+ "Div": "/",
+ "Mod": "%",
+ "LShift": "<<",
+ "RShift": ">>",
+ "BitOr": "|",
+ "BitXor": "^",
+ "BitAnd": "&",
+ "FloorDiv": "//",
+ "Pow": "**",
+ }
+
+ def _BinOp(self, t):
+ self.write("(")
+ self.dispatch(t.left)
+ self.write(" " + self.binop[t.op.__class__.__name__] + " ")
+ self.dispatch(t.right)
+ self.write(")")
+
+ cmpops = {
+ "Eq": "==",
+ "NotEq": "!=",
+ "Lt": "<",
+ "LtE": "<=",
+ "Gt": ">",
+ "GtE": ">=",
+ "Is": "is",
+ "IsNot": "is not",
+ "In": "in",
+ "NotIn": "not in",
+ }
+
+ def _Compare(self, t):
+ self.write("(")
+ self.dispatch(t.left)
+ for o, e in zip(t.ops, t.comparators):
+ self.write(" " + self.cmpops[o.__class__.__name__] + " ")
+ self.dispatch(e)
+ self.write(")")
+
+ boolops = {ast.And: "and", ast.Or: "or"}
+
+ def _BoolOp(self, t):
+ self.write("(")
+ s = " %s " % self.boolops[t.op.__class__]
+ interleave(lambda: self.write(s), self.dispatch, t.values)
+ self.write(")")
+
+ def _Attribute(self, t):
+ self.dispatch(t.value)
+ # Special case: 3.__abs__() is a syntax error, so if t.value
+ # is an integer literal then we need to either parenthesize
+ # it or add an extra space to get 3 .__abs__().
+ if isinstance(t.value, ast.Constant) and isinstance(t.value.value, int):
+ self.write(" ")
+ self.write(".")
+ self.write(t.attr)
+
+ def _Call(self, t):
+ self.dispatch(t.func)
+ self.write("(")
+ comma = False
+ for e in t.args:
+ if comma:
+ self.write(", ")
+ else:
+ comma = True
+ self.dispatch(e)
+ for e in t.keywords:
+ if comma:
+ self.write(", ")
+ else:
+ comma = True
+ self.dispatch(e)
+ self.write(")")
+
+ def _Subscript(self, t):
+ self.dispatch(t.value)
+ self.write("[")
+ if (
+ isinstance(t.slice, ast.Index)
+ and isinstance(t.slice.value, ast.Tuple)
+ and t.slice.value.elts
+ ):
+ if len(t.slice.value.elts) == 1:
+ elt = t.slice.value.elts[0]
+ self.dispatch(elt)
+ self.write(",")
+ else:
+ interleave(lambda: self.write(", "), self.dispatch, t.slice.value.elts)
+ else:
+ self.dispatch(t.slice)
+ self.write("]")
+
+ def _Starred(self, t):
+ self.write("*")
+ self.dispatch(t.value)
+
+ # slice
+ def _Ellipsis(self, t):
+ self.write("...")
+
+ def _Index(self, t):
+ self.dispatch(t.value)
+
+ def _Slice(self, t):
+ if t.lower:
+ self.dispatch(t.lower)
+ self.write(":")
+ if t.upper:
+ self.dispatch(t.upper)
+ if t.step:
+ self.write(":")
+ self.dispatch(t.step)
+
+ def _ExtSlice(self, t):
+ if len(t.dims) == 1:
+ elt = t.dims[0]
+ self.dispatch(elt)
+ self.write(",")
+ else:
+ interleave(lambda: self.write(", "), self.dispatch, t.dims)
+
+ # argument
+ def _arg(self, t):
+ self.write(t.arg)
+ if t.annotation:
+ self.write(": ")
+ self.dispatch(t.annotation)
+
+ # others
+ def _arguments(self, t):
+ first = True
+ # normal arguments
+ all_args = t.posonlyargs + t.args
+ defaults = [None] * (len(all_args) - len(t.defaults)) + t.defaults
+ for index, elements in enumerate(zip(all_args, defaults), 1):
+ a, d = elements
+ if first:
+ first = False
+ else:
+ self.write(", ")
+ self.dispatch(a)
+ if d:
+ self.write("=")
+ self.dispatch(d)
+ if index == len(t.posonlyargs):
+ self.write(", /")
+
+ # varargs, or bare '*' if no varargs but keyword-only arguments present
+ if t.vararg or t.kwonlyargs:
+ if first:
+ first = False
+ else:
+ self.write(", ")
+ self.write("*")
+ if t.vararg:
+ self.write(t.vararg.arg)
+ if t.vararg.annotation:
+ self.write(": ")
+ self.dispatch(t.vararg.annotation)
+
+ # keyword-only arguments
+ if t.kwonlyargs:
+ for a, d in zip(t.kwonlyargs, t.kw_defaults):
+ if first:
+ first = False
+ else:
+ self.write(", ")
+ self.dispatch(a)
+ if d:
+ self.write("=")
+ self.dispatch(d)
+
+ # kwargs
+ if t.kwarg:
+ if first:
+ first = False
+ else:
+ self.write(", ")
+ self.write("**" + t.kwarg.arg)
+ if t.kwarg.annotation:
+ self.write(": ")
+ self.dispatch(t.kwarg.annotation)
+
+ def _keyword(self, t):
+ if t.arg is None:
+ self.write("**")
+ else:
+ self.write(t.arg)
+ self.write("=")
+ self.dispatch(t.value)
+
+ def _Lambda(self, t):
+ self.write("(")
+ self.write("lambda ")
+ self.dispatch(t.args)
+ self.write(": ")
+ self.dispatch(t.body)
+ self.write(")")
+
+ def _alias(self, t):
+ self.write(t.name)
+ if t.asname:
+ self.write(" as " + t.asname)
+
+ def _withitem(self, t):
+ self.dispatch(t.context_expr)
+ if t.optional_vars:
+ self.write(" as ")
+ self.dispatch(t.optional_vars)
+
+
+def roundtrip(filename, output=sys.stdout):
+ with open(filename, "rb") as pyfile:
+ encoding = tokenize.detect_encoding(pyfile.readline)[0]
+ with open(filename, "r", encoding=encoding) as pyfile:
+ source = pyfile.read()
+ tree = compile(source, filename, "exec", ast.PyCF_ONLY_AST)
+ Unparser(tree, output)
diff --git a/libs/community/langchain_community/tools/edenai/__init__.py b/libs/community/langchain_community/tools/edenai/__init__.py
new file mode 100644
index 00000000000..e6084d51574
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/__init__.py
@@ -0,0 +1,34 @@
+"""Edenai Tools."""
+from langchain_community.tools.edenai.audio_speech_to_text import (
+ EdenAiSpeechToTextTool,
+)
+from langchain_community.tools.edenai.audio_text_to_speech import (
+ EdenAiTextToSpeechTool,
+)
+from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
+from langchain_community.tools.edenai.image_explicitcontent import (
+ EdenAiExplicitImageTool,
+)
+from langchain_community.tools.edenai.image_objectdetection import (
+ EdenAiObjectDetectionTool,
+)
+from langchain_community.tools.edenai.ocr_identityparser import (
+ EdenAiParsingIDTool,
+)
+from langchain_community.tools.edenai.ocr_invoiceparser import (
+ EdenAiParsingInvoiceTool,
+)
+from langchain_community.tools.edenai.text_moderation import (
+ EdenAiTextModerationTool,
+)
+
+__all__ = [
+ "EdenAiExplicitImageTool",
+ "EdenAiObjectDetectionTool",
+ "EdenAiParsingIDTool",
+ "EdenAiParsingInvoiceTool",
+ "EdenAiTextToSpeechTool",
+ "EdenAiSpeechToTextTool",
+ "EdenAiTextModerationTool",
+ "EdenaiTool",
+]
diff --git a/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py b/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py
new file mode 100644
index 00000000000..00f158ed53b
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import json
+import logging
+import time
+from typing import List, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import validator
+
+from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
+
+logger = logging.getLogger(__name__)
+
+
+class EdenAiSpeechToTextTool(EdenaiTool):
+ """Tool that queries the Eden AI Speech To Text API.
+
+ for api reference check edenai documentation:
+ https://app.edenai.run/bricks/speech/asynchronous-speech-to-text.
+
+ To use, you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+
+ """
+
+ edenai_api_key: Optional[str] = None
+
+ name = "edenai_speech_to_text"
+ description = (
+ "A wrapper around edenai Services speech to text "
+ "Useful for when you have to convert audio to text."
+ "Input should be a url to an audio file."
+ )
+ is_async = True
+
+ language: Optional[str] = "en"
+ speakers: Optional[int]
+ profanity_filter: bool = False
+ custom_vocabulary: Optional[List[str]]
+
+ feature: str = "audio"
+ subfeature: str = "speech_to_text_async"
+ base_url = "https://api.edenai.run/v2/audio/speech_to_text_async/"
+
+ @validator("providers")
+ def check_only_one_provider_selected(cls, v: List[str]) -> List[str]:
+ """
+ This tool has no feature to combine providers results.
+ Therefore we only allow one provider
+ """
+ if len(v) > 1:
+ raise ValueError(
+ "Please select only one provider. "
+ "The feature to combine providers results is not available "
+ "for this tool."
+ )
+ return v
+
+ def _wait_processing(self, url: str) -> requests.Response:
+ for _ in range(10):
+ time.sleep(1)
+ audio_analysis_result = self._get_edenai(url)
+ temp = audio_analysis_result.json()
+ if temp["status"] == "finished":
+ if temp["results"][self.providers[0]]["error"] is not None:
+ raise Exception(
+ f"""EdenAI returned an unexpected response
+ {temp['results'][self.providers[0]]['error']}"""
+ )
+ else:
+ return audio_analysis_result
+
+ raise Exception("Edenai speech to text job id processing Timed out")
+
+ def _parse_response(self, response: dict) -> str:
+ return response["public_id"]
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ all_params = {
+ "file_url": query,
+ "language": self.language,
+ "speakers": self.speakers,
+ "profanity_filter": self.profanity_filter,
+ "custom_vocabulary": self.custom_vocabulary,
+ }
+
+ # filter so we don't send val to api when val is `None
+ query_params = {k: v for k, v in all_params.items() if v is not None}
+
+ job_id = self._call_eden_ai(query_params)
+ url = self.base_url + job_id
+ audio_analysis_result = self._wait_processing(url)
+ result = audio_analysis_result.text
+ formatted_text = json.loads(result)
+ return formatted_text["results"][self.providers[0]]["text"]
diff --git a/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py b/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py
new file mode 100644
index 00000000000..575b9a70b52
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py
@@ -0,0 +1,116 @@
+from __future__ import annotations
+
+import logging
+from typing import Dict, List, Literal, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field, root_validator, validator
+
+from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
+
+logger = logging.getLogger(__name__)
+
+
+class EdenAiTextToSpeechTool(EdenaiTool):
+ """Tool that queries the Eden AI Text to speech API.
+ for api reference check edenai documentation:
+ https://docs.edenai.co/reference/audio_text_to_speech_create.
+
+ To use, you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+
+ """
+
+ name = "edenai_text_to_speech"
+ description = (
+ "A wrapper around edenai Services text to speech."
+ "Useful for when you need to convert text to speech."
+ """the output is a string representing the URL of the audio file,
+ or the path to the downloaded wav file """
+ )
+
+ language: Optional[str] = "en"
+ """
+ language of the text passed to the model.
+ """
+
+ # optional params see api documentation for more info
+ return_type: Literal["url", "wav"] = "url"
+ rate: Optional[int]
+ pitch: Optional[int]
+ volume: Optional[int]
+ audio_format: Optional[str]
+ sampling_rate: Optional[int]
+ voice_models: Dict[str, str] = Field(default_factory=dict)
+
+ voice: Literal["MALE", "FEMALE"]
+ """voice option : 'MALE' or 'FEMALE' """
+
+ feature: str = "audio"
+ subfeature: str = "text_to_speech"
+
+ @validator("providers")
+ def check_only_one_provider_selected(cls, v: List[str]) -> List[str]:
+ """
+ This tool has no feature to combine providers results.
+ Therefore we only allow one provider
+ """
+ if len(v) > 1:
+ raise ValueError(
+ "Please select only one provider. "
+ "The feature to combine providers results is not available "
+ "for this tool."
+ )
+ return v
+
+ @root_validator
+ def check_voice_models_key_is_provider_name(cls, values: dict) -> dict:
+ for key in values.get("voice_models", {}).keys():
+ if key not in values.get("providers", []):
+ raise ValueError(
+ "voice_model should be formatted like this "
+ "{: }"
+ )
+ return values
+
+ def _download_wav(self, url: str, save_path: str) -> None:
+ response = requests.get(url)
+ if response.status_code == 200:
+ with open(save_path, "wb") as f:
+ f.write(response.content)
+ else:
+ raise ValueError("Error while downloading wav file")
+
+ def _parse_response(self, response: list) -> str:
+ result = response[0]
+ if self.return_type == "url":
+ return result["audio_resource_url"]
+ else:
+ self._download_wav(result["audio_resource_url"], "audio.wav")
+ return "audio.wav"
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ all_params = {
+ "text": query,
+ "language": self.language,
+ "option": self.voice,
+ "return_type": self.return_type,
+ "rate": self.rate,
+ "pitch": self.pitch,
+ "volume": self.volume,
+ "audio_format": self.audio_format,
+ "sampling_rate": self.sampling_rate,
+ "settings": self.voice_models,
+ }
+
+ # filter so we don't send val to api when val is `None
+ query_params = {k: v for k, v in all_params.items() if v is not None}
+
+ return self._call_eden_ai(query_params)
diff --git a/libs/community/langchain_community/tools/edenai/edenai_base_tool.py b/libs/community/langchain_community/tools/edenai/edenai_base_tool.py
new file mode 100644
index 00000000000..e38f6936cc1
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/edenai_base_tool.py
@@ -0,0 +1,159 @@
+from __future__ import annotations
+
+import logging
+from abc import abstractmethod
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.tools import BaseTool
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class EdenaiTool(BaseTool):
+
+ """
+ the base tool for all the EdenAI Tools .
+ you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+ """
+
+ feature: str
+ subfeature: str
+ edenai_api_key: Optional[str] = None
+ is_async: bool = False
+
+ providers: List[str]
+ """provider to use for the API call."""
+
+ @root_validator(allow_reuse=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ values["edenai_api_key"] = get_from_dict_or_env(
+ values, "edenai_api_key", "EDENAI_API_KEY"
+ )
+ return values
+
+ @staticmethod
+ def get_user_agent() -> str:
+ from langchain_community import __version__
+
+ return f"langchain/{__version__}"
+
+ def _call_eden_ai(self, query_params: Dict[str, Any]) -> str:
+ """
+ Make an API call to the EdenAI service with the specified query parameters.
+
+ Args:
+ query_params (dict): The parameters to include in the API call.
+
+ Returns:
+ requests.Response: The response from the EdenAI API call.
+
+ """
+
+ # faire l'API call
+
+ headers = {
+ "Authorization": f"Bearer {self.edenai_api_key}",
+ "User-Agent": self.get_user_agent(),
+ }
+
+ url = f"https://api.edenai.run/v2/{self.feature}/{self.subfeature}"
+
+ payload = {
+ "providers": str(self.providers),
+ "response_as_dict": False,
+ "attributes_as_list": True,
+ "show_original_response": False,
+ }
+
+ payload.update(query_params)
+
+ response = requests.post(url, json=payload, headers=headers)
+
+ self._raise_on_error(response)
+
+ try:
+ return self._parse_response(response.json())
+ except Exception as e:
+ raise RuntimeError(f"An error occurred while running tool: {e}")
+
+ def _raise_on_error(self, response: requests.Response) -> None:
+ if response.status_code >= 500:
+ raise Exception(f"EdenAI Server: Error {response.status_code}")
+ elif response.status_code >= 400:
+ raise ValueError(f"EdenAI received an invalid payload: {response.text}")
+ elif response.status_code != 200:
+ raise Exception(
+ f"EdenAI returned an unexpected response with status "
+ f"{response.status_code}: {response.text}"
+ )
+
+ # case where edenai call succeeded but provider returned an error
+ # (eg: rate limit, server error, etc.)
+ if self.is_async is False:
+ # async call are different and only return a job_id,
+ # not the provider response directly
+ provider_response = response.json()[0]
+ if provider_response.get("status") == "fail":
+ err_msg = provider_response["error"]["message"]
+ raise ValueError(err_msg)
+
+ @abstractmethod
+ def _run(
+ self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ pass
+
+ @abstractmethod
+ def _parse_response(self, response: Any) -> str:
+ """Take a dict response and condense it's data in a human readable string"""
+ pass
+
+ def _get_edenai(self, url: str) -> requests.Response:
+ headers = {
+ "accept": "application/json",
+ "authorization": f"Bearer {self.edenai_api_key}",
+ "User-Agent": self.get_user_agent(),
+ }
+
+ response = requests.get(url, headers=headers)
+
+ self._raise_on_error(response)
+
+ return response
+
+ def _parse_json_multilevel(
+ self, extracted_data: dict, formatted_list: list, level: int = 0
+ ) -> None:
+ for section, subsections in extracted_data.items():
+ indentation = " " * level
+ if isinstance(subsections, str):
+ subsections = subsections.replace("\n", ",")
+ formatted_list.append(f"{indentation}{section} : {subsections}")
+
+ elif isinstance(subsections, list):
+ formatted_list.append(f"{indentation}{section} : ")
+ self._list_handling(subsections, formatted_list, level + 1)
+
+ elif isinstance(subsections, dict):
+ formatted_list.append(f"{indentation}{section} : ")
+ self._parse_json_multilevel(subsections, formatted_list, level + 1)
+
+ def _list_handling(
+ self, subsection_list: list, formatted_list: list, level: int
+ ) -> None:
+ for list_item in subsection_list:
+ if isinstance(list_item, dict):
+ self._parse_json_multilevel(list_item, formatted_list, level)
+
+ elif isinstance(list_item, list):
+ self._list_handling(list_item, formatted_list, level + 1)
+
+ else:
+ formatted_list.append(f"{' ' * level}{list_item}")
diff --git a/libs/community/langchain_community/tools/edenai/image_explicitcontent.py b/libs/community/langchain_community/tools/edenai/image_explicitcontent.py
new file mode 100644
index 00000000000..fbeb4c1ec19
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/image_explicitcontent.py
@@ -0,0 +1,68 @@
+from __future__ import annotations
+
+import logging
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+
+from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
+
+logger = logging.getLogger(__name__)
+
+
+class EdenAiExplicitImageTool(EdenaiTool):
+
+ """Tool that queries the Eden AI Explicit image detection.
+
+ for api reference check edenai documentation:
+ https://docs.edenai.co/reference/image_explicit_content_create.
+
+ To use, you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+
+ """
+
+ name = "edenai_image_explicit_content_detection"
+
+ description = (
+ "A wrapper around edenai Services Explicit image detection. "
+ """Useful for when you have to extract Explicit Content from images.
+ it detects adult only content in images,
+ that is generally inappropriate for people under
+ the age of 18 and includes nudity, sexual activity,
+ pornography, violence, gore content, etc."""
+ "Input should be the string url of the image ."
+ )
+
+ combine_available = True
+ feature = "image"
+ subfeature = "explicit_content"
+
+ def _parse_json(self, json_data: dict) -> str:
+ result_str = f"nsfw_likelihood: {json_data['nsfw_likelihood']}\n"
+ for idx, found_obj in enumerate(json_data["items"]):
+ label = found_obj["label"].lower()
+ likelihood = found_obj["likelihood"]
+ result_str += f"{idx}: {label} likelihood {likelihood},\n"
+
+ return result_str[:-2]
+
+ def _parse_response(self, json_data: list) -> str:
+ if len(json_data) == 1:
+ result = self._parse_json(json_data[0])
+ else:
+ for entry in json_data:
+ if entry.get("provider") == "eden-ai":
+ result = self._parse_json(entry)
+
+ return result
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ query_params = {"file_url": query, "attributes_as_list": False}
+ return self._call_eden_ai(query_params)
diff --git a/libs/community/langchain_community/tools/edenai/image_objectdetection.py b/libs/community/langchain_community/tools/edenai/image_objectdetection.py
new file mode 100644
index 00000000000..e40d464d5c7
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/image_objectdetection.py
@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+import logging
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+
+from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
+
+logger = logging.getLogger(__name__)
+
+
+class EdenAiObjectDetectionTool(EdenaiTool):
+ """Tool that queries the Eden AI Object detection API.
+
+ for api reference check edenai documentation:
+ https://docs.edenai.co/reference/image_object_detection_create.
+
+ To use, you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+
+ """
+
+ name = "edenai_object_detection"
+
+ description = (
+ "A wrapper around edenai Services Object Detection . "
+ """Useful for when you have to do an to identify and locate
+ (with bounding boxes) objects in an image """
+ "Input should be the string url of the image to identify."
+ )
+
+ show_positions: bool = False
+
+ feature = "image"
+ subfeature = "object_detection"
+
+ def _parse_json(self, json_data: dict) -> str:
+ result = []
+ label_info = []
+
+ for found_obj in json_data["items"]:
+ label_str = f"{found_obj['label']} - Confidence {found_obj['confidence']}"
+ x_min = found_obj.get("x_min")
+ x_max = found_obj.get("x_max")
+ y_min = found_obj.get("y_min")
+ y_max = found_obj.get("y_max")
+ if self.show_positions and all(
+ [x_min, x_max, y_min, y_max]
+ ): # some providers don't return positions
+ label_str += f""",at the position x_min: {x_min}, x_max: {x_max},
+ y_min: {y_min}, y_max: {y_max}"""
+ label_info.append(label_str)
+
+ result.append("\n".join(label_info))
+ return "\n\n".join(result)
+
+ def _parse_response(self, response: list) -> str:
+ if len(response) == 1:
+ result = self._parse_json(response[0])
+ else:
+ for entry in response:
+ if entry.get("provider") == "eden-ai":
+ result = self._parse_json(entry)
+
+ return result
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ query_params = {"file_url": query, "attributes_as_list": False}
+ return self._call_eden_ai(query_params)
diff --git a/libs/community/langchain_community/tools/edenai/ocr_identityparser.py b/libs/community/langchain_community/tools/edenai/ocr_identityparser.py
new file mode 100644
index 00000000000..6af27c92fd6
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/ocr_identityparser.py
@@ -0,0 +1,69 @@
+from __future__ import annotations
+
+import logging
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+
+from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
+
+logger = logging.getLogger(__name__)
+
+
+class EdenAiParsingIDTool(EdenaiTool):
+ """Tool that queries the Eden AI Identity parsing API.
+
+ for api reference check edenai documentation:
+ https://docs.edenai.co/reference/ocr_identity_parser_create.
+
+ To use, you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+
+ """
+
+ name = "edenai_identity_parsing"
+
+ description = (
+ "A wrapper around edenai Services Identity parsing. "
+ "Useful for when you have to extract information from an ID Document "
+ "Input should be the string url of the document to parse."
+ )
+
+ feature = "ocr"
+ subfeature = "identity_parser"
+
+ language: Optional[str] = None
+ """
+ language of the text passed to the model.
+ """
+
+ def _parse_response(self, response: list) -> str:
+ formatted_list: list = []
+
+ if len(response) == 1:
+ self._parse_json_multilevel(
+ response[0]["extracted_data"][0], formatted_list
+ )
+ else:
+ for entry in response:
+ if entry.get("provider") == "eden-ai":
+ self._parse_json_multilevel(
+ entry["extracted_data"][0], formatted_list
+ )
+
+ return "\n".join(formatted_list)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ query_params = {
+ "file_url": query,
+ "language": self.language,
+ "attributes_as_list": False,
+ }
+
+ return self._call_eden_ai(query_params)
diff --git a/libs/community/langchain_community/tools/edenai/ocr_invoiceparser.py b/libs/community/langchain_community/tools/edenai/ocr_invoiceparser.py
new file mode 100644
index 00000000000..6b2e7d8befe
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/ocr_invoiceparser.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+import logging
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+
+from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
+
+logger = logging.getLogger(__name__)
+
+
+class EdenAiParsingInvoiceTool(EdenaiTool):
+ """Tool that queries the Eden AI Invoice parsing API.
+
+ for api reference check edenai documentation:
+ https://docs.edenai.co/reference/ocr_invoice_parser_create.
+
+ To use, you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+
+ """
+
+ name = "edenai_invoice_parsing"
+
+ description = (
+ "A wrapper around edenai Services invoice parsing. "
+ """Useful for when you have to extract information from
+ an image it enables to take invoices
+ in a variety of formats and returns the data in contains
+ (items, prices, addresses, vendor name, etc.)
+ in a structured format to automate the invoice processing """
+ "Input should be the string url of the document to parse."
+ )
+
+ language: Optional[str] = None
+ """
+ language of the image passed to the model.
+ """
+
+ feature = "ocr"
+ subfeature = "invoice_parser"
+
+ def _parse_response(self, response: list) -> str:
+ formatted_list: list = []
+
+ if len(response) == 1:
+ self._parse_json_multilevel(
+ response[0]["extracted_data"][0], formatted_list
+ )
+ else:
+ for entry in response:
+ if entry.get("provider") == "eden-ai":
+ self._parse_json_multilevel(
+ entry["extracted_data"][0], formatted_list
+ )
+
+ return "\n".join(formatted_list)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ query_params = {
+ "file_url": query,
+ "language": self.language,
+ "attributes_as_list": False,
+ }
+
+ return self._call_eden_ai(query_params)
diff --git a/libs/community/langchain_community/tools/edenai/text_moderation.py b/libs/community/langchain_community/tools/edenai/text_moderation.py
new file mode 100644
index 00000000000..44d1308117f
--- /dev/null
+++ b/libs/community/langchain_community/tools/edenai/text_moderation.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+import logging
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+
+from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
+
+logger = logging.getLogger(__name__)
+
+
+class EdenAiTextModerationTool(EdenaiTool):
+ """Tool that queries the Eden AI Explicit text detection.
+
+ for api reference check edenai documentation:
+ https://docs.edenai.co/reference/image_explicit_content_create.
+
+ To use, you should have
+ the environment variable ``EDENAI_API_KEY`` set with your API token.
+ You can find your token here: https://app.edenai.run/admin/account/settings
+
+ """
+
+ name = "edenai_explicit_content_detection_text"
+
+ description = (
+ "A wrapper around edenai Services explicit content detection for text. "
+ """Useful for when you have to scan text for offensive,
+ sexually explicit or suggestive content,
+ it checks also if there is any content of self-harm,
+ violence, racist or hate speech."""
+ """the structure of the output is :
+ 'the type of the explicit content : the likelihood of it being explicit'
+ the likelihood is a number
+ between 1 and 5, 1 being the lowest and 5 the highest.
+ something is explicit if the likelihood is equal or higher than 3.
+ for example :
+ nsfw_likelihood: 1
+ this is not explicit.
+ for example :
+ nsfw_likelihood: 3
+ this is explicit.
+ """
+ "Input should be a string."
+ )
+
+ language: str
+
+ feature: str = "text"
+ subfeature: str = "moderation"
+
+ def _parse_response(self, response: list) -> str:
+ formatted_result = []
+ for result in response:
+ if "nsfw_likelihood" in result.keys():
+ formatted_result.append(
+ "nsfw_likelihood: " + str(result["nsfw_likelihood"])
+ )
+
+ for label, likelihood in zip(result["label"], result["likelihood"]):
+ formatted_result.append(f'"{label}": {str(likelihood)}')
+
+ return "\n".join(formatted_result)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ query_params = {"text": query, "language": self.language}
+ return self._call_eden_ai(query_params)
diff --git a/libs/community/langchain_community/tools/eleven_labs/__init__.py b/libs/community/langchain_community/tools/eleven_labs/__init__.py
new file mode 100644
index 00000000000..3cb16a41603
--- /dev/null
+++ b/libs/community/langchain_community/tools/eleven_labs/__init__.py
@@ -0,0 +1,5 @@
+"""Eleven Labs Services Tools."""
+
+from langchain_community.tools.eleven_labs.text2speech import ElevenLabsText2SpeechTool
+
+__all__ = ["ElevenLabsText2SpeechTool"]
diff --git a/libs/community/langchain_community/tools/eleven_labs/models.py b/libs/community/langchain_community/tools/eleven_labs/models.py
new file mode 100644
index 00000000000..c977b2972f7
--- /dev/null
+++ b/libs/community/langchain_community/tools/eleven_labs/models.py
@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class ElevenLabsModel(str, Enum):
+ """Models available for Eleven Labs Text2Speech."""
+
+ MULTI_LINGUAL = "eleven_multilingual_v1"
+ MONO_LINGUAL = "eleven_monolingual_v1"
diff --git a/libs/community/langchain_community/tools/eleven_labs/text2speech.py b/libs/community/langchain_community/tools/eleven_labs/text2speech.py
new file mode 100644
index 00000000000..d1a68ba7a5c
--- /dev/null
+++ b/libs/community/langchain_community/tools/eleven_labs/text2speech.py
@@ -0,0 +1,80 @@
+import tempfile
+from enum import Enum
+from typing import Any, Dict, Optional, Union
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.tools import BaseTool
+from langchain_core.utils import get_from_dict_or_env
+
+
+def _import_elevenlabs() -> Any:
+ try:
+ import elevenlabs
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import elevenlabs, please install `pip install elevenlabs`."
+ ) from e
+ return elevenlabs
+
+
+class ElevenLabsModel(str, Enum):
+ """Models available for Eleven Labs Text2Speech."""
+
+ MULTI_LINGUAL = "eleven_multilingual_v1"
+ MONO_LINGUAL = "eleven_monolingual_v1"
+
+
+class ElevenLabsText2SpeechTool(BaseTool):
+ """Tool that queries the Eleven Labs Text2Speech API.
+
+ In order to set this up, follow instructions at:
+ https://docs.elevenlabs.io/welcome/introduction
+ """
+
+ model: Union[ElevenLabsModel, str] = ElevenLabsModel.MULTI_LINGUAL
+
+ name: str = "eleven_labs_text2speech"
+ description: str = (
+ "A wrapper around Eleven Labs Text2Speech. "
+ "Useful for when you need to convert text to speech. "
+ "It supports multiple languages, including English, German, Polish, "
+ "Spanish, Italian, French, Portuguese, and Hindi. "
+ )
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ _ = get_from_dict_or_env(values, "eleven_api_key", "ELEVEN_API_KEY")
+
+ return values
+
+ def _run(
+ self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ """Use the tool."""
+ elevenlabs = _import_elevenlabs()
+ try:
+ speech = elevenlabs.generate(text=query, model=self.model)
+ with tempfile.NamedTemporaryFile(
+ mode="bx", suffix=".wav", delete=False
+ ) as f:
+ f.write(speech)
+ return f.name
+ except Exception as e:
+ raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}")
+
+ def play(self, speech_file: str) -> None:
+ """Play the text as speech."""
+ elevenlabs = _import_elevenlabs()
+ with open(speech_file, mode="rb") as f:
+ speech = f.read()
+
+ elevenlabs.play(speech)
+
+ def stream_speech(self, query: str) -> None:
+ """Stream the text as speech as it is generated.
+ Play the text in your speakers."""
+ elevenlabs = _import_elevenlabs()
+ speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True)
+ elevenlabs.stream(speech_stream)
diff --git a/libs/community/langchain_community/tools/file_management/__init__.py b/libs/community/langchain_community/tools/file_management/__init__.py
new file mode 100644
index 00000000000..395f5d5ea60
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/__init__.py
@@ -0,0 +1,19 @@
+"""File Management Tools."""
+
+from langchain_community.tools.file_management.copy import CopyFileTool
+from langchain_community.tools.file_management.delete import DeleteFileTool
+from langchain_community.tools.file_management.file_search import FileSearchTool
+from langchain_community.tools.file_management.list_dir import ListDirectoryTool
+from langchain_community.tools.file_management.move import MoveFileTool
+from langchain_community.tools.file_management.read import ReadFileTool
+from langchain_community.tools.file_management.write import WriteFileTool
+
+__all__ = [
+ "CopyFileTool",
+ "DeleteFileTool",
+ "FileSearchTool",
+ "MoveFileTool",
+ "ReadFileTool",
+ "WriteFileTool",
+ "ListDirectoryTool",
+]
diff --git a/libs/community/langchain_community/tools/file_management/copy.py b/libs/community/langchain_community/tools/file_management/copy.py
new file mode 100644
index 00000000000..f91081003d1
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/copy.py
@@ -0,0 +1,53 @@
+import shutil
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.file_management.utils import (
+ INVALID_PATH_TEMPLATE,
+ BaseFileToolMixin,
+ FileValidationError,
+)
+
+
+class FileCopyInput(BaseModel):
+ """Input for CopyFileTool."""
+
+ source_path: str = Field(..., description="Path of the file to copy")
+ destination_path: str = Field(..., description="Path to save the copied file")
+
+
+class CopyFileTool(BaseFileToolMixin, BaseTool):
+ """Tool that copies a file."""
+
+ name: str = "copy_file"
+ args_schema: Type[BaseModel] = FileCopyInput
+ description: str = "Create a copy of a file in a specified location"
+
+ def _run(
+ self,
+ source_path: str,
+ destination_path: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ source_path_ = self.get_relative_path(source_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(
+ arg_name="source_path", value=source_path
+ )
+ try:
+ destination_path_ = self.get_relative_path(destination_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(
+ arg_name="destination_path", value=destination_path
+ )
+ try:
+ shutil.copy2(source_path_, destination_path_, follow_symlinks=False)
+ return f"File copied successfully from {source_path} to {destination_path}."
+ except Exception as e:
+ return "Error: " + str(e)
+
+ # TODO: Add aiofiles method
diff --git a/libs/community/langchain_community/tools/file_management/delete.py b/libs/community/langchain_community/tools/file_management/delete.py
new file mode 100644
index 00000000000..c2694762aed
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/delete.py
@@ -0,0 +1,45 @@
+import os
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.file_management.utils import (
+ INVALID_PATH_TEMPLATE,
+ BaseFileToolMixin,
+ FileValidationError,
+)
+
+
+class FileDeleteInput(BaseModel):
+ """Input for DeleteFileTool."""
+
+ file_path: str = Field(..., description="Path of the file to delete")
+
+
+class DeleteFileTool(BaseFileToolMixin, BaseTool):
+ """Tool that deletes a file."""
+
+ name: str = "file_delete"
+ args_schema: Type[BaseModel] = FileDeleteInput
+ description: str = "Delete a file"
+
+ def _run(
+ self,
+ file_path: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ file_path_ = self.get_relative_path(file_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(arg_name="file_path", value=file_path)
+ if not file_path_.exists():
+ return f"Error: no such file or directory: {file_path}"
+ try:
+ os.remove(file_path_)
+ return f"File deleted successfully: {file_path}."
+ except Exception as e:
+ return "Error: " + str(e)
+
+ # TODO: Add aiofiles method
diff --git a/libs/community/langchain_community/tools/file_management/file_search.py b/libs/community/langchain_community/tools/file_management/file_search.py
new file mode 100644
index 00000000000..c77abd0bef9
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/file_search.py
@@ -0,0 +1,62 @@
+import fnmatch
+import os
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.file_management.utils import (
+ INVALID_PATH_TEMPLATE,
+ BaseFileToolMixin,
+ FileValidationError,
+)
+
+
+class FileSearchInput(BaseModel):
+ """Input for FileSearchTool."""
+
+ dir_path: str = Field(
+ default=".",
+ description="Subdirectory to search in.",
+ )
+ pattern: str = Field(
+ ...,
+ description="Unix shell regex, where * matches everything.",
+ )
+
+
+class FileSearchTool(BaseFileToolMixin, BaseTool):
+ """Tool that searches for files in a subdirectory that match a regex pattern."""
+
+ name: str = "file_search"
+ args_schema: Type[BaseModel] = FileSearchInput
+ description: str = (
+ "Recursively search for files in a subdirectory that match the regex pattern"
+ )
+
+ def _run(
+ self,
+ pattern: str,
+ dir_path: str = ".",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ dir_path_ = self.get_relative_path(dir_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(arg_name="dir_path", value=dir_path)
+ matches = []
+ try:
+ for root, _, filenames in os.walk(dir_path_):
+ for filename in fnmatch.filter(filenames, pattern):
+ absolute_path = os.path.join(root, filename)
+ relative_path = os.path.relpath(absolute_path, dir_path_)
+ matches.append(relative_path)
+ if matches:
+ return "\n".join(matches)
+ else:
+ return f"No files found for pattern {pattern} in directory {dir_path}"
+ except Exception as e:
+ return "Error: " + str(e)
+
+ # TODO: Add aiofiles method
diff --git a/libs/community/langchain_community/tools/file_management/list_dir.py b/libs/community/langchain_community/tools/file_management/list_dir.py
new file mode 100644
index 00000000000..d8b700134ae
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/list_dir.py
@@ -0,0 +1,46 @@
+import os
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.file_management.utils import (
+ INVALID_PATH_TEMPLATE,
+ BaseFileToolMixin,
+ FileValidationError,
+)
+
+
+class DirectoryListingInput(BaseModel):
+ """Input for ListDirectoryTool."""
+
+ dir_path: str = Field(default=".", description="Subdirectory to list.")
+
+
+class ListDirectoryTool(BaseFileToolMixin, BaseTool):
+ """Tool that lists files and directories in a specified folder."""
+
+ name: str = "list_directory"
+ args_schema: Type[BaseModel] = DirectoryListingInput
+ description: str = "List files and directories in a specified folder"
+
+ def _run(
+ self,
+ dir_path: str = ".",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ dir_path_ = self.get_relative_path(dir_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(arg_name="dir_path", value=dir_path)
+ try:
+ entries = os.listdir(dir_path_)
+ if entries:
+ return "\n".join(entries)
+ else:
+ return f"No files found in directory {dir_path}"
+ except Exception as e:
+ return "Error: " + str(e)
+
+ # TODO: Add aiofiles method
diff --git a/libs/community/langchain_community/tools/file_management/move.py b/libs/community/langchain_community/tools/file_management/move.py
new file mode 100644
index 00000000000..fc3bb778d8e
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/move.py
@@ -0,0 +1,56 @@
+import shutil
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.file_management.utils import (
+ INVALID_PATH_TEMPLATE,
+ BaseFileToolMixin,
+ FileValidationError,
+)
+
+
+class FileMoveInput(BaseModel):
+ """Input for MoveFileTool."""
+
+ source_path: str = Field(..., description="Path of the file to move")
+ destination_path: str = Field(..., description="New path for the moved file")
+
+
+class MoveFileTool(BaseFileToolMixin, BaseTool):
+ """Tool that moves a file."""
+
+ name: str = "move_file"
+ args_schema: Type[BaseModel] = FileMoveInput
+ description: str = "Move or rename a file from one location to another"
+
+ def _run(
+ self,
+ source_path: str,
+ destination_path: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ source_path_ = self.get_relative_path(source_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(
+ arg_name="source_path", value=source_path
+ )
+ try:
+ destination_path_ = self.get_relative_path(destination_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(
+ arg_name="destination_path_", value=destination_path_
+ )
+ if not source_path_.exists():
+ return f"Error: no such file or directory {source_path}"
+ try:
+ # shutil.move expects str args in 3.8
+ shutil.move(str(source_path_), destination_path_)
+ return f"File moved successfully from {source_path} to {destination_path}."
+ except Exception as e:
+ return "Error: " + str(e)
+
+ # TODO: Add aiofiles method
diff --git a/libs/community/langchain_community/tools/file_management/read.py b/libs/community/langchain_community/tools/file_management/read.py
new file mode 100644
index 00000000000..42cc1515e7e
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/read.py
@@ -0,0 +1,45 @@
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.file_management.utils import (
+ INVALID_PATH_TEMPLATE,
+ BaseFileToolMixin,
+ FileValidationError,
+)
+
+
+class ReadFileInput(BaseModel):
+ """Input for ReadFileTool."""
+
+ file_path: str = Field(..., description="name of file")
+
+
+class ReadFileTool(BaseFileToolMixin, BaseTool):
+ """Tool that reads a file."""
+
+ name: str = "read_file"
+ args_schema: Type[BaseModel] = ReadFileInput
+ description: str = "Read file from disk"
+
+ def _run(
+ self,
+ file_path: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ read_path = self.get_relative_path(file_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(arg_name="file_path", value=file_path)
+ if not read_path.exists():
+ return f"Error: no such file or directory: {file_path}"
+ try:
+ with read_path.open("r", encoding="utf-8") as f:
+ content = f.read()
+ return content
+ except Exception as e:
+ return "Error: " + str(e)
+
+ # TODO: Add aiofiles method
diff --git a/libs/community/langchain_community/tools/file_management/utils.py b/libs/community/langchain_community/tools/file_management/utils.py
new file mode 100644
index 00000000000..b2a3632ecaa
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/utils.py
@@ -0,0 +1,54 @@
+import sys
+from pathlib import Path
+from typing import Optional
+
+from langchain_core.pydantic_v1 import BaseModel
+
+
+def is_relative_to(path: Path, root: Path) -> bool:
+ """Check if path is relative to root."""
+ if sys.version_info >= (3, 9):
+ # No need for a try/except block in Python 3.8+.
+ return path.is_relative_to(root)
+ try:
+ path.relative_to(root)
+ return True
+ except ValueError:
+ return False
+
+
+INVALID_PATH_TEMPLATE = (
+ "Error: Access denied to {arg_name}: {value}."
+ " Permission granted exclusively to the current working directory"
+)
+
+
+class FileValidationError(ValueError):
+ """Error for paths outside the root directory."""
+
+
+class BaseFileToolMixin(BaseModel):
+ """Mixin for file system tools."""
+
+ root_dir: Optional[str] = None
+ """The final path will be chosen relative to root_dir if specified."""
+
+ def get_relative_path(self, file_path: str) -> Path:
+ """Get the relative path, returning an error if unsupported."""
+ if self.root_dir is None:
+ return Path(file_path)
+ return get_validated_relative_path(Path(self.root_dir), file_path)
+
+
+def get_validated_relative_path(root: Path, user_path: str) -> Path:
+ """Resolve a relative path, raising an error if not within the root directory."""
+ # Note, this still permits symlinks from outside that point within the root.
+ # Further validation would be needed if those are to be disallowed.
+ root = root.resolve()
+ full_path = (root / user_path).resolve()
+
+ if not is_relative_to(full_path, root):
+ raise FileValidationError(
+ f"Path {user_path} is outside of the allowed directory {root}"
+ )
+ return full_path
diff --git a/libs/community/langchain_community/tools/file_management/write.py b/libs/community/langchain_community/tools/file_management/write.py
new file mode 100644
index 00000000000..b42b70256b7
--- /dev/null
+++ b/libs/community/langchain_community/tools/file_management/write.py
@@ -0,0 +1,51 @@
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.file_management.utils import (
+ INVALID_PATH_TEMPLATE,
+ BaseFileToolMixin,
+ FileValidationError,
+)
+
+
+class WriteFileInput(BaseModel):
+ """Input for WriteFileTool."""
+
+ file_path: str = Field(..., description="name of file")
+ text: str = Field(..., description="text to write to file")
+ append: bool = Field(
+ default=False, description="Whether to append to an existing file."
+ )
+
+
+class WriteFileTool(BaseFileToolMixin, BaseTool):
+ """Tool that writes a file to disk."""
+
+ name: str = "write_file"
+ args_schema: Type[BaseModel] = WriteFileInput
+ description: str = "Write file to disk"
+
+ def _run(
+ self,
+ file_path: str,
+ text: str,
+ append: bool = False,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ write_path = self.get_relative_path(file_path)
+ except FileValidationError:
+ return INVALID_PATH_TEMPLATE.format(arg_name="file_path", value=file_path)
+ try:
+ write_path.parent.mkdir(exist_ok=True, parents=False)
+ mode = "a" if append else "w"
+ with write_path.open(mode, encoding="utf-8") as f:
+ f.write(text)
+ return f"File written successfully to {file_path}."
+ except Exception as e:
+ return "Error: " + str(e)
+
+ # TODO: Add aiofiles method
diff --git a/libs/community/langchain_community/tools/github/__init__.py b/libs/community/langchain_community/tools/github/__init__.py
new file mode 100644
index 00000000000..e737ac26ba0
--- /dev/null
+++ b/libs/community/langchain_community/tools/github/__init__.py
@@ -0,0 +1 @@
+""" GitHub Tool """
diff --git a/libs/community/langchain_community/tools/github/prompt.py b/libs/community/langchain_community/tools/github/prompt.py
new file mode 100644
index 00000000000..3d66713e02b
--- /dev/null
+++ b/libs/community/langchain_community/tools/github/prompt.py
@@ -0,0 +1,100 @@
+# flake8: noqa
+GET_ISSUES_PROMPT = """
+This tool will fetch a list of the repository's issues. It will return the title, and issue number of 5 issues. It takes no input."""
+
+GET_ISSUE_PROMPT = """
+This tool will fetch the title, body, and comment thread of a specific issue. **VERY IMPORTANT**: You must specify the issue number as an integer."""
+
+COMMENT_ON_ISSUE_PROMPT = """
+This tool is useful when you need to comment on a GitHub issue. Simply pass in the issue number and the comment you would like to make. Please use this sparingly as we don't want to clutter the comment threads. **VERY IMPORTANT**: Your input to this tool MUST strictly follow these rules:
+
+- First you must specify the issue number as an integer
+- Then you must place two newlines
+- Then you must specify your comment"""
+
+CREATE_PULL_REQUEST_PROMPT = """
+This tool is useful when you need to create a new pull request in a GitHub repository. **VERY IMPORTANT**: Your input to this tool MUST strictly follow these rules:
+
+- First you must specify the title of the pull request
+- Then you must place two newlines
+- Then you must write the body or description of the pull request
+
+When appropriate, always reference relevant issues in the body by using the syntax `closes #>>> OLD
+- Then you must specify the new contents which you would like to replace the old contents with wrapped in NEW <<<< and >>>> NEW
+
+For example, if you would like to replace the contents of the file /test/test.txt from "old contents" to "new contents", you would pass in the following string:
+
+test/test.txt
+
+This is text that will not be changed
+OLD <<<<
+old contents
+>>>> OLD
+NEW <<<<
+new contents
+>>>> NEW"""
+
+DELETE_FILE_PROMPT = """
+This tool is a wrapper for the GitHub API, useful when you need to delete a file in a GitHub repository. Simply pass in the full file path of the file you would like to delete. **IMPORTANT**: the path must not start with a slash"""
+
+GET_PR_PROMPT = """
+This tool will fetch the title, body, comment thread and commit history of a specific Pull Request (by PR number). **VERY IMPORTANT**: You must specify the PR number as an integer."""
+
+LIST_PRS_PROMPT = """
+This tool will fetch a list of the repository's Pull Requests (PRs). It will return the title, and PR number of 5 PRs. It takes no input."""
+
+LIST_PULL_REQUEST_FILES = """
+This tool will fetch the full text of all files in a pull request (PR) given the PR number as an input. This is useful for understanding the code changes in a PR or contributing to it. **VERY IMPORTANT**: You must specify the PR number as an integer input parameter."""
+
+OVERVIEW_EXISTING_FILES_IN_MAIN = """
+This tool will provide an overview of all existing files in the main branch of the repository. It will list the file names, their respective paths, and a brief summary of their contents. This can be useful for understanding the structure and content of the repository, especially when navigating through large codebases. No input parameters are required."""
+
+OVERVIEW_EXISTING_FILES_BOT_BRANCH = """
+This tool will provide an overview of all files in your current working branch where you should implement changes. This is great for getting a high level overview of the structure of your code. No input parameters are required."""
+
+SEARCH_ISSUES_AND_PRS_PROMPT = """
+This tool will search for issues and pull requests in the repository. **VERY IMPORTANT**: You must specify the search query as a string input parameter."""
+
+SEARCH_CODE_PROMPT = """
+This tool will search for code in the repository. **VERY IMPORTANT**: You must specify the search query as a string input parameter."""
+
+CREATE_REVIEW_REQUEST_PROMPT = """
+This tool will create a review request on the open pull request that matches the current active branch. **VERY IMPORTANT**: You must specify the username of the person who is being requested as a string input parameter."""
+
+LIST_BRANCHES_IN_REPO_PROMPT = """
+This tool will fetch a list of all branches in the repository. It will return the name of each branch. No input parameters are required."""
+
+SET_ACTIVE_BRANCH_PROMPT = """
+This tool will set the active branch in the repository, similar to `git checkout ` and `git switch -c `. **VERY IMPORTANT**: You must specify the name of the branch as a string input parameter."""
+
+CREATE_BRANCH_PROMPT = """
+This tool will create a new branch in the repository. **VERY IMPORTANT**: You must specify the name of the new branch as a string input parameter."""
+
+GET_FILES_FROM_DIRECTORY_PROMPT = """
+This tool will fetch a list of all files in a specified directory. **VERY IMPORTANT**: You must specify the path of the directory as a string input parameter."""
diff --git a/libs/community/langchain_community/tools/github/tool.py b/libs/community/langchain_community/tools/github/tool.py
new file mode 100644
index 00000000000..89253183632
--- /dev/null
+++ b/libs/community/langchain_community/tools/github/tool.py
@@ -0,0 +1,37 @@
+"""
+This tool allows agents to interact with the pygithub library
+and operate on a GitHub repository.
+
+To use this tool, you must first set as environment variables:
+ GITHUB_API_TOKEN
+ GITHUB_REPOSITORY -> format: {owner}/{repo}
+
+"""
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.github import GitHubAPIWrapper
+
+
+class GitHubAction(BaseTool):
+ """Tool for interacting with the GitHub API."""
+
+ api_wrapper: GitHubAPIWrapper = Field(default_factory=GitHubAPIWrapper)
+ mode: str
+ name: str = ""
+ description: str = ""
+ args_schema: Optional[Type[BaseModel]] = None
+
+ def _run(
+ self,
+ instructions: Optional[str] = "",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the GitHub API to run an operation."""
+ if not instructions or instructions == "{}":
+ # Catch other forms of empty input that GPT-4 likes to send.
+ instructions = ""
+ return self.api_wrapper.run(self.mode, instructions)
diff --git a/libs/community/langchain_community/tools/gitlab/__init__.py b/libs/community/langchain_community/tools/gitlab/__init__.py
new file mode 100644
index 00000000000..4b6d6367663
--- /dev/null
+++ b/libs/community/langchain_community/tools/gitlab/__init__.py
@@ -0,0 +1 @@
+""" GitLab Tool """
diff --git a/libs/community/langchain_community/tools/gitlab/prompt.py b/libs/community/langchain_community/tools/gitlab/prompt.py
new file mode 100644
index 00000000000..3f303155cd4
--- /dev/null
+++ b/libs/community/langchain_community/tools/gitlab/prompt.py
@@ -0,0 +1,70 @@
+# flake8: noqa
+GET_ISSUES_PROMPT = """
+This tool will fetch a list of the repository's issues. It will return the title, and issue number of 5 issues. It takes no input.
+"""
+
+GET_ISSUE_PROMPT = """
+This tool will fetch the title, body, and comment thread of a specific issue. **VERY IMPORTANT**: You must specify the issue number as an integer.
+"""
+
+COMMENT_ON_ISSUE_PROMPT = """
+This tool is useful when you need to comment on a GitLab issue. Simply pass in the issue number and the comment you would like to make. Please use this sparingly as we don't want to clutter the comment threads. **VERY IMPORTANT**: Your input to this tool MUST strictly follow these rules:
+
+- First you must specify the issue number as an integer
+- Then you must place two newlines
+- Then you must specify your comment
+"""
+CREATE_PULL_REQUEST_PROMPT = """
+This tool is useful when you need to create a new pull request in a GitLab repository. **VERY IMPORTANT**: Your input to this tool MUST strictly follow these rules:
+
+- First you must specify the title of the pull request
+- Then you must place two newlines
+- Then you must write the body or description of the pull request
+
+To reference an issue in the body, put its issue number directly after a #.
+For example, if you would like to create a pull request called "README updates" with contents "added contributors' names, closes issue #3", you would pass in the following string:
+
+README updates
+
+added contributors' names, closes issue #3
+"""
+CREATE_FILE_PROMPT = """
+This tool is a wrapper for the GitLab API, useful when you need to create a file in a GitLab repository. **VERY IMPORTANT**: Your input to this tool MUST strictly follow these rules:
+
+- First you must specify which file to create by passing a full file path (**IMPORTANT**: the path must not start with a slash)
+- Then you must specify the contents of the file
+
+For example, if you would like to create a file called /test/test.txt with contents "test contents", you would pass in the following string:
+
+test/test.txt
+
+test contents
+"""
+
+READ_FILE_PROMPT = """
+This tool is a wrapper for the GitLab API, useful when you need to read the contents of a file in a GitLab repository. Simply pass in the full file path of the file you would like to read. **IMPORTANT**: the path must not start with a slash
+"""
+
+UPDATE_FILE_PROMPT = """
+This tool is a wrapper for the GitLab API, useful when you need to update the contents of a file in a GitLab repository. **VERY IMPORTANT**: Your input to this tool MUST strictly follow these rules:
+
+- First you must specify which file to modify by passing a full file path (**IMPORTANT**: the path must not start with a slash)
+- Then you must specify the old contents which you would like to replace wrapped in OLD <<<< and >>>> OLD
+- Then you must specify the new contents which you would like to replace the old contents with wrapped in NEW <<<< and >>>> NEW
+
+For example, if you would like to replace the contents of the file /test/test.txt from "old contents" to "new contents", you would pass in the following string:
+
+test/test.txt
+
+This is text that will not be changed
+OLD <<<<
+old contents
+>>>> OLD
+NEW <<<<
+new contents
+>>>> NEW
+"""
+
+DELETE_FILE_PROMPT = """
+This tool is a wrapper for the GitLab API, useful when you need to delete a file in a GitLab repository. Simply pass in the full file path of the file you would like to delete. **IMPORTANT**: the path must not start with a slash
+"""
diff --git a/libs/community/langchain_community/tools/gitlab/tool.py b/libs/community/langchain_community/tools/gitlab/tool.py
new file mode 100644
index 00000000000..92ea8b98bf9
--- /dev/null
+++ b/libs/community/langchain_community/tools/gitlab/tool.py
@@ -0,0 +1,33 @@
+"""
+This tool allows agents to interact with the python-gitlab library
+and operate on a GitLab repository.
+
+To use this tool, you must first set as environment variables:
+ GITLAB_PRIVATE_ACCESS_TOKEN
+ GITLAB_REPOSITORY -> format: {owner}/{repo}
+
+"""
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.gitlab import GitLabAPIWrapper
+
+
+class GitLabAction(BaseTool):
+ """Tool for interacting with the GitLab API."""
+
+ api_wrapper: GitLabAPIWrapper = Field(default_factory=GitLabAPIWrapper)
+ mode: str
+ name: str = ""
+ description: str = ""
+
+ def _run(
+ self,
+ instructions: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the GitLab API to run an operation."""
+ return self.api_wrapper.run(self.mode, instructions)
diff --git a/libs/community/langchain_community/tools/gmail/__init__.py b/libs/community/langchain_community/tools/gmail/__init__.py
new file mode 100644
index 00000000000..7ef66e21dc7
--- /dev/null
+++ b/libs/community/langchain_community/tools/gmail/__init__.py
@@ -0,0 +1,17 @@
+"""Gmail tools."""
+
+from langchain_community.tools.gmail.create_draft import GmailCreateDraft
+from langchain_community.tools.gmail.get_message import GmailGetMessage
+from langchain_community.tools.gmail.get_thread import GmailGetThread
+from langchain_community.tools.gmail.search import GmailSearch
+from langchain_community.tools.gmail.send_message import GmailSendMessage
+from langchain_community.tools.gmail.utils import get_gmail_credentials
+
+__all__ = [
+ "GmailCreateDraft",
+ "GmailSendMessage",
+ "GmailSearch",
+ "GmailGetMessage",
+ "GmailGetThread",
+ "get_gmail_credentials",
+]
diff --git a/libs/community/langchain_community/tools/gmail/base.py b/libs/community/langchain_community/tools/gmail/base.py
new file mode 100644
index 00000000000..b96e16117fb
--- /dev/null
+++ b/libs/community/langchain_community/tools/gmail/base.py
@@ -0,0 +1,37 @@
+"""Base class for Gmail tools."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.gmail.utils import build_resource_service
+
+if TYPE_CHECKING:
+ # This is for linting and IDE typehints
+ from googleapiclient.discovery import Resource
+else:
+ try:
+ # We do this so pydantic can resolve the types when instantiating
+ from googleapiclient.discovery import Resource
+ except ImportError:
+ pass
+
+
+class GmailBaseTool(BaseTool):
+ """Base class for Gmail tools."""
+
+ api_resource: Resource = Field(default_factory=build_resource_service)
+
+ @classmethod
+ def from_api_resource(cls, api_resource: Resource) -> "GmailBaseTool":
+ """Create a tool from an api resource.
+
+ Args:
+ api_resource: The api resource to use.
+
+ Returns:
+ A tool.
+ """
+ return cls(service=api_resource)
diff --git a/libs/community/langchain_community/tools/gmail/create_draft.py b/libs/community/langchain_community/tools/gmail/create_draft.py
new file mode 100644
index 00000000000..b8e6ac93c61
--- /dev/null
+++ b/libs/community/langchain_community/tools/gmail/create_draft.py
@@ -0,0 +1,87 @@
+import base64
+from email.message import EmailMessage
+from typing import List, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.gmail.base import GmailBaseTool
+
+
+class CreateDraftSchema(BaseModel):
+ """Input for CreateDraftTool."""
+
+ message: str = Field(
+ ...,
+ description="The message to include in the draft.",
+ )
+ to: List[str] = Field(
+ ...,
+ description="The list of recipients.",
+ )
+ subject: str = Field(
+ ...,
+ description="The subject of the message.",
+ )
+ cc: Optional[List[str]] = Field(
+ None,
+ description="The list of CC recipients.",
+ )
+ bcc: Optional[List[str]] = Field(
+ None,
+ description="The list of BCC recipients.",
+ )
+
+
+class GmailCreateDraft(GmailBaseTool):
+ """Tool that creates a draft email for Gmail."""
+
+ name: str = "create_gmail_draft"
+ description: str = (
+ "Use this tool to create a draft email with the provided message fields."
+ )
+ args_schema: Type[CreateDraftSchema] = CreateDraftSchema
+
+ def _prepare_draft_message(
+ self,
+ message: str,
+ to: List[str],
+ subject: str,
+ cc: Optional[List[str]] = None,
+ bcc: Optional[List[str]] = None,
+ ) -> dict:
+ draft_message = EmailMessage()
+ draft_message.set_content(message)
+
+ draft_message["To"] = ", ".join(to)
+ draft_message["Subject"] = subject
+ if cc is not None:
+ draft_message["Cc"] = ", ".join(cc)
+
+ if bcc is not None:
+ draft_message["Bcc"] = ", ".join(bcc)
+
+ encoded_message = base64.urlsafe_b64encode(draft_message.as_bytes()).decode()
+ return {"message": {"raw": encoded_message}}
+
+ def _run(
+ self,
+ message: str,
+ to: List[str],
+ subject: str,
+ cc: Optional[List[str]] = None,
+ bcc: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ create_message = self._prepare_draft_message(message, to, subject, cc, bcc)
+ draft = (
+ self.api_resource.users()
+ .drafts()
+ .create(userId="me", body=create_message)
+ .execute()
+ )
+ output = f'Draft created. Draft Id: {draft["id"]}'
+ return output
+ except Exception as e:
+ raise Exception(f"An error occurred: {e}")
diff --git a/libs/community/langchain_community/tools/gmail/get_message.py b/libs/community/langchain_community/tools/gmail/get_message.py
new file mode 100644
index 00000000000..79b963453e8
--- /dev/null
+++ b/libs/community/langchain_community/tools/gmail/get_message.py
@@ -0,0 +1,70 @@
+import base64
+import email
+from typing import Dict, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.gmail.base import GmailBaseTool
+from langchain_community.tools.gmail.utils import clean_email_body
+
+
+class SearchArgsSchema(BaseModel):
+ """Input for GetMessageTool."""
+
+ message_id: str = Field(
+ ...,
+ description="The unique ID of the email message, retrieved from a search.",
+ )
+
+
+class GmailGetMessage(GmailBaseTool):
+ """Tool that gets a message by ID from Gmail."""
+
+ name: str = "get_gmail_message"
+ description: str = (
+ "Use this tool to fetch an email by message ID."
+ " Returns the thread ID, snippet, body, subject, and sender."
+ )
+ args_schema: Type[SearchArgsSchema] = SearchArgsSchema
+
+ def _run(
+ self,
+ message_id: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> Dict:
+ """Run the tool."""
+ query = (
+ self.api_resource.users()
+ .messages()
+ .get(userId="me", format="raw", id=message_id)
+ )
+ message_data = query.execute()
+ raw_message = base64.urlsafe_b64decode(message_data["raw"])
+
+ email_msg = email.message_from_bytes(raw_message)
+
+ subject = email_msg["Subject"]
+ sender = email_msg["From"]
+
+ message_body = ""
+ if email_msg.is_multipart():
+ for part in email_msg.walk():
+ ctype = part.get_content_type()
+ cdispo = str(part.get("Content-Disposition"))
+ if ctype == "text/plain" and "attachment" not in cdispo:
+ message_body = part.get_payload(decode=True).decode("utf-8")
+ break
+ else:
+ message_body = email_msg.get_payload(decode=True).decode("utf-8")
+
+ body = clean_email_body(message_body)
+
+ return {
+ "id": message_id,
+ "threadId": message_data["threadId"],
+ "snippet": message_data["snippet"],
+ "body": body,
+ "subject": subject,
+ "sender": sender,
+ }
diff --git a/libs/community/langchain_community/tools/gmail/get_thread.py b/libs/community/langchain_community/tools/gmail/get_thread.py
new file mode 100644
index 00000000000..c42c90924b6
--- /dev/null
+++ b/libs/community/langchain_community/tools/gmail/get_thread.py
@@ -0,0 +1,48 @@
+from typing import Dict, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.gmail.base import GmailBaseTool
+
+
+class GetThreadSchema(BaseModel):
+ """Input for GetMessageTool."""
+
+ # From https://support.google.com/mail/answer/7190?hl=en
+ thread_id: str = Field(
+ ...,
+ description="The thread ID.",
+ )
+
+
+class GmailGetThread(GmailBaseTool):
+ """Tool that gets a thread by ID from Gmail."""
+
+ name: str = "get_gmail_thread"
+ description: str = (
+ "Use this tool to search for email messages."
+ " The input must be a valid Gmail query."
+ " The output is a JSON list of messages."
+ )
+ args_schema: Type[GetThreadSchema] = GetThreadSchema
+
+ def _run(
+ self,
+ thread_id: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> Dict:
+ """Run the tool."""
+ query = self.api_resource.users().threads().get(userId="me", id=thread_id)
+ thread_data = query.execute()
+ if not isinstance(thread_data, dict):
+ raise ValueError("The output of the query must be a list.")
+ messages = thread_data["messages"]
+ thread_data["messages"] = []
+ keys_to_keep = ["id", "snippet", "snippet"]
+ # TODO: Parse body.
+ for message in messages:
+ thread_data["messages"].append(
+ {k: message[k] for k in keys_to_keep if k in message}
+ )
+ return thread_data
diff --git a/libs/community/langchain_community/tools/gmail/search.py b/libs/community/langchain_community/tools/gmail/search.py
new file mode 100644
index 00000000000..bcd16e09311
--- /dev/null
+++ b/libs/community/langchain_community/tools/gmail/search.py
@@ -0,0 +1,140 @@
+import base64
+import email
+from enum import Enum
+from typing import Any, Dict, List, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.gmail.base import GmailBaseTool
+from langchain_community.tools.gmail.utils import clean_email_body
+
+
+class Resource(str, Enum):
+ """Enumerator of Resources to search."""
+
+ THREADS = "threads"
+ MESSAGES = "messages"
+
+
+class SearchArgsSchema(BaseModel):
+ """Input for SearchGmailTool."""
+
+ # From https://support.google.com/mail/answer/7190?hl=en
+ query: str = Field(
+ ...,
+ description="The Gmail query. Example filters include from:sender,"
+ " to:recipient, subject:subject, -filtered_term,"
+ " in:folder, is:important|read|starred, after:year/mo/date, "
+ "before:year/mo/date, label:label_name"
+ ' "exact phrase".'
+ " Search newer/older than using d (day), m (month), and y (year): "
+ "newer_than:2d, older_than:1y."
+ " Attachments with extension example: filename:pdf. Multiple term"
+ " matching example: from:amy OR from:david.",
+ )
+ resource: Resource = Field(
+ default=Resource.MESSAGES,
+ description="Whether to search for threads or messages.",
+ )
+ max_results: int = Field(
+ default=10,
+ description="The maximum number of results to return.",
+ )
+
+
+class GmailSearch(GmailBaseTool):
+ """Tool that searches for messages or threads in Gmail."""
+
+ name: str = "search_gmail"
+ description: str = (
+ "Use this tool to search for email messages or threads."
+ " The input must be a valid Gmail query."
+ " The output is a JSON list of the requested resource."
+ )
+ args_schema: Type[SearchArgsSchema] = SearchArgsSchema
+
+ def _parse_threads(self, threads: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ # Add the thread message snippets to the thread results
+ results = []
+ for thread in threads:
+ thread_id = thread["id"]
+ thread_data = (
+ self.api_resource.users()
+ .threads()
+ .get(userId="me", id=thread_id)
+ .execute()
+ )
+ messages = thread_data["messages"]
+ thread["messages"] = []
+ for message in messages:
+ snippet = message["snippet"]
+ thread["messages"].append({"snippet": snippet, "id": message["id"]})
+ results.append(thread)
+
+ return results
+
+ def _parse_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ results = []
+ for message in messages:
+ message_id = message["id"]
+ message_data = (
+ self.api_resource.users()
+ .messages()
+ .get(userId="me", format="raw", id=message_id)
+ .execute()
+ )
+
+ raw_message = base64.urlsafe_b64decode(message_data["raw"])
+
+ email_msg = email.message_from_bytes(raw_message)
+
+ subject = email_msg["Subject"]
+ sender = email_msg["From"]
+
+ message_body = ""
+ if email_msg.is_multipart():
+ for part in email_msg.walk():
+ ctype = part.get_content_type()
+ cdispo = str(part.get("Content-Disposition"))
+ if ctype == "text/plain" and "attachment" not in cdispo:
+ message_body = part.get_payload(decode=True).decode("utf-8")
+ break
+ else:
+ message_body = email_msg.get_payload(decode=True).decode("utf-8")
+
+ body = clean_email_body(message_body)
+
+ results.append(
+ {
+ "id": message["id"],
+ "threadId": message_data["threadId"],
+ "snippet": message_data["snippet"],
+ "body": body,
+ "subject": subject,
+ "sender": sender,
+ }
+ )
+ return results
+
+ def _run(
+ self,
+ query: str,
+ resource: Resource = Resource.MESSAGES,
+ max_results: int = 10,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> List[Dict[str, Any]]:
+ """Run the tool."""
+ results = (
+ self.api_resource.users()
+ .messages()
+ .list(userId="me", q=query, maxResults=max_results)
+ .execute()
+ .get(resource.value, [])
+ )
+ if resource == Resource.THREADS:
+ return self._parse_threads(results)
+ elif resource == Resource.MESSAGES:
+ return self._parse_messages(results)
+ else:
+ raise NotImplementedError(f"Resource of type {resource} not implemented.")
diff --git a/libs/community/langchain_community/tools/gmail/send_message.py b/libs/community/langchain_community/tools/gmail/send_message.py
new file mode 100644
index 00000000000..52a2b97d29c
--- /dev/null
+++ b/libs/community/langchain_community/tools/gmail/send_message.py
@@ -0,0 +1,89 @@
+"""Send Gmail messages."""
+import base64
+from email.mime.multipart import MIMEMultipart
+from email.mime.text import MIMEText
+from typing import Any, Dict, List, Optional, Union
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.gmail.base import GmailBaseTool
+
+
+class SendMessageSchema(BaseModel):
+ """Input for SendMessageTool."""
+
+ message: str = Field(
+ ...,
+ description="The message to send.",
+ )
+ to: Union[str, List[str]] = Field(
+ ...,
+ description="The list of recipients.",
+ )
+ subject: str = Field(
+ ...,
+ description="The subject of the message.",
+ )
+ cc: Optional[Union[str, List[str]]] = Field(
+ None,
+ description="The list of CC recipients.",
+ )
+ bcc: Optional[Union[str, List[str]]] = Field(
+ None,
+ description="The list of BCC recipients.",
+ )
+
+
+class GmailSendMessage(GmailBaseTool):
+ """Tool that sends a message to Gmail."""
+
+ name: str = "send_gmail_message"
+ description: str = (
+ "Use this tool to send email messages." " The input is the message, recipients"
+ )
+
+ def _prepare_message(
+ self,
+ message: str,
+ to: Union[str, List[str]],
+ subject: str,
+ cc: Optional[Union[str, List[str]]] = None,
+ bcc: Optional[Union[str, List[str]]] = None,
+ ) -> Dict[str, Any]:
+ """Create a message for an email."""
+ mime_message = MIMEMultipart()
+ mime_message.attach(MIMEText(message, "html"))
+
+ mime_message["To"] = ", ".join(to if isinstance(to, list) else [to])
+ mime_message["Subject"] = subject
+ if cc is not None:
+ mime_message["Cc"] = ", ".join(cc if isinstance(cc, list) else [cc])
+
+ if bcc is not None:
+ mime_message["Bcc"] = ", ".join(bcc if isinstance(bcc, list) else [bcc])
+
+ encoded_message = base64.urlsafe_b64encode(mime_message.as_bytes()).decode()
+ return {"raw": encoded_message}
+
+ def _run(
+ self,
+ message: str,
+ to: Union[str, List[str]],
+ subject: str,
+ cc: Optional[Union[str, List[str]]] = None,
+ bcc: Optional[Union[str, List[str]]] = None,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run the tool."""
+ try:
+ create_message = self._prepare_message(message, to, subject, cc=cc, bcc=bcc)
+ send_message = (
+ self.api_resource.users()
+ .messages()
+ .send(userId="me", body=create_message)
+ )
+ sent_message = send_message.execute()
+ return f'Message sent. Message Id: {sent_message["id"]}'
+ except Exception as error:
+ raise Exception(f"An error occurred: {error}")
diff --git a/libs/community/langchain_community/tools/gmail/utils.py b/libs/community/langchain_community/tools/gmail/utils.py
new file mode 100644
index 00000000000..ee562174c2f
--- /dev/null
+++ b/libs/community/langchain_community/tools/gmail/utils.py
@@ -0,0 +1,132 @@
+"""Gmail tool utils."""
+from __future__ import annotations
+
+import logging
+import os
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+if TYPE_CHECKING:
+ from google.auth.transport.requests import Request
+ from google.oauth2.credentials import Credentials
+ from google_auth_oauthlib.flow import InstalledAppFlow
+ from googleapiclient.discovery import Resource
+ from googleapiclient.discovery import build as build_resource
+
+logger = logging.getLogger(__name__)
+
+
+def import_google() -> Tuple[Request, Credentials]:
+ """Import google libraries.
+
+ Returns:
+ Tuple[Request, Credentials]: Request and Credentials classes.
+ """
+ # google-auth-httplib2
+ try:
+ from google.auth.transport.requests import Request # noqa: F401
+ from google.oauth2.credentials import Credentials # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "You need to install google-auth-httplib2 to use this toolkit. "
+ "Try running pip install --upgrade google-auth-httplib2"
+ )
+ return Request, Credentials
+
+
+def import_installed_app_flow() -> InstalledAppFlow:
+ """Import InstalledAppFlow class.
+
+ Returns:
+ InstalledAppFlow: InstalledAppFlow class.
+ """
+ try:
+ from google_auth_oauthlib.flow import InstalledAppFlow
+ except ImportError:
+ raise ImportError(
+ "You need to install google-auth-oauthlib to use this toolkit. "
+ "Try running pip install --upgrade google-auth-oauthlib"
+ )
+ return InstalledAppFlow
+
+
+def import_googleapiclient_resource_builder() -> build_resource:
+ """Import googleapiclient.discovery.build function.
+
+ Returns:
+ build_resource: googleapiclient.discovery.build function.
+ """
+ try:
+ from googleapiclient.discovery import build
+ except ImportError:
+ raise ImportError(
+ "You need to install googleapiclient to use this toolkit. "
+ "Try running pip install --upgrade google-api-python-client"
+ )
+ return build
+
+
+DEFAULT_SCOPES = ["https://mail.google.com/"]
+DEFAULT_CREDS_TOKEN_FILE = "token.json"
+DEFAULT_CLIENT_SECRETS_FILE = "credentials.json"
+
+
+def get_gmail_credentials(
+ token_file: Optional[str] = None,
+ client_secrets_file: Optional[str] = None,
+ scopes: Optional[List[str]] = None,
+) -> Credentials:
+ """Get credentials."""
+ # From https://developers.google.com/gmail/api/quickstart/python
+ Request, Credentials = import_google()
+ InstalledAppFlow = import_installed_app_flow()
+ creds = None
+ scopes = scopes or DEFAULT_SCOPES
+ token_file = token_file or DEFAULT_CREDS_TOKEN_FILE
+ client_secrets_file = client_secrets_file or DEFAULT_CLIENT_SECRETS_FILE
+ # The file token.json stores the user's access and refresh tokens, and is
+ # created automatically when the authorization flow completes for the first
+ # time.
+ if os.path.exists(token_file):
+ creds = Credentials.from_authorized_user_file(token_file, scopes)
+ # If there are no (valid) credentials available, let the user log in.
+ if not creds or not creds.valid:
+ if creds and creds.expired and creds.refresh_token:
+ creds.refresh(Request())
+ else:
+ # https://developers.google.com/gmail/api/quickstart/python#authorize_credentials_for_a_desktop_application # noqa
+ flow = InstalledAppFlow.from_client_secrets_file(
+ client_secrets_file, scopes
+ )
+ creds = flow.run_local_server(port=0)
+ # Save the credentials for the next run
+ with open(token_file, "w") as token:
+ token.write(creds.to_json())
+ return creds
+
+
+def build_resource_service(
+ credentials: Optional[Credentials] = None,
+ service_name: str = "gmail",
+ service_version: str = "v1",
+) -> Resource:
+ """Build a Gmail service."""
+ credentials = credentials or get_gmail_credentials()
+ builder = import_googleapiclient_resource_builder()
+ return builder(service_name, service_version, credentials=credentials)
+
+
+def clean_email_body(body: str) -> str:
+ """Clean email body."""
+ try:
+ from bs4 import BeautifulSoup
+
+ try:
+ soup = BeautifulSoup(str(body), "html.parser")
+ body = soup.get_text()
+ return str(body)
+ except Exception as e:
+ logger.error(e)
+ return str(body)
+ except ImportError:
+ logger.warning("BeautifulSoup not installed. Skipping cleaning.")
+ return str(body)
diff --git a/libs/community/langchain_community/tools/golden_query/__init__.py b/libs/community/langchain_community/tools/golden_query/__init__.py
new file mode 100644
index 00000000000..f2f2bc5339b
--- /dev/null
+++ b/libs/community/langchain_community/tools/golden_query/__init__.py
@@ -0,0 +1,8 @@
+"""Golden API toolkit."""
+
+
+from langchain_community.tools.golden_query.tool import GoldenQueryRun
+
+__all__ = [
+ "GoldenQueryRun",
+]
diff --git a/libs/community/langchain_community/tools/golden_query/tool.py b/libs/community/langchain_community/tools/golden_query/tool.py
new file mode 100644
index 00000000000..78e08cd7cf9
--- /dev/null
+++ b/libs/community/langchain_community/tools/golden_query/tool.py
@@ -0,0 +1,34 @@
+"""Tool for the Golden API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.golden_query import GoldenQueryAPIWrapper
+
+
+class GoldenQueryRun(BaseTool):
+ """Tool that adds the capability to query using the Golden API and get back JSON."""
+
+ name: str = "Golden-Query"
+ description: str = (
+ "A wrapper around Golden Query API."
+ " Useful for getting entities that match"
+ " a natural language query from Golden's Knowledge Base."
+ "\nExample queries:"
+ "\n- companies in nanotech"
+ "\n- list of cloud providers starting in 2019"
+ "\nInput should be the natural language query."
+ "\nOutput is a paginated list of results or an error object"
+ " in JSON format."
+ )
+ api_wrapper: GoldenQueryAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Golden tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/google_cloud/__init__.py b/libs/community/langchain_community/tools/google_cloud/__init__.py
new file mode 100644
index 00000000000..ec7deb89515
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_cloud/__init__.py
@@ -0,0 +1,7 @@
+"""Google Cloud Tools."""
+
+from langchain_community.tools.google_cloud.texttospeech import (
+ GoogleCloudTextToSpeechTool,
+)
+
+__all__ = ["GoogleCloudTextToSpeechTool"]
diff --git a/libs/community/langchain_community/tools/google_cloud/texttospeech.py b/libs/community/langchain_community/tools/google_cloud/texttospeech.py
new file mode 100644
index 00000000000..fc1e5852efb
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_cloud/texttospeech.py
@@ -0,0 +1,91 @@
+from __future__ import annotations
+
+import tempfile
+from typing import TYPE_CHECKING, Any, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.vertexai import get_client_info
+
+if TYPE_CHECKING:
+ from google.cloud import texttospeech
+
+
+def _import_google_cloud_texttospeech() -> Any:
+ try:
+ from google.cloud import texttospeech
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import google.cloud.texttospeech, please install "
+ "`pip install google-cloud-texttospeech`."
+ ) from e
+ return texttospeech
+
+
+def _encoding_file_extension_map(encoding: texttospeech.AudioEncoding) -> Optional[str]:
+ texttospeech = _import_google_cloud_texttospeech()
+
+ ENCODING_FILE_EXTENSION_MAP = {
+ texttospeech.AudioEncoding.LINEAR16: ".wav",
+ texttospeech.AudioEncoding.MP3: ".mp3",
+ texttospeech.AudioEncoding.OGG_OPUS: ".ogg",
+ texttospeech.AudioEncoding.MULAW: ".wav",
+ texttospeech.AudioEncoding.ALAW: ".wav",
+ }
+ return ENCODING_FILE_EXTENSION_MAP.get(encoding)
+
+
+class GoogleCloudTextToSpeechTool(BaseTool):
+ """Tool that queries the Google Cloud Text to Speech API.
+
+ In order to set this up, follow instructions at:
+ https://cloud.google.com/text-to-speech/docs/before-you-begin
+ """
+
+ name: str = "google_cloud_texttospeech"
+ description: str = (
+ "A wrapper around Google Cloud Text-to-Speech. "
+ "Useful for when you need to synthesize audio from text. "
+ "It supports multiple languages, including English, German, Polish, "
+ "Spanish, Italian, French, Portuguese, and Hindi. "
+ )
+
+ _client: Any
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initializes private fields."""
+ texttospeech = _import_google_cloud_texttospeech()
+
+ super().__init__(**kwargs)
+
+ self._client = texttospeech.TextToSpeechClient(
+ client_info=get_client_info(module="text-to-speech")
+ )
+
+ def _run(
+ self,
+ input_text: str,
+ language_code: str = "en-US",
+ ssml_gender: Optional[texttospeech.SsmlVoiceGender] = None,
+ audio_encoding: Optional[texttospeech.AudioEncoding] = None,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ texttospeech = _import_google_cloud_texttospeech()
+ ssml_gender = ssml_gender or texttospeech.SsmlVoiceGender.NEUTRAL
+ audio_encoding = audio_encoding or texttospeech.AudioEncoding.MP3
+
+ response = self._client.synthesize_speech(
+ input=texttospeech.SynthesisInput(text=input_text),
+ voice=texttospeech.VoiceSelectionParams(
+ language_code=language_code, ssml_gender=ssml_gender
+ ),
+ audio_config=texttospeech.AudioConfig(audio_encoding=audio_encoding),
+ )
+
+ suffix = _encoding_file_extension_map(audio_encoding)
+
+ with tempfile.NamedTemporaryFile(mode="bx", suffix=suffix, delete=False) as f:
+ f.write(response.audio_content)
+ return f.name
diff --git a/libs/community/langchain_community/tools/google_finance/__init__.py b/libs/community/langchain_community/tools/google_finance/__init__.py
new file mode 100644
index 00000000000..bc06ae46d56
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_finance/__init__.py
@@ -0,0 +1,5 @@
+"""Google Finance API Toolkit."""
+
+from langchain_community.tools.google_finance.tool import GoogleFinanceQueryRun
+
+__all__ = ["GoogleFinanceQueryRun"]
diff --git a/libs/community/langchain_community/tools/google_finance/tool.py b/libs/community/langchain_community/tools/google_finance/tool.py
new file mode 100644
index 00000000000..82eb82de318
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_finance/tool.py
@@ -0,0 +1,29 @@
+"""Tool for the Google Finance"""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.google_finance import GoogleFinanceAPIWrapper
+
+
+class GoogleFinanceQueryRun(BaseTool):
+ """Tool that queries the Google Finance API."""
+
+ name: str = "google_finance"
+ description: str = (
+ "A wrapper around Google Finance Search. "
+ "Useful for when you need to get information about"
+ "google search Finance from Google Finance"
+ "Input should be a search query."
+ )
+ api_wrapper: GoogleFinanceAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/google_jobs/__init__.py b/libs/community/langchain_community/tools/google_jobs/__init__.py
new file mode 100644
index 00000000000..f23e0eecffb
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_jobs/__init__.py
@@ -0,0 +1,5 @@
+"""Google Jobs API Toolkit."""
+
+from langchain_community.tools.google_jobs.tool import GoogleJobsQueryRun
+
+__all__ = ["GoogleJobsQueryRun"]
diff --git a/libs/community/langchain_community/tools/google_jobs/tool.py b/libs/community/langchain_community/tools/google_jobs/tool.py
new file mode 100644
index 00000000000..6a83b3043d9
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_jobs/tool.py
@@ -0,0 +1,29 @@
+"""Tool for the Google Trends"""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
+
+
+class GoogleJobsQueryRun(BaseTool):
+ """Tool that queries the Google Jobs API."""
+
+ name: str = "google_jobs"
+ description: str = (
+ "A wrapper around Google Jobs Search. "
+ "Useful for when you need to get information about"
+ "google search Jobs from Google Jobs"
+ "Input should be a search query."
+ )
+ api_wrapper: GoogleJobsAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/google_lens/__init__.py b/libs/community/langchain_community/tools/google_lens/__init__.py
new file mode 100644
index 00000000000..15a0c179379
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_lens/__init__.py
@@ -0,0 +1,5 @@
+"""Google Lens API Toolkit."""
+
+from langchain_community.tools.google_lens.tool import GoogleLensQueryRun
+
+__all__ = ["GoogleLensQueryRun"]
diff --git a/libs/community/langchain_community/tools/google_lens/tool.py b/libs/community/langchain_community/tools/google_lens/tool.py
new file mode 100644
index 00000000000..0a69739e352
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_lens/tool.py
@@ -0,0 +1,29 @@
+"""Tool for the Google Lens"""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.google_lens import GoogleLensAPIWrapper
+
+
+class GoogleLensQueryRun(BaseTool):
+ """Tool that queries the Google Lens API."""
+
+ name: str = "google_Lens"
+ description: str = (
+ "A wrapper around Google Lens Search. "
+ "Useful for when you need to get information related"
+ "to an image from Google Lens"
+ "Input should be a url to an image."
+ )
+ api_wrapper: GoogleLensAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/google_places/__init__.py b/libs/community/langchain_community/tools/google_places/__init__.py
new file mode 100644
index 00000000000..6d3b948ea58
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_places/__init__.py
@@ -0,0 +1,5 @@
+"""Google Places API Toolkit."""
+
+from langchain_community.tools.google_places.tool import GooglePlacesTool
+
+__all__ = ["GooglePlacesTool"]
diff --git a/libs/community/langchain_community/tools/google_places/tool.py b/libs/community/langchain_community/tools/google_places/tool.py
new file mode 100644
index 00000000000..d198350126e
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_places/tool.py
@@ -0,0 +1,37 @@
+"""Tool for the Google search API."""
+
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.google_places_api import GooglePlacesAPIWrapper
+
+
+class GooglePlacesSchema(BaseModel):
+ """Input for GooglePlacesTool."""
+
+ query: str = Field(..., description="Query for google maps")
+
+
+class GooglePlacesTool(BaseTool):
+ """Tool that queries the Google places API."""
+
+ name: str = "google_places"
+ description: str = (
+ "A wrapper around Google Places. "
+ "Useful for when you need to validate or "
+ "discover addressed from ambiguous text. "
+ "Input should be a search query."
+ )
+ api_wrapper: GooglePlacesAPIWrapper = Field(default_factory=GooglePlacesAPIWrapper)
+ args_schema: Type[BaseModel] = GooglePlacesSchema
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/google_scholar/__init__.py b/libs/community/langchain_community/tools/google_scholar/__init__.py
new file mode 100644
index 00000000000..b83e5dfc1ef
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_scholar/__init__.py
@@ -0,0 +1,5 @@
+"""Google Scholar API Toolkit."""
+
+from langchain_community.tools.google_scholar.tool import GoogleScholarQueryRun
+
+__all__ = ["GoogleScholarQueryRun"]
diff --git a/libs/community/langchain_community/tools/google_scholar/tool.py b/libs/community/langchain_community/tools/google_scholar/tool.py
new file mode 100644
index 00000000000..49f8769696f
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_scholar/tool.py
@@ -0,0 +1,29 @@
+"""Tool for the Google Scholar"""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.google_scholar import GoogleScholarAPIWrapper
+
+
+class GoogleScholarQueryRun(BaseTool):
+ """Tool that queries the Google search API."""
+
+ name: str = "google_scholar"
+ description: str = (
+ "A wrapper around Google Scholar Search. "
+ "Useful for when you need to get information about"
+ "research papers from Google Scholar"
+ "Input should be a search query."
+ )
+ api_wrapper: GoogleScholarAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/google_search/__init__.py b/libs/community/langchain_community/tools/google_search/__init__.py
new file mode 100644
index 00000000000..08eccf0a318
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_search/__init__.py
@@ -0,0 +1,8 @@
+"""Google Search API Toolkit."""
+
+from langchain_community.tools.google_search.tool import (
+ GoogleSearchResults,
+ GoogleSearchRun,
+)
+
+__all__ = ["GoogleSearchRun", "GoogleSearchResults"]
diff --git a/libs/community/langchain_community/tools/google_search/tool.py b/libs/community/langchain_community/tools/google_search/tool.py
new file mode 100644
index 00000000000..4b6ea4685b4
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_search/tool.py
@@ -0,0 +1,49 @@
+"""Tool for the Google search API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.google_search import GoogleSearchAPIWrapper
+
+
+class GoogleSearchRun(BaseTool):
+ """Tool that queries the Google search API."""
+
+ name: str = "google_search"
+ description: str = (
+ "A wrapper around Google Search. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query."
+ )
+ api_wrapper: GoogleSearchAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
+
+
+class GoogleSearchResults(BaseTool):
+ """Tool that queries the Google Search API and gets back json."""
+
+ name: str = "Google Search Results JSON"
+ description: str = (
+ "A wrapper around Google Search. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query. Output is a JSON array of the query results"
+ )
+ num_results: int = 4
+ api_wrapper: GoogleSearchAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return str(self.api_wrapper.results(query, self.num_results))
diff --git a/libs/community/langchain_community/tools/google_serper/__init__.py b/libs/community/langchain_community/tools/google_serper/__init__.py
new file mode 100644
index 00000000000..413481a645b
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_serper/__init__.py
@@ -0,0 +1,9 @@
+from langchain_community.tools.google_serper.tool import (
+ GoogleSerperResults,
+ GoogleSerperRun,
+)
+
+"""Google Serper API Toolkit."""
+"""Tool for the Serer.dev Google Search API."""
+
+__all__ = ["GoogleSerperRun", "GoogleSerperResults"]
diff --git a/libs/community/langchain_community/tools/google_serper/tool.py b/libs/community/langchain_community/tools/google_serper/tool.py
new file mode 100644
index 00000000000..b94771a894f
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_serper/tool.py
@@ -0,0 +1,70 @@
+"""Tool for the Serper.dev Google Search API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.google_serper import GoogleSerperAPIWrapper
+
+
+class GoogleSerperRun(BaseTool):
+ """Tool that queries the Serper.dev Google search API."""
+
+ name: str = "google_serper"
+ description: str = (
+ "A low-cost Google Search API."
+ "Useful for when you need to answer questions about current events."
+ "Input should be a search query."
+ )
+ api_wrapper: GoogleSerperAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return str(self.api_wrapper.run(query))
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ return (await self.api_wrapper.arun(query)).__str__()
+
+
+class GoogleSerperResults(BaseTool):
+ """Tool that queries the Serper.dev Google Search API
+ and get back json."""
+
+ name: str = "google_serper_results_json"
+ description: str = (
+ "A low-cost Google Search API."
+ "Useful for when you need to answer questions about current events."
+ "Input should be a search query. Output is a JSON object of the query results"
+ )
+ api_wrapper: GoogleSerperAPIWrapper = Field(default_factory=GoogleSerperAPIWrapper)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return str(self.api_wrapper.results(query))
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+
+ return (await self.api_wrapper.aresults(query)).__str__()
diff --git a/libs/community/langchain_community/tools/google_trends/__init__.py b/libs/community/langchain_community/tools/google_trends/__init__.py
new file mode 100644
index 00000000000..ca3d58fc595
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_trends/__init__.py
@@ -0,0 +1,5 @@
+"""Google Trends API Toolkit."""
+
+from langchain_community.tools.google_trends.tool import GoogleTrendsQueryRun
+
+__all__ = ["GoogleTrendsQueryRun"]
diff --git a/libs/community/langchain_community/tools/google_trends/tool.py b/libs/community/langchain_community/tools/google_trends/tool.py
new file mode 100644
index 00000000000..8b2b5dd8bfb
--- /dev/null
+++ b/libs/community/langchain_community/tools/google_trends/tool.py
@@ -0,0 +1,29 @@
+"""Tool for the Google Trends"""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
+
+
+class GoogleTrendsQueryRun(BaseTool):
+ """Tool that queries the Google trends API."""
+
+ name: str = "google_trends"
+ description: str = (
+ "A wrapper around Google Trends Search. "
+ "Useful for when you need to get information about"
+ "google search trends from Google Trends"
+ "Input should be a search query."
+ )
+ api_wrapper: GoogleTrendsAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/graphql/__init__.py b/libs/community/langchain_community/tools/graphql/__init__.py
new file mode 100644
index 00000000000..7e9a84c3772
--- /dev/null
+++ b/libs/community/langchain_community/tools/graphql/__init__.py
@@ -0,0 +1 @@
+"""Tools for interacting with a GraphQL API"""
diff --git a/libs/community/langchain_community/tools/graphql/tool.py b/libs/community/langchain_community/tools/graphql/tool.py
new file mode 100644
index 00000000000..a794334cbdb
--- /dev/null
+++ b/libs/community/langchain_community/tools/graphql/tool.py
@@ -0,0 +1,36 @@
+import json
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.graphql import GraphQLAPIWrapper
+
+
+class BaseGraphQLTool(BaseTool):
+ """Base tool for querying a GraphQL API."""
+
+ graphql_wrapper: GraphQLAPIWrapper
+
+ name: str = "query_graphql"
+ description: str = """\
+ Input to this tool is a detailed and correct GraphQL query, output is a result from the API.
+ If the query is not correct, an error message will be returned.
+ If an error is returned with 'Bad request' in it, rewrite the query and try again.
+ If an error is returned with 'Unauthorized' in it, do not try again, but tell the user to change their authentication.
+
+ Example Input: query {{ allUsers {{ id, name, email }} }}\
+ """ # noqa: E501
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def _run(
+ self,
+ tool_input: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ result = self.graphql_wrapper.run(tool_input)
+ return json.dumps(result, indent=2)
diff --git a/libs/community/langchain_community/tools/human/__init__.py b/libs/community/langchain_community/tools/human/__init__.py
new file mode 100644
index 00000000000..084487d0f9b
--- /dev/null
+++ b/libs/community/langchain_community/tools/human/__init__.py
@@ -0,0 +1,5 @@
+"""Tool for asking for human input."""
+
+from langchain_community.tools.human.tool import HumanInputRun
+
+__all__ = ["HumanInputRun"]
diff --git a/libs/community/langchain_community/tools/human/tool.py b/libs/community/langchain_community/tools/human/tool.py
new file mode 100644
index 00000000000..7b7911577ce
--- /dev/null
+++ b/libs/community/langchain_community/tools/human/tool.py
@@ -0,0 +1,34 @@
+"""Tool for asking human input."""
+
+from typing import Callable, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+
+def _print_func(text: str) -> None:
+ print("\n")
+ print(text)
+
+
+class HumanInputRun(BaseTool):
+ """Tool that asks user for input."""
+
+ name: str = "human"
+ description: str = (
+ "You can ask a human for guidance when you think you "
+ "got stuck or you are not sure what to do next. "
+ "The input should be a question for the human."
+ )
+ prompt_func: Callable[[str], None] = Field(default_factory=lambda: _print_func)
+ input_func: Callable = Field(default_factory=lambda: input)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Human input tool."""
+ self.prompt_func(query)
+ return self.input_func()
diff --git a/libs/community/langchain_community/tools/ifttt.py b/libs/community/langchain_community/tools/ifttt.py
new file mode 100644
index 00000000000..5df673061dc
--- /dev/null
+++ b/libs/community/langchain_community/tools/ifttt.py
@@ -0,0 +1,60 @@
+"""From https://github.com/SidU/teams-langchain-js/wiki/Connecting-IFTTT-Services.
+
+# Creating a webhook
+- Go to https://ifttt.com/create
+
+# Configuring the "If This"
+- Click on the "If This" button in the IFTTT interface.
+- Search for "Webhooks" in the search bar.
+- Choose the first option for "Receive a web request with a JSON payload."
+- Choose an Event Name that is specific to the service you plan to connect to.
+This will make it easier for you to manage the webhook URL.
+For example, if you're connecting to Spotify, you could use "Spotify" as your
+Event Name.
+- Click the "Create Trigger" button to save your settings and create your webhook.
+
+# Configuring the "Then That"
+- Tap on the "Then That" button in the IFTTT interface.
+- Search for the service you want to connect, such as Spotify.
+- Choose an action from the service, such as "Add track to a playlist".
+- Configure the action by specifying the necessary details, such as the playlist name,
+e.g., "Songs from AI".
+- Reference the JSON Payload received by the Webhook in your action. For the Spotify
+scenario, choose "{{JsonPayload}}" as your search query.
+- Tap the "Create Action" button to save your action settings.
+- Once you have finished configuring your action, click the "Finish" button to
+complete the setup.
+- Congratulations! You have successfully connected the Webhook to the desired
+service, and you're ready to start receiving data and triggering actions π
+
+# Finishing up
+- To get your webhook URL go to https://ifttt.com/maker_webhooks/settings
+- Copy the IFTTT key value from there. The URL is of the form
+https://maker.ifttt.com/use/YOUR_IFTTT_KEY. Grab the YOUR_IFTTT_KEY value.
+"""
+from typing import Optional
+
+import requests
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+
+class IFTTTWebhook(BaseTool):
+ """IFTTT Webhook.
+
+ Args:
+ name: name of the tool
+ description: description of the tool
+ url: url to hit with the json event.
+ """
+
+ url: str
+
+ def _run(
+ self,
+ tool_input: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ body = {"this": tool_input}
+ response = requests.post(self.url, data=body)
+ return response.text
diff --git a/libs/community/langchain_community/tools/interaction/__init__.py b/libs/community/langchain_community/tools/interaction/__init__.py
new file mode 100644
index 00000000000..be3393362d8
--- /dev/null
+++ b/libs/community/langchain_community/tools/interaction/__init__.py
@@ -0,0 +1 @@
+"""Tools for interacting with the user."""
diff --git a/libs/community/langchain_community/tools/interaction/tool.py b/libs/community/langchain_community/tools/interaction/tool.py
new file mode 100644
index 00000000000..5ec6ceef501
--- /dev/null
+++ b/libs/community/langchain_community/tools/interaction/tool.py
@@ -0,0 +1,17 @@
+"""Tools for interacting with the user."""
+
+
+import warnings
+from typing import Any
+
+from langchain_community.tools.human.tool import HumanInputRun
+
+
+def StdInInquireTool(*args: Any, **kwargs: Any) -> HumanInputRun:
+ """Tool for asking the user for input."""
+ warnings.warn(
+ "StdInInquireTool will be deprecated in the future. "
+ "Please use HumanInputRun instead.",
+ DeprecationWarning,
+ )
+ return HumanInputRun(*args, **kwargs)
diff --git a/libs/community/langchain_community/tools/jira/__init__.py b/libs/community/langchain_community/tools/jira/__init__.py
new file mode 100644
index 00000000000..06cd8cbcd9e
--- /dev/null
+++ b/libs/community/langchain_community/tools/jira/__init__.py
@@ -0,0 +1 @@
+"""Jira Tool."""
diff --git a/libs/community/langchain_community/tools/jira/prompt.py b/libs/community/langchain_community/tools/jira/prompt.py
new file mode 100644
index 00000000000..08f06023040
--- /dev/null
+++ b/libs/community/langchain_community/tools/jira/prompt.py
@@ -0,0 +1,42 @@
+# flake8: noqa
+JIRA_ISSUE_CREATE_PROMPT = """
+ This tool is a wrapper around atlassian-python-api's Jira issue_create API, useful when you need to create a Jira issue.
+ The input to this tool is a dictionary specifying the fields of the Jira issue, and will be passed into atlassian-python-api's Jira `issue_create` function.
+ For example, to create a low priority task called "test issue" with description "test description", you would pass in the following dictionary:
+ {{"summary": "test issue", "description": "test description", "issuetype": {{"name": "Task"}}, "priority": {{"name": "Low"}}}}
+ """
+
+JIRA_GET_ALL_PROJECTS_PROMPT = """
+ This tool is a wrapper around atlassian-python-api's Jira project API,
+ useful when you need to fetch all the projects the user has access to, find out how many projects there are, or as an intermediary step that involv searching by projects.
+ there is no input to this tool.
+ """
+
+JIRA_JQL_PROMPT = """
+ This tool is a wrapper around atlassian-python-api's Jira jql API, useful when you need to search for Jira issues.
+ The input to this tool is a JQL query string, and will be passed into atlassian-python-api's Jira `jql` function,
+ For example, to find all the issues in project "Test" assigned to the me, you would pass in the following string:
+ project = Test AND assignee = currentUser()
+ or to find issues with summaries that contain the word "test", you would pass in the following string:
+ summary ~ 'test'
+ """
+
+JIRA_CATCH_ALL_PROMPT = """
+ This tool is a wrapper around atlassian-python-api's Jira API.
+ There are other dedicated tools for fetching all projects, and creating and searching for issues,
+ use this tool if you need to perform any other actions allowed by the atlassian-python-api Jira API.
+ The input to this tool is a dictionary specifying a function from atlassian-python-api's Jira API,
+ as well as a list of arguments and dictionary of keyword arguments to pass into the function.
+ For example, to get all the users in a group, while increasing the max number of results to 100, you would
+ pass in the following dictionary: {{"function": "get_all_users_from_group", "args": ["group"], "kwargs": {{"limit":100}} }}
+ or to find out how many projects are in the Jira instance, you would pass in the following string:
+ {{"function": "projects"}}
+ For more information on the Jira API, refer to https://atlassian-python-api.readthedocs.io/jira.html
+ """
+
+JIRA_CONFLUENCE_PAGE_CREATE_PROMPT = """This tool is a wrapper around atlassian-python-api's Confluence
+atlassian-python-api API, useful when you need to create a Confluence page. The input to this tool is a dictionary
+specifying the fields of the Confluence page, and will be passed into atlassian-python-api's Confluence `create_page`
+function. For example, to create a page in the DEMO space titled "This is the title" with body "This is the body. You can use
+HTML tags!", you would pass in the following dictionary: {{"space": "DEMO", "title":"This is the
+title","body":"This is the body. You can use HTML tags!"}} """
diff --git a/libs/community/langchain_community/tools/jira/tool.py b/libs/community/langchain_community/tools/jira/tool.py
new file mode 100644
index 00000000000..dc57b13dc20
--- /dev/null
+++ b/libs/community/langchain_community/tools/jira/tool.py
@@ -0,0 +1,44 @@
+"""
+This tool allows agents to interact with the atlassian-python-api library
+and operate on a Jira instance. For more information on the
+atlassian-python-api library, see https://atlassian-python-api.readthedocs.io/jira.html
+
+To use this tool, you must first set as environment variables:
+ JIRA_API_TOKEN
+ JIRA_USERNAME
+ JIRA_INSTANCE_URL
+
+Below is a sample script that uses the Jira tool:
+
+```python
+from langchain_community.agent_toolkits.jira.toolkit import JiraToolkit
+from langchain_community.utilities.jira import JiraAPIWrapper
+
+jira = JiraAPIWrapper()
+toolkit = JiraToolkit.from_jira_api_wrapper(jira)
+```
+"""
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.jira import JiraAPIWrapper
+
+
+class JiraAction(BaseTool):
+ """Tool that queries the Atlassian Jira API."""
+
+ api_wrapper: JiraAPIWrapper = Field(default_factory=JiraAPIWrapper)
+ mode: str
+ name: str = ""
+ description: str = ""
+
+ def _run(
+ self,
+ instructions: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Atlassian Jira API to run an operation."""
+ return self.api_wrapper.run(self.mode, instructions)
diff --git a/libs/community/langchain_community/tools/json/__init__.py b/libs/community/langchain_community/tools/json/__init__.py
new file mode 100644
index 00000000000..d13302f008a
--- /dev/null
+++ b/libs/community/langchain_community/tools/json/__init__.py
@@ -0,0 +1 @@
+"""Tools for interacting with a JSON file."""
diff --git a/libs/community/langchain_community/tools/json/tool.py b/libs/community/langchain_community/tools/json/tool.py
new file mode 100644
index 00000000000..af392c96cef
--- /dev/null
+++ b/libs/community/langchain_community/tools/json/tool.py
@@ -0,0 +1,133 @@
+# flake8: noqa
+"""Tools for working with JSON specs."""
+from __future__ import annotations
+
+import json
+import re
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+from langchain_core.pydantic_v1 import BaseModel
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.tools import BaseTool
+
+
+def _parse_input(text: str) -> List[Union[str, int]]:
+ """Parse input of the form data["key1"][0]["key2"] into a list of keys."""
+ _res = re.findall(r"\[.*?]", text)
+ # strip the brackets and quotes, convert to int if possible
+ res = [i[1:-1].replace('"', "").replace("'", "") for i in _res]
+ res = [int(i) if i.isdigit() else i for i in res]
+ return res
+
+
+class JsonSpec(BaseModel):
+ """Base class for JSON spec."""
+
+ dict_: Dict
+ max_value_length: int = 200
+
+ @classmethod
+ def from_file(cls, path: Path) -> JsonSpec:
+ """Create a JsonSpec from a file."""
+ if not path.exists():
+ raise FileNotFoundError(f"File not found: {path}")
+ dict_ = json.loads(path.read_text())
+ return cls(dict_=dict_)
+
+ def keys(self, text: str) -> str:
+ """Return the keys of the dict at the given path.
+
+ Args:
+ text: Python representation of the path to the dict (e.g. data["key1"][0]["key2"]).
+ """
+ try:
+ items = _parse_input(text)
+ val = self.dict_
+ for i in items:
+ if i:
+ val = val[i]
+ if not isinstance(val, dict):
+ raise ValueError(
+ f"Value at path `{text}` is not a dict, get the value directly."
+ )
+ return str(list(val.keys()))
+ except Exception as e:
+ return repr(e)
+
+ def value(self, text: str) -> str:
+ """Return the value of the dict at the given path.
+
+ Args:
+ text: Python representation of the path to the dict (e.g. data["key1"][0]["key2"]).
+ """
+ try:
+ items = _parse_input(text)
+ val = self.dict_
+ for i in items:
+ val = val[i]
+
+ if isinstance(val, dict) and len(str(val)) > self.max_value_length:
+ return "Value is a large dictionary, should explore its keys directly"
+ str_val = str(val)
+ if len(str_val) > self.max_value_length:
+ str_val = str_val[: self.max_value_length] + "..."
+ return str_val
+ except Exception as e:
+ return repr(e)
+
+
+class JsonListKeysTool(BaseTool):
+ """Tool for listing keys in a JSON spec."""
+
+ name: str = "json_spec_list_keys"
+ description: str = """
+ Can be used to list all keys at a given path.
+ Before calling this you should be SURE that the path to this exists.
+ The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]).
+ """
+ spec: JsonSpec
+
+ def _run(
+ self,
+ tool_input: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ return self.spec.keys(tool_input)
+
+ async def _arun(
+ self,
+ tool_input: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ return self._run(tool_input)
+
+
+class JsonGetValueTool(BaseTool):
+ """Tool for getting a value in a JSON spec."""
+
+ name: str = "json_spec_get_value"
+ description: str = """
+ Can be used to see value in string format at a given path.
+ Before calling this you should be SURE that the path to this exists.
+ The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]).
+ """
+ spec: JsonSpec
+
+ def _run(
+ self,
+ tool_input: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ return self.spec.value(tool_input)
+
+ async def _arun(
+ self,
+ tool_input: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ return self._run(tool_input)
diff --git a/libs/community/langchain_community/tools/memorize/__init__.py b/libs/community/langchain_community/tools/memorize/__init__.py
new file mode 100644
index 00000000000..76a84406ace
--- /dev/null
+++ b/libs/community/langchain_community/tools/memorize/__init__.py
@@ -0,0 +1,5 @@
+"""Unsupervised learning based memorization."""
+
+from langchain_community.tools.memorize.tool import Memorize
+
+__all__ = ["Memorize"]
diff --git a/libs/community/langchain_community/tools/memorize/tool.py b/libs/community/langchain_community/tools/memorize/tool.py
new file mode 100644
index 00000000000..2813c870dcc
--- /dev/null
+++ b/libs/community/langchain_community/tools/memorize/tool.py
@@ -0,0 +1,58 @@
+from abc import abstractmethod
+from typing import Any, Optional, Protocol, Sequence, runtime_checkable
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.llms.gradient_ai import TrainResult
+
+
+@runtime_checkable
+class TrainableLLM(Protocol):
+ @abstractmethod
+ def train_unsupervised(
+ self,
+ inputs: Sequence[str],
+ **kwargs: Any,
+ ) -> TrainResult:
+ ...
+
+ @abstractmethod
+ async def atrain_unsupervised(
+ self,
+ inputs: Sequence[str],
+ **kwargs: Any,
+ ) -> TrainResult:
+ ...
+
+
+class Memorize(BaseTool):
+ name: str = "Memorize"
+ description: str = (
+ "Useful whenever you observed novel information "
+ "from previous conversation history, "
+ "i.e., another tool's action outputs or human comments. "
+ "The action input should include observed information in detail, "
+ "then the tool will fine-tune yourself to remember it."
+ )
+ llm: TrainableLLM = Field()
+
+ def _run(
+ self,
+ information_to_learn: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ train_result = self.llm.train_unsupervised((information_to_learn,))
+ return f"Train complete. Loss: {train_result['loss']}"
+
+ async def _arun(
+ self,
+ information_to_learn: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ train_result = await self.llm.atrain_unsupervised((information_to_learn,))
+ return f"Train complete. Loss: {train_result['loss']}"
diff --git a/libs/community/langchain_community/tools/merriam_webster/__init__.py b/libs/community/langchain_community/tools/merriam_webster/__init__.py
new file mode 100644
index 00000000000..73390d54980
--- /dev/null
+++ b/libs/community/langchain_community/tools/merriam_webster/__init__.py
@@ -0,0 +1 @@
+"""Merriam-Webster API toolkit."""
diff --git a/libs/community/langchain_community/tools/merriam_webster/tool.py b/libs/community/langchain_community/tools/merriam_webster/tool.py
new file mode 100644
index 00000000000..a3d1fa881bb
--- /dev/null
+++ b/libs/community/langchain_community/tools/merriam_webster/tool.py
@@ -0,0 +1,28 @@
+"""Tool for the Merriam-Webster API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.merriam_webster import MerriamWebsterAPIWrapper
+
+
+class MerriamWebsterQueryRun(BaseTool):
+ """Tool that searches the Merriam-Webster API."""
+
+ name: str = "MerriamWebster"
+ description: str = (
+ "A wrapper around Merriam-Webster. "
+ "Useful for when you need to get the definition of a word."
+ "Input should be the word you want the definition of."
+ )
+ api_wrapper: MerriamWebsterAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Merriam-Webster tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/metaphor_search/__init__.py b/libs/community/langchain_community/tools/metaphor_search/__init__.py
new file mode 100644
index 00000000000..246f25a1291
--- /dev/null
+++ b/libs/community/langchain_community/tools/metaphor_search/__init__.py
@@ -0,0 +1,5 @@
+"""Metaphor Search API toolkit."""
+
+from langchain_community.tools.metaphor_search.tool import MetaphorSearchResults
+
+__all__ = ["MetaphorSearchResults"]
diff --git a/libs/community/langchain_community/tools/metaphor_search/tool.py b/libs/community/langchain_community/tools/metaphor_search/tool.py
new file mode 100644
index 00000000000..d81ea0a0ba3
--- /dev/null
+++ b/libs/community/langchain_community/tools/metaphor_search/tool.py
@@ -0,0 +1,81 @@
+"""Tool for the Metaphor search API."""
+
+from typing import Dict, List, Optional, Union
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.metaphor_search import MetaphorSearchAPIWrapper
+
+
+class MetaphorSearchResults(BaseTool):
+ """Tool that queries the Metaphor Search API and gets back json."""
+
+ name: str = "metaphor_search_results_json"
+ description: str = (
+ "A wrapper around Metaphor Search. "
+ "Input should be a Metaphor-optimized query. "
+ "Output is a JSON array of the query results"
+ )
+ api_wrapper: MetaphorSearchAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ num_results: int,
+ include_domains: Optional[List[str]] = None,
+ exclude_domains: Optional[List[str]] = None,
+ start_crawl_date: Optional[str] = None,
+ end_crawl_date: Optional[str] = None,
+ start_published_date: Optional[str] = None,
+ end_published_date: Optional[str] = None,
+ use_autoprompt: Optional[bool] = None,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> Union[List[Dict], str]:
+ """Use the tool."""
+ try:
+ return self.api_wrapper.results(
+ query,
+ num_results,
+ include_domains,
+ exclude_domains,
+ start_crawl_date,
+ end_crawl_date,
+ start_published_date,
+ end_published_date,
+ use_autoprompt,
+ )
+ except Exception as e:
+ return repr(e)
+
+ async def _arun(
+ self,
+ query: str,
+ num_results: int,
+ include_domains: Optional[List[str]] = None,
+ exclude_domains: Optional[List[str]] = None,
+ start_crawl_date: Optional[str] = None,
+ end_crawl_date: Optional[str] = None,
+ start_published_date: Optional[str] = None,
+ end_published_date: Optional[str] = None,
+ use_autoprompt: Optional[bool] = None,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> Union[List[Dict], str]:
+ """Use the tool asynchronously."""
+ try:
+ return await self.api_wrapper.results_async(
+ query,
+ num_results,
+ include_domains,
+ exclude_domains,
+ start_crawl_date,
+ end_crawl_date,
+ start_published_date,
+ end_published_date,
+ use_autoprompt,
+ )
+ except Exception as e:
+ return repr(e)
diff --git a/libs/community/langchain_community/tools/multion/__init__.py b/libs/community/langchain_community/tools/multion/__init__.py
new file mode 100644
index 00000000000..693ddcffc90
--- /dev/null
+++ b/libs/community/langchain_community/tools/multion/__init__.py
@@ -0,0 +1,6 @@
+"""MutliOn Client API tools."""
+from langchain_community.tools.multion.close_session import MultionCloseSession
+from langchain_community.tools.multion.create_session import MultionCreateSession
+from langchain_community.tools.multion.update_session import MultionUpdateSession
+
+__all__ = ["MultionCreateSession", "MultionUpdateSession", "MultionCloseSession"]
diff --git a/libs/community/langchain_community/tools/multion/close_session.py b/libs/community/langchain_community/tools/multion/close_session.py
new file mode 100644
index 00000000000..7aaead7fa0c
--- /dev/null
+++ b/libs/community/langchain_community/tools/multion/close_session.py
@@ -0,0 +1,67 @@
+import asyncio
+from typing import TYPE_CHECKING, Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+if TYPE_CHECKING:
+ # This is for linting and IDE typehints
+ import multion
+else:
+ try:
+ # We do this so pydantic can resolve the types when instantiating
+ import multion
+ except ImportError:
+ pass
+
+
+class CloseSessionSchema(BaseModel):
+ """Input for UpdateSessionTool."""
+
+ sessionId: str = Field(
+ ...,
+ description="""The sessionId, received from one of the createSessions
+ or updateSessions run before""",
+ )
+
+
+class MultionCloseSession(BaseTool):
+ """Tool that closes an existing Multion Browser Window with provided fields.
+
+ Attributes:
+ name: The name of the tool. Default: "close_multion_session"
+ description: The description of the tool.
+ args_schema: The schema for the tool's arguments. Default: UpdateSessionSchema
+ """
+
+ name: str = "close_multion_session"
+ description: str = """Use this tool to close \
+an existing corresponding Multion Browser Window with provided fields. \
+Note: SessionId must be received from previous Browser window creation."""
+ args_schema: Type[CloseSessionSchema] = CloseSessionSchema
+ sessionId: str = ""
+
+ def _run(
+ self,
+ sessionId: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> None:
+ try:
+ try:
+ multion.close_session(sessionId)
+ except Exception as e:
+ print(f"{e}, retrying...")
+ except Exception as e:
+ raise Exception(f"An error occurred: {e}")
+
+ async def _arun(
+ self,
+ sessionId: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> None:
+ loop = asyncio.get_running_loop()
+ await loop.run_in_executor(None, self._run, sessionId)
diff --git a/libs/community/langchain_community/tools/multion/create_session.py b/libs/community/langchain_community/tools/multion/create_session.py
new file mode 100644
index 00000000000..9f93676ee18
--- /dev/null
+++ b/libs/community/langchain_community/tools/multion/create_session.py
@@ -0,0 +1,80 @@
+import asyncio
+from typing import TYPE_CHECKING, Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+if TYPE_CHECKING:
+ # This is for linting and IDE typehints
+ import multion
+else:
+ try:
+ # We do this so pydantic can resolve the types when instantiating
+ import multion
+ except ImportError:
+ pass
+
+
+class CreateSessionSchema(BaseModel):
+ """Input for CreateSessionTool."""
+
+ query: str = Field(
+ ...,
+ description="The query to run in multion agent.",
+ )
+ url: str = Field(
+ "https://www.google.com/",
+ description="""The Url to run the agent at. Note: accepts only secure \
+ links having https://""",
+ )
+
+
+class MultionCreateSession(BaseTool):
+ """Tool that creates a new Multion Browser Window with provided fields.
+
+ Attributes:
+ name: The name of the tool. Default: "create_multion_session"
+ description: The description of the tool.
+ args_schema: The schema for the tool's arguments.
+ """
+
+ name: str = "create_multion_session"
+ description: str = """
+ Create a new web browsing session based on a user's command or request. \
+ The command should include the full info required for the session. \
+ Also include an url (defaults to google.com if no better option) \
+ to start the session. \
+ Use this tool to create a new Browser Window with provided fields. \
+ Always the first step to run any activities that can be done using browser.
+ """
+ args_schema: Type[CreateSessionSchema] = CreateSessionSchema
+
+ def _run(
+ self,
+ query: str,
+ url: Optional[str] = "https://www.google.com/",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> dict:
+ try:
+ response = multion.new_session({"input": query, "url": url})
+ return {
+ "sessionId": response["session_id"],
+ "Response": response["message"],
+ }
+ except Exception as e:
+ raise Exception(f"An error occurred: {e}")
+
+ async def _arun(
+ self,
+ query: str,
+ url: Optional[str] = "https://www.google.com/",
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> dict:
+ loop = asyncio.get_running_loop()
+ result = await loop.run_in_executor(None, self._run, query, url)
+
+ return result
diff --git a/libs/community/langchain_community/tools/multion/update_session.py b/libs/community/langchain_community/tools/multion/update_session.py
new file mode 100644
index 00000000000..97a8f1ff4a3
--- /dev/null
+++ b/libs/community/langchain_community/tools/multion/update_session.py
@@ -0,0 +1,88 @@
+import asyncio
+from typing import TYPE_CHECKING, Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+if TYPE_CHECKING:
+ # This is for linting and IDE typehints
+ import multion
+else:
+ try:
+ # We do this so pydantic can resolve the types when instantiating
+ import multion
+ except ImportError:
+ pass
+
+
+class UpdateSessionSchema(BaseModel):
+ """Input for UpdateSessionTool."""
+
+ sessionId: str = Field(
+ ...,
+ description="""The sessionID,
+ received from one of the createSessions run before""",
+ )
+ query: str = Field(
+ ...,
+ description="The query to run in multion agent.",
+ )
+ url: str = Field(
+ "https://www.google.com/",
+ description="""The Url to run the agent at. \
+ Note: accepts only secure links having https://""",
+ )
+
+
+class MultionUpdateSession(BaseTool):
+ """Tool that updates an existing Multion Browser Window with provided fields.
+
+ Attributes:
+ name: The name of the tool. Default: "update_multion_session"
+ description: The description of the tool.
+ args_schema: The schema for the tool's arguments. Default: UpdateSessionSchema
+ """
+
+ name: str = "update_multion_session"
+ description: str = """Use this tool to update \
+an existing corresponding Multion Browser Window with provided fields. \
+Note: sessionId must be received from previous Browser window creation."""
+ args_schema: Type[UpdateSessionSchema] = UpdateSessionSchema
+ sessionId: str = ""
+
+ def _run(
+ self,
+ sessionId: str,
+ query: str,
+ url: Optional[str] = "https://www.google.com/",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> dict:
+ try:
+ try:
+ response = multion.update_session(
+ sessionId, {"input": query, "url": url}
+ )
+ content = {"sessionId": sessionId, "Response": response["message"]}
+ self.sessionId = sessionId
+ return content
+ except Exception as e:
+ print(f"{e}, retrying...")
+ return {"error": f"{e}", "Response": "retrying..."}
+ except Exception as e:
+ raise Exception(f"An error occurred: {e}")
+
+ async def _arun(
+ self,
+ sessionId: str,
+ query: str,
+ url: Optional[str] = "https://www.google.com/",
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> dict:
+ loop = asyncio.get_running_loop()
+ result = await loop.run_in_executor(None, self._run, sessionId, query, url)
+
+ return result
diff --git a/libs/langchain/tests/integration_tests/tools/nuclia/__init__.py b/libs/community/langchain_community/tools/nasa/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/tools/nuclia/__init__.py
rename to libs/community/langchain_community/tools/nasa/__init__.py
diff --git a/libs/community/langchain_community/tools/nasa/prompt.py b/libs/community/langchain_community/tools/nasa/prompt.py
new file mode 100644
index 00000000000..4c7a3846a7e
--- /dev/null
+++ b/libs/community/langchain_community/tools/nasa/prompt.py
@@ -0,0 +1,82 @@
+# flake8: noqa
+NASA_SEARCH_PROMPT = """
+ This tool is a wrapper around NASA's search API, useful when you need to search through NASA's Image and Video Library.
+ The input to this tool is a query specified by the user, and will be passed into NASA's `search` function.
+
+ At least one parameter must be provided.
+
+ There are optional parameters that can be passed by the user based on their query
+ specifications. Each item in this list contains pound sign (#) separated values, the first value is the parameter name,
+ the second value is the datatype and the third value is the description: {{
+
+ - q#string#Free text search terms to compare to all indexed metadata.
+ - center#string#NASA center which published the media.
+ - description#string#Terms to search for in βDescriptionβ fields.
+ - description_508#string#Terms to search for in β508 Descriptionβ fields.
+ - keywords #string#Terms to search for in βKeywordsβ fields. Separate multiple values with commas.
+ - location #string#Terms to search for in βLocationβ fields.
+ - media_type#string#Media types to restrict the search to. Available types: [βimageβ,βvideoβ, βaudioβ]. Separate multiple values with commas.
+ - nasa_id #string#The media assetβs NASA ID.
+ - page#integer#Page number, starting at 1, of results to get.-
+ - page_size#integer#Number of results per page. Default: 100.
+ - photographer#string#The primary photographerβs name.
+ - secondary_creator#string#A secondary photographer/videographerβs name.
+ - title #string#Terms to search for in βTitleβ fields.
+ - year_start#string#The start year for results. Format: YYYY.
+ - year_end #string#The end year for results. Format: YYYY.
+
+ }}
+
+ Below are several task descriptions along with their respective input examples.
+ Task: get the 2nd page of image and video content starting from the year 2002 to 2010
+ Example Input: {{"year_start": "2002", "year_end": "2010", "page": 2}}
+
+ Task: get the image and video content of saturn photographed by John Appleseed
+ Example Input: {{"q": "saturn", "photographer": "John Appleseed"}}
+
+ Task: search for Meteor Showers with description "Search Description" with media type image
+ Example Input: {{"q": "Meteor Shower", "description": "Search Description", "media_type": "image"}}
+
+ Task: get the image and video content from year 2008 to 2010 from Kennedy Center
+ Example Input: {{"year_start": "2002", "year_end": "2010", "location": "Kennedy Center}}
+ """
+
+
+NASA_MANIFEST_PROMPT = """
+ This tool is a wrapper around NASA's media asset manifest API, useful when you need to retrieve a media
+ asset's manifest. The input to this tool should include a string representing a NASA ID for a media asset that the user is trying to get the media asset manifest data for. The NASA ID will be passed as a string into NASA's `get_media_metadata_manifest` function.
+
+ The following list are some examples of NASA IDs for a media asset that you can use to better extract the NASA ID from the input string to the tool.
+ - GSFC_20171102_Archive_e000579
+ - Launch-Sound_Delta-PAM-Random-Commentary
+ - iss066m260341519_Expedition_66_Education_Inflight_with_Random_Lake_School_District_220203
+ - 6973610
+ - GRC-2020-CM-0167.4
+ - Expedition_55_Inflight_Japan_VIP_Event_May_31_2018_659970
+ - NASA 60th_SEAL_SLIVER_150DPI
+"""
+
+NASA_METADATA_PROMPT = """
+ This tool is a wrapper around NASA's media asset metadata location API, useful when you need to retrieve the media asset's metadata. The input to this tool should include a string representing a NASA ID for a media asset that the user is trying to get the media asset metadata location for. The NASA ID will be passed as a string into NASA's `get_media_metadata_manifest` function.
+
+ The following list are some examples of NASA IDs for a media asset that you can use to better extract the NASA ID from the input string to the tool.
+ - GSFC_20171102_Archive_e000579
+ - Launch-Sound_Delta-PAM-Random-Commentary
+ - iss066m260341519_Expedition_66_Education_Inflight_with_Random_Lake_School_District_220203
+ - 6973610
+ - GRC-2020-CM-0167.4
+ - Expedition_55_Inflight_Japan_VIP_Event_May_31_2018_659970
+ - NASA 60th_SEAL_SLIVER_150DPI
+"""
+
+NASA_CAPTIONS_PROMPT = """
+ This tool is a wrapper around NASA's video assests caption location API, useful when you need
+ to retrieve the location of the captions of a specific video. The input to this tool should include a string representing a NASA ID for a video media asset that the user is trying to get the get the location of the captions for. The NASA ID will be passed as a string into NASA's `get_media_metadata_manifest` function.
+
+ The following list are some examples of NASA IDs for a video asset that you can use to better extract the NASA ID from the input string to the tool.
+ - 2017-08-09 - Video File RS-25 Engine Test
+ - 20180415-TESS_Social_Briefing
+ - 201_TakingWildOutOfWildfire
+ - 2022-H1_V_EuropaClipper-4
+ - 2022_0429_Recientemente
+"""
diff --git a/libs/community/langchain_community/tools/nasa/tool.py b/libs/community/langchain_community/tools/nasa/tool.py
new file mode 100644
index 00000000000..f5644661536
--- /dev/null
+++ b/libs/community/langchain_community/tools/nasa/tool.py
@@ -0,0 +1,29 @@
+"""
+This tool allows agents to interact with the NASA API, specifically
+the the NASA Image & Video Library and Exoplanet
+"""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.nasa import NasaAPIWrapper
+
+
+class NasaAction(BaseTool):
+ """Tool that queries the Atlassian Jira API."""
+
+ api_wrapper: NasaAPIWrapper = Field(default_factory=NasaAPIWrapper)
+ mode: str
+ name: str = ""
+ description: str = ""
+
+ def _run(
+ self,
+ instructions: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the NASA API to run an operation."""
+ return self.api_wrapper.run(self.mode, instructions)
diff --git a/libs/community/langchain_community/tools/nuclia/__init__.py b/libs/community/langchain_community/tools/nuclia/__init__.py
new file mode 100644
index 00000000000..ea2f3dc651c
--- /dev/null
+++ b/libs/community/langchain_community/tools/nuclia/__init__.py
@@ -0,0 +1,3 @@
+from langchain_community.tools.nuclia.tool import NucliaUnderstandingAPI
+
+__all__ = ["NucliaUnderstandingAPI"]
diff --git a/libs/community/langchain_community/tools/nuclia/tool.py b/libs/community/langchain_community/tools/nuclia/tool.py
new file mode 100644
index 00000000000..2f8bae73ff3
--- /dev/null
+++ b/libs/community/langchain_community/tools/nuclia/tool.py
@@ -0,0 +1,237 @@
+"""Tool for the Nuclia Understanding API.
+
+Installation:
+
+```bash
+ pip install --upgrade protobuf
+ pip install nucliadb-protos
+```
+"""
+
+import asyncio
+import base64
+import logging
+import mimetypes
+import os
+from typing import Any, Dict, Optional, Type, Union
+
+import requests
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+logger = logging.getLogger(__name__)
+
+
+class NUASchema(BaseModel):
+ """Input for Nuclia Understanding API.
+
+ Attributes:
+ action: Action to perform. Either `push` or `pull`.
+ id: ID of the file to push or pull.
+ path: Path to the file to push (needed only for `push` action).
+ text: Text content to process (needed only for `push` action).
+ """
+
+ action: str = Field(
+ ...,
+ description="Action to perform. Either `push` or `pull`.",
+ )
+ id: str = Field(
+ ...,
+ description="ID of the file to push or pull.",
+ )
+ path: Optional[str] = Field(
+ ...,
+ description="Path to the file to push (needed only for `push` action).",
+ )
+ text: Optional[str] = Field(
+ ...,
+ description="Text content to process (needed only for `push` action).",
+ )
+
+
+class NucliaUnderstandingAPI(BaseTool):
+ """Tool to process files with the Nuclia Understanding API."""
+
+ name: str = "nuclia_understanding_api"
+ description: str = (
+ "A wrapper around Nuclia Understanding API endpoints. "
+ "Useful for when you need to extract text from any kind of files. "
+ )
+ args_schema: Type[BaseModel] = NUASchema
+ _results: Dict[str, Any] = {}
+ _config: Dict[str, Any] = {}
+
+ def __init__(self, enable_ml: bool = False) -> None:
+ zone = os.environ.get("NUCLIA_ZONE", "europe-1")
+ self._config["BACKEND"] = f"https://{zone}.nuclia.cloud/api/v1"
+ key = os.environ.get("NUCLIA_NUA_KEY")
+ if not key:
+ raise ValueError("NUCLIA_NUA_KEY environment variable not set")
+ else:
+ self._config["NUA_KEY"] = key
+ self._config["enable_ml"] = enable_ml
+ super().__init__()
+
+ def _run(
+ self,
+ action: str,
+ id: str,
+ path: Optional[str],
+ text: Optional[str],
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if action == "push":
+ self._check_params(path, text)
+ if path:
+ return self._pushFile(id, path)
+ if text:
+ return self._pushText(id, text)
+ elif action == "pull":
+ return self._pull(id)
+ return ""
+
+ async def _arun(
+ self,
+ action: str,
+ id: str,
+ path: Optional[str] = None,
+ text: Optional[str] = None,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ self._check_params(path, text)
+ if path:
+ self._pushFile(id, path)
+ if text:
+ self._pushText(id, text)
+ data = None
+ while True:
+ data = self._pull(id)
+ if data:
+ break
+ await asyncio.sleep(15)
+ return data
+
+ def _pushText(self, id: str, text: str) -> str:
+ field = {
+ "textfield": {"text": {"body": text, "format": 0}},
+ "processing_options": {"ml_text": self._config["enable_ml"]},
+ }
+ return self._pushField(id, field)
+
+ def _pushFile(self, id: str, content_path: str) -> str:
+ with open(content_path, "rb") as source_file:
+ response = requests.post(
+ self._config["BACKEND"] + "/processing/upload",
+ headers={
+ "content-type": mimetypes.guess_type(content_path)[0]
+ or "application/octet-stream",
+ "x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
+ },
+ data=source_file.read(),
+ )
+ if response.status_code != 200:
+ logger.info(
+ f"Error uploading {content_path}: "
+ f"{response.status_code} {response.text}"
+ )
+ return ""
+ else:
+ field = {
+ "filefield": {"file": f"{response.text}"},
+ "processing_options": {"ml_text": self._config["enable_ml"]},
+ }
+ return self._pushField(id, field)
+
+ def _pushField(self, id: str, field: Any) -> str:
+ logger.info(f"Pushing {id} in queue")
+ response = requests.post(
+ self._config["BACKEND"] + "/processing/push",
+ headers={
+ "content-type": "application/json",
+ "x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
+ },
+ json=field,
+ )
+ if response.status_code != 200:
+ logger.info(
+ f"Error pushing field {id}:" f"{response.status_code} {response.text}"
+ )
+ raise ValueError("Error pushing field")
+ else:
+ uuid = response.json()["uuid"]
+ logger.info(f"Field {id} pushed in queue, uuid: {uuid}")
+ self._results[id] = {"uuid": uuid, "status": "pending"}
+ return uuid
+
+ def _pull(self, id: str) -> str:
+ self._pull_queue()
+ result = self._results.get(id, None)
+ if not result:
+ logger.info(f"{id} not in queue")
+ return ""
+ elif result["status"] == "pending":
+ logger.info(f'Waiting for {result["uuid"]} to be processed')
+ return ""
+ else:
+ return result["data"]
+
+ def _pull_queue(self) -> None:
+ try:
+ from nucliadb_protos.writer_pb2 import BrokerMessage
+ except ImportError as e:
+ raise ImportError(
+ "nucliadb-protos is not installed. "
+ "Run `pip install nucliadb-protos` to install."
+ ) from e
+ try:
+ from google.protobuf.json_format import MessageToJson
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import google.protobuf, please install with "
+ "`pip install protobuf`."
+ ) from e
+
+ res = requests.get(
+ self._config["BACKEND"] + "/processing/pull",
+ headers={
+ "x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
+ },
+ ).json()
+ if res["status"] == "empty":
+ logger.info("Queue empty")
+ elif res["status"] == "ok":
+ payload = res["payload"]
+ pb = BrokerMessage()
+ pb.ParseFromString(base64.b64decode(payload))
+ uuid = pb.uuid
+ logger.info(f"Pulled {uuid} from queue")
+ matching_id = self._find_matching_id(uuid)
+ if not matching_id:
+ logger.info(f"No matching id for {uuid}")
+ else:
+ self._results[matching_id]["status"] = "done"
+ data = MessageToJson(
+ pb,
+ preserving_proto_field_name=True,
+ including_default_value_fields=True,
+ )
+ self._results[matching_id]["data"] = data
+
+ def _find_matching_id(self, uuid: str) -> Union[str, None]:
+ for id, result in self._results.items():
+ if result["uuid"] == uuid:
+ return id
+ return None
+
+ def _check_params(self, path: Optional[str], text: Optional[str]) -> None:
+ if not path and not text:
+ raise ValueError("File path or text is required")
+ if path and text:
+ raise ValueError("Cannot process both file and text on a single run")
diff --git a/libs/community/langchain_community/tools/office365/__init__.py b/libs/community/langchain_community/tools/office365/__init__.py
new file mode 100644
index 00000000000..10ff2206f71
--- /dev/null
+++ b/libs/community/langchain_community/tools/office365/__init__.py
@@ -0,0 +1,19 @@
+"""O365 tools."""
+
+from langchain_community.tools.office365.create_draft_message import (
+ O365CreateDraftMessage,
+)
+from langchain_community.tools.office365.events_search import O365SearchEvents
+from langchain_community.tools.office365.messages_search import O365SearchEmails
+from langchain_community.tools.office365.send_event import O365SendEvent
+from langchain_community.tools.office365.send_message import O365SendMessage
+from langchain_community.tools.office365.utils import authenticate
+
+__all__ = [
+ "O365SearchEmails",
+ "O365SearchEvents",
+ "O365CreateDraftMessage",
+ "O365SendMessage",
+ "O365SendEvent",
+ "authenticate",
+]
diff --git a/libs/community/langchain_community/tools/office365/base.py b/libs/community/langchain_community/tools/office365/base.py
new file mode 100644
index 00000000000..230e2aa125c
--- /dev/null
+++ b/libs/community/langchain_community/tools/office365/base.py
@@ -0,0 +1,19 @@
+"""Base class for Office 365 tools."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.office365.utils import authenticate
+
+if TYPE_CHECKING:
+ from O365 import Account
+
+
+class O365BaseTool(BaseTool):
+ """Base class for the Office 365 tools."""
+
+ account: Account = Field(default_factory=authenticate)
+ """The account object for the Office 365 account."""
diff --git a/libs/community/langchain_community/tools/office365/create_draft_message.py b/libs/community/langchain_community/tools/office365/create_draft_message.py
new file mode 100644
index 00000000000..88c94578a83
--- /dev/null
+++ b/libs/community/langchain_community/tools/office365/create_draft_message.py
@@ -0,0 +1,68 @@
+from typing import List, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.office365.base import O365BaseTool
+
+
+class CreateDraftMessageSchema(BaseModel):
+ """Input for SendMessageTool."""
+
+ body: str = Field(
+ ...,
+ description="The message body to include in the draft.",
+ )
+ to: List[str] = Field(
+ ...,
+ description="The list of recipients.",
+ )
+ subject: str = Field(
+ ...,
+ description="The subject of the message.",
+ )
+ cc: Optional[List[str]] = Field(
+ None,
+ description="The list of CC recipients.",
+ )
+ bcc: Optional[List[str]] = Field(
+ None,
+ description="The list of BCC recipients.",
+ )
+
+
+class O365CreateDraftMessage(O365BaseTool):
+ """Tool for creating a draft email in Office 365."""
+
+ name: str = "create_email_draft"
+ description: str = (
+ "Use this tool to create a draft email with the provided message fields."
+ )
+ args_schema: Type[CreateDraftMessageSchema] = CreateDraftMessageSchema
+
+ def _run(
+ self,
+ body: str,
+ to: List[str],
+ subject: str,
+ cc: Optional[List[str]] = None,
+ bcc: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ # Get mailbox object
+ mailbox = self.account.mailbox()
+ message = mailbox.new_message()
+
+ # Assign message values
+ message.body = body
+ message.subject = subject
+ message.to.add(to)
+ if cc is not None:
+ message.cc.add(cc)
+ if bcc is not None:
+ message.bcc.add(bcc)
+
+ message.save_draft()
+
+ output = "Draft created: " + str(message)
+ return output
diff --git a/libs/community/langchain_community/tools/office365/events_search.py b/libs/community/langchain_community/tools/office365/events_search.py
new file mode 100644
index 00000000000..8cb16f7a299
--- /dev/null
+++ b/libs/community/langchain_community/tools/office365/events_search.py
@@ -0,0 +1,128 @@
+"""Util that Searches calendar events in Office 365.
+
+Free, but setup is required. See link below.
+https://learn.microsoft.com/en-us/graph/auth/
+"""
+
+from datetime import datetime as dt
+from typing import Any, Dict, List, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field
+
+from langchain_community.tools.office365.base import O365BaseTool
+from langchain_community.tools.office365.utils import UTC_FORMAT, clean_body
+
+
+class SearchEventsInput(BaseModel):
+ """Input for SearchEmails Tool.
+
+ From https://learn.microsoft.com/en-us/graph/search-query-parameter"""
+
+ start_datetime: str = Field(
+ description=(
+ " The start datetime for the search query in the following format: "
+ ' YYYY-MM-DDTHH:MM:SSΒ±hh:mm, where "T" separates the date and time '
+ " components, and the time zone offset is specified as Β±hh:mm. "
+ ' For example: "2023-06-09T10:30:00+03:00" represents June 9th, '
+ " 2023, at 10:30 AM in a time zone with a positive offset of 3 "
+ " hours from Coordinated Universal Time (UTC)."
+ )
+ )
+ end_datetime: str = Field(
+ description=(
+ " The end datetime for the search query in the following format: "
+ ' YYYY-MM-DDTHH:MM:SSΒ±hh:mm, where "T" separates the date and time '
+ " components, and the time zone offset is specified as Β±hh:mm. "
+ ' For example: "2023-06-09T10:30:00+03:00" represents June 9th, '
+ " 2023, at 10:30 AM in a time zone with a positive offset of 3 "
+ " hours from Coordinated Universal Time (UTC)."
+ )
+ )
+ max_results: int = Field(
+ default=10,
+ description="The maximum number of results to return.",
+ )
+ truncate: bool = Field(
+ default=True,
+ description=(
+ "Whether the event's body is truncated to meet token number limits. Set to "
+ "False for searches that will retrieve small events, otherwise, set to "
+ "True."
+ ),
+ )
+
+
+class O365SearchEvents(O365BaseTool):
+ """Class for searching calendar events in Office 365
+
+ Free, but setup is required
+ """
+
+ name: str = "events_search"
+ args_schema: Type[BaseModel] = SearchEventsInput
+ description: str = (
+ " Use this tool to search for the user's calendar events."
+ " The input must be the start and end datetimes for the search query."
+ " The output is a JSON list of all the events in the user's calendar"
+ " between the start and end times. You can assume that the user can "
+ " not schedule any meeting over existing meetings, and that the user "
+ "is busy during meetings. Any times without events are free for the user. "
+ )
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _run(
+ self,
+ start_datetime: str,
+ end_datetime: str,
+ max_results: int = 10,
+ truncate: bool = True,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ truncate_limit: int = 150,
+ ) -> List[Dict[str, Any]]:
+ # Get calendar object
+ schedule = self.account.schedule()
+ calendar = schedule.get_default_calendar()
+
+ # Process the date range parameters
+ start_datetime_query = dt.strptime(start_datetime, UTC_FORMAT)
+ end_datetime_query = dt.strptime(end_datetime, UTC_FORMAT)
+
+ # Run the query
+ q = calendar.new_query("start").greater_equal(start_datetime_query)
+ q.chain("and").on_attribute("end").less_equal(end_datetime_query)
+ events = calendar.get_events(query=q, include_recurring=True, limit=max_results)
+
+ # Generate output dict
+ output_events = []
+ for event in events:
+ output_event = {}
+ output_event["organizer"] = event.organizer
+
+ output_event["subject"] = event.subject
+
+ if truncate:
+ output_event["body"] = clean_body(event.body)[:truncate_limit]
+ else:
+ output_event["body"] = clean_body(event.body)
+
+ # Get the time zone from the search parameters
+ time_zone = start_datetime_query.tzinfo
+ # Assign the datetimes in the search time zone
+ output_event["start_datetime"] = event.start.astimezone(time_zone).strftime(
+ UTC_FORMAT
+ )
+ output_event["end_datetime"] = event.end.astimezone(time_zone).strftime(
+ UTC_FORMAT
+ )
+ output_event["modified_date"] = event.modified.astimezone(
+ time_zone
+ ).strftime(UTC_FORMAT)
+
+ output_events.append(output_event)
+
+ return output_events
diff --git a/libs/community/langchain_community/tools/office365/messages_search.py b/libs/community/langchain_community/tools/office365/messages_search.py
new file mode 100644
index 00000000000..ad26e42de4e
--- /dev/null
+++ b/libs/community/langchain_community/tools/office365/messages_search.py
@@ -0,0 +1,123 @@
+"""Util that Searches email messages in Office 365.
+
+Free, but setup is required. See link below.
+https://learn.microsoft.com/en-us/graph/auth/
+"""
+
+from typing import Any, Dict, List, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field
+
+from langchain_community.tools.office365.base import O365BaseTool
+from langchain_community.tools.office365.utils import UTC_FORMAT, clean_body
+
+
+class SearchEmailsInput(BaseModel):
+ """Input for SearchEmails Tool."""
+
+ """From https://learn.microsoft.com/en-us/graph/search-query-parameter"""
+
+ folder: str = Field(
+ default=None,
+ description=(
+ " If the user wants to search in only one folder, the name of the folder. "
+ 'Default folders are "inbox", "drafts", "sent items", "deleted ttems", but '
+ "users can search custom folders as well."
+ ),
+ )
+ query: str = Field(
+ description=(
+ "The Microsoift Graph v1.0 $search query. Example filters include "
+ "from:sender, from:sender, to:recipient, subject:subject, "
+ "recipients:list_of_recipients, body:excitement, importance:high, "
+ "received>2022-12-01, received<2021-12-01, sent>2022-12-01, "
+ "sent<2021-12-01, hasAttachments:true attachment:api-catalog.md, "
+ "cc:samanthab@contoso.com, bcc:samanthab@contoso.com, body:excitement date "
+ "range example: received:2023-06-08..2023-06-09 matching example: "
+ "from:amy OR from:david."
+ )
+ )
+ max_results: int = Field(
+ default=10,
+ description="The maximum number of results to return.",
+ )
+ truncate: bool = Field(
+ default=True,
+ description=(
+ "Whether the email body is truncated to meet token number limits. Set to "
+ "False for searches that will retrieve small messages, otherwise, set to "
+ "True"
+ ),
+ )
+
+
+class O365SearchEmails(O365BaseTool):
+ """Class for searching email messages in Office 365
+
+ Free, but setup is required
+ """
+
+ name: str = "messages_search"
+ args_schema: Type[BaseModel] = SearchEmailsInput
+ description: str = (
+ "Use this tool to search for email messages."
+ " The input must be a valid Microsoft Graph v1.0 $search query."
+ " The output is a JSON list of the requested resource."
+ )
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _run(
+ self,
+ query: str,
+ folder: str = "",
+ max_results: int = 10,
+ truncate: bool = True,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ truncate_limit: int = 150,
+ ) -> List[Dict[str, Any]]:
+ # Get mailbox object
+ mailbox = self.account.mailbox()
+
+ # Pull the folder if the user wants to search in a folder
+ if folder != "":
+ mailbox = mailbox.get_folder(folder_name=folder)
+
+ # Retrieve messages based on query
+ query = mailbox.q().search(query)
+ messages = mailbox.get_messages(limit=max_results, query=query)
+
+ # Generate output dict
+ output_messages = []
+ for message in messages:
+ output_message = {}
+ output_message["from"] = message.sender
+
+ if truncate:
+ output_message["body"] = message.body_preview[:truncate_limit]
+ else:
+ output_message["body"] = clean_body(message.body)
+
+ output_message["subject"] = message.subject
+
+ output_message["date"] = message.modified.strftime(UTC_FORMAT)
+
+ output_message["to"] = []
+ for recipient in message.to._recipients:
+ output_message["to"].append(str(recipient))
+
+ output_message["cc"] = []
+ for recipient in message.cc._recipients:
+ output_message["cc"].append(str(recipient))
+
+ output_message["bcc"] = []
+ for recipient in message.bcc._recipients:
+ output_message["bcc"].append(str(recipient))
+
+ output_messages.append(output_message)
+
+ return output_messages
diff --git a/libs/community/langchain_community/tools/office365/send_event.py b/libs/community/langchain_community/tools/office365/send_event.py
new file mode 100644
index 00000000000..e8b9d023af7
--- /dev/null
+++ b/libs/community/langchain_community/tools/office365/send_event.py
@@ -0,0 +1,85 @@
+"""Util that sends calendar events in Office 365.
+
+Free, but setup is required. See link below.
+https://learn.microsoft.com/en-us/graph/auth/
+"""
+
+from datetime import datetime as dt
+from typing import List, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.office365.base import O365BaseTool
+from langchain_community.tools.office365.utils import UTC_FORMAT
+
+
+class SendEventSchema(BaseModel):
+ """Input for CreateEvent Tool."""
+
+ body: str = Field(
+ ...,
+ description="The message body to include in the event.",
+ )
+ attendees: List[str] = Field(
+ ...,
+ description="The list of attendees for the event.",
+ )
+ subject: str = Field(
+ ...,
+ description="The subject of the event.",
+ )
+ start_datetime: str = Field(
+ description=" The start datetime for the event in the following format: "
+ ' YYYY-MM-DDTHH:MM:SSΒ±hh:mm, where "T" separates the date and time '
+ " components, and the time zone offset is specified as Β±hh:mm. "
+ ' For example: "2023-06-09T10:30:00+03:00" represents June 9th, '
+ " 2023, at 10:30 AM in a time zone with a positive offset of 3 "
+ " hours from Coordinated Universal Time (UTC).",
+ )
+ end_datetime: str = Field(
+ description=" The end datetime for the event in the following format: "
+ ' YYYY-MM-DDTHH:MM:SSΒ±hh:mm, where "T" separates the date and time '
+ " components, and the time zone offset is specified as Β±hh:mm. "
+ ' For example: "2023-06-09T10:30:00+03:00" represents June 9th, '
+ " 2023, at 10:30 AM in a time zone with a positive offset of 3 "
+ " hours from Coordinated Universal Time (UTC).",
+ )
+
+
+class O365SendEvent(O365BaseTool):
+ """Tool for sending calendar events in Office 365."""
+
+ name: str = "send_event"
+ description: str = (
+ "Use this tool to create and send an event with the provided event fields."
+ )
+ args_schema: Type[SendEventSchema] = SendEventSchema
+
+ def _run(
+ self,
+ body: str,
+ attendees: List[str],
+ subject: str,
+ start_datetime: str,
+ end_datetime: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ # Get calendar object
+ schedule = self.account.schedule()
+ calendar = schedule.get_default_calendar()
+
+ event = calendar.new_event()
+
+ event.body = body
+ event.subject = subject
+ event.start = dt.strptime(start_datetime, UTC_FORMAT)
+ event.end = dt.strptime(end_datetime, UTC_FORMAT)
+ for attendee in attendees:
+ event.attendees.add(attendee)
+
+ # TO-DO: Look into PytzUsageWarning
+ event.save()
+
+ output = "Event sent: " + str(event)
+ return output
diff --git a/libs/community/langchain_community/tools/office365/send_message.py b/libs/community/langchain_community/tools/office365/send_message.py
new file mode 100644
index 00000000000..cd7abef976b
--- /dev/null
+++ b/libs/community/langchain_community/tools/office365/send_message.py
@@ -0,0 +1,68 @@
+from typing import List, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.office365.base import O365BaseTool
+
+
+class SendMessageSchema(BaseModel):
+ """Input for SendMessageTool."""
+
+ body: str = Field(
+ ...,
+ description="The message body to be sent.",
+ )
+ to: List[str] = Field(
+ ...,
+ description="The list of recipients.",
+ )
+ subject: str = Field(
+ ...,
+ description="The subject of the message.",
+ )
+ cc: Optional[List[str]] = Field(
+ None,
+ description="The list of CC recipients.",
+ )
+ bcc: Optional[List[str]] = Field(
+ None,
+ description="The list of BCC recipients.",
+ )
+
+
+class O365SendMessage(O365BaseTool):
+ """Tool for sending an email in Office 365."""
+
+ name: str = "send_email"
+ description: str = (
+ "Use this tool to send an email with the provided message fields."
+ )
+ args_schema: Type[SendMessageSchema] = SendMessageSchema
+
+ def _run(
+ self,
+ body: str,
+ to: List[str],
+ subject: str,
+ cc: Optional[List[str]] = None,
+ bcc: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ # Get mailbox object
+ mailbox = self.account.mailbox()
+ message = mailbox.new_message()
+
+ # Assign message values
+ message.body = body
+ message.subject = subject
+ message.to.add(to)
+ if cc is not None:
+ message.cc.add(cc)
+ if bcc is not None:
+ message.bcc.add(bcc)
+
+ message.send()
+
+ output = "Message sent: " + str(message)
+ return output
diff --git a/libs/community/langchain_community/tools/office365/utils.py b/libs/community/langchain_community/tools/office365/utils.py
new file mode 100644
index 00000000000..127fb6dba1f
--- /dev/null
+++ b/libs/community/langchain_community/tools/office365/utils.py
@@ -0,0 +1,78 @@
+"""O365 tool utils."""
+from __future__ import annotations
+
+import logging
+import os
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from O365 import Account
+
+logger = logging.getLogger(__name__)
+
+
+def clean_body(body: str) -> str:
+ """Clean body of a message or event."""
+ try:
+ from bs4 import BeautifulSoup
+
+ try:
+ # Remove HTML
+ soup = BeautifulSoup(str(body), "html.parser")
+ body = soup.get_text()
+
+ # Remove return characters
+ body = "".join(body.splitlines())
+
+ # Remove extra spaces
+ body = " ".join(body.split())
+
+ return str(body)
+ except Exception:
+ return str(body)
+ except ImportError:
+ return str(body)
+
+
+def authenticate() -> Account:
+ """Authenticate using the Microsoft Grah API"""
+ try:
+ from O365 import Account
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import 0365. Please install the package with `pip install O365`."
+ ) from e
+
+ if "CLIENT_ID" in os.environ and "CLIENT_SECRET" in os.environ:
+ client_id = os.environ["CLIENT_ID"]
+ client_secret = os.environ["CLIENT_SECRET"]
+ credentials = (client_id, client_secret)
+ else:
+ logger.error(
+ "Error: The CLIENT_ID and CLIENT_SECRET environmental variables have not "
+ "been set. Visit the following link on how to acquire these authorization "
+ "tokens: https://learn.microsoft.com/en-us/graph/auth/"
+ )
+ return None
+
+ account = Account(credentials)
+
+ if account.is_authenticated is False:
+ if not account.authenticate(
+ scopes=[
+ "https://graph.microsoft.com/Mail.ReadWrite",
+ "https://graph.microsoft.com/Mail.Send",
+ "https://graph.microsoft.com/Calendars.ReadWrite",
+ "https://graph.microsoft.com/MailboxSettings.ReadWrite",
+ ]
+ ):
+ print("Error: Could not authenticate")
+ return None
+ else:
+ return account
+ else:
+ return account
+
+
+UTC_FORMAT = "%Y-%m-%dT%H:%M:%S%z"
+"""UTC format for datetime objects."""
diff --git a/libs/langchain/tests/integration_tests/utilities/__init__.py b/libs/community/langchain_community/tools/openapi/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/utilities/__init__.py
rename to libs/community/langchain_community/tools/openapi/__init__.py
diff --git a/libs/langchain/tests/integration_tests/vectorstores/docarray/__init__.py b/libs/community/langchain_community/tools/openapi/utils/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/vectorstores/docarray/__init__.py
rename to libs/community/langchain_community/tools/openapi/utils/__init__.py
diff --git a/libs/community/langchain_community/tools/openapi/utils/api_models.py b/libs/community/langchain_community/tools/openapi/utils/api_models.py
new file mode 100644
index 00000000000..38ae4afca17
--- /dev/null
+++ b/libs/community/langchain_community/tools/openapi/utils/api_models.py
@@ -0,0 +1,631 @@
+"""Pydantic models for parsing an OpenAPI spec."""
+from __future__ import annotations
+
+import logging
+from enum import Enum
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
+
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
+
+logger = logging.getLogger(__name__)
+PRIMITIVE_TYPES = {
+ "integer": int,
+ "number": float,
+ "string": str,
+ "boolean": bool,
+ "array": List,
+ "object": Dict,
+ "null": None,
+}
+
+
+# See https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#parameterIn
+# for more info.
+class APIPropertyLocation(Enum):
+ """The location of the property."""
+
+ QUERY = "query"
+ PATH = "path"
+ HEADER = "header"
+ COOKIE = "cookie" # Not yet supported
+
+ @classmethod
+ def from_str(cls, location: str) -> "APIPropertyLocation":
+ """Parse an APIPropertyLocation."""
+ try:
+ return cls(location)
+ except ValueError:
+ raise ValueError(
+ f"Invalid APIPropertyLocation. Valid values are {cls.__members__}"
+ )
+
+
+_SUPPORTED_MEDIA_TYPES = ("application/json",)
+
+SUPPORTED_LOCATIONS = {
+ APIPropertyLocation.QUERY,
+ APIPropertyLocation.PATH,
+}
+INVALID_LOCATION_TEMPL = (
+ 'Unsupported APIPropertyLocation "{location}"'
+ " for parameter {name}. "
+ + f"Valid values are {[loc.value for loc in SUPPORTED_LOCATIONS]}"
+)
+
+SCHEMA_TYPE = Union[str, Type, tuple, None, Enum]
+
+
+class APIPropertyBase(BaseModel):
+ """Base model for an API property."""
+
+ # The name of the parameter is required and is case-sensitive.
+ # If "in" is "path", the "name" field must correspond to a template expression
+ # within the path field in the Paths Object.
+ # If "in" is "header" and the "name" field is "Accept", "Content-Type",
+ # or "Authorization", the parameter definition is ignored.
+ # For all other cases, the "name" corresponds to the parameter
+ # name used by the "in" property.
+ name: str = Field(alias="name")
+ """The name of the property."""
+
+ required: bool = Field(alias="required")
+ """Whether the property is required."""
+
+ type: SCHEMA_TYPE = Field(alias="type")
+ """The type of the property.
+
+ Either a primitive type, a component/parameter type,
+ or an array or 'object' (dict) of the above."""
+
+ default: Optional[Any] = Field(alias="default", default=None)
+ """The default value of the property."""
+
+ description: Optional[str] = Field(alias="description", default=None)
+ """The description of the property."""
+
+
+if TYPE_CHECKING:
+ from openapi_pydantic import (
+ MediaType,
+ Parameter,
+ RequestBody,
+ Schema,
+ )
+
+
+class APIProperty(APIPropertyBase):
+ """A model for a property in the query, path, header, or cookie params."""
+
+ location: APIPropertyLocation = Field(alias="location")
+ """The path/how it's being passed to the endpoint."""
+
+ @staticmethod
+ def _cast_schema_list_type(
+ schema: Schema,
+ ) -> Optional[Union[str, Tuple[str, ...]]]:
+ type_ = schema.type
+ if not isinstance(type_, list):
+ return type_
+ else:
+ return tuple(type_)
+
+ @staticmethod
+ def _get_schema_type_for_enum(parameter: Parameter, schema: Schema) -> Enum:
+ """Get the schema type when the parameter is an enum."""
+ param_name = f"{parameter.name}Enum"
+ return Enum(param_name, {str(v): v for v in schema.enum})
+
+ @staticmethod
+ def _get_schema_type_for_array(
+ schema: Schema,
+ ) -> Optional[Union[str, Tuple[str, ...]]]:
+ from openapi_pydantic import (
+ Reference,
+ Schema,
+ )
+
+ items = schema.items
+ if isinstance(items, Schema):
+ schema_type = APIProperty._cast_schema_list_type(items)
+ elif isinstance(items, Reference):
+ ref_name = items.ref.split("/")[-1]
+ schema_type = ref_name # TODO: Add ref definitions to make his valid
+ else:
+ raise ValueError(f"Unsupported array items: {items}")
+
+ if isinstance(schema_type, str):
+ # TODO: recurse
+ schema_type = (schema_type,)
+
+ return schema_type
+
+ @staticmethod
+ def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE:
+ if schema is None:
+ return None
+ schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
+ if schema_type == "array":
+ schema_type = APIProperty._get_schema_type_for_array(schema)
+ elif schema_type == "object":
+ # TODO: Resolve array and object types to components.
+ raise NotImplementedError("Objects not yet supported")
+ elif schema_type in PRIMITIVE_TYPES:
+ if schema.enum:
+ schema_type = APIProperty._get_schema_type_for_enum(parameter, schema)
+ else:
+ # Directly use the primitive type
+ pass
+ else:
+ raise NotImplementedError(f"Unsupported type: {schema_type}")
+
+ return schema_type
+
+ @staticmethod
+ def _validate_location(location: APIPropertyLocation, name: str) -> None:
+ if location not in SUPPORTED_LOCATIONS:
+ raise NotImplementedError(
+ INVALID_LOCATION_TEMPL.format(location=location, name=name)
+ )
+
+ @staticmethod
+ def _validate_content(content: Optional[Dict[str, MediaType]]) -> None:
+ if content:
+ raise ValueError(
+ "API Properties with media content not supported. "
+ "Media content only supported within APIRequestBodyProperty's"
+ )
+
+ @staticmethod
+ def _get_schema(parameter: Parameter, spec: OpenAPISpec) -> Optional[Schema]:
+ from openapi_pydantic import (
+ Reference,
+ Schema,
+ )
+
+ schema = parameter.param_schema
+ if isinstance(schema, Reference):
+ schema = spec.get_referenced_schema(schema)
+ elif schema is None:
+ return None
+ elif not isinstance(schema, Schema):
+ raise ValueError(f"Error dereferencing schema: {schema}")
+
+ return schema
+
+ @staticmethod
+ def is_supported_location(location: str) -> bool:
+ """Return whether the provided location is supported."""
+ try:
+ return APIPropertyLocation.from_str(location) in SUPPORTED_LOCATIONS
+ except ValueError:
+ return False
+
+ @classmethod
+ def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty":
+ """Instantiate from an OpenAPI Parameter."""
+ location = APIPropertyLocation.from_str(parameter.param_in)
+ cls._validate_location(
+ location,
+ parameter.name,
+ )
+ cls._validate_content(parameter.content)
+ schema = cls._get_schema(parameter, spec)
+ schema_type = cls._get_schema_type(parameter, schema)
+ default_val = schema.default if schema is not None else None
+ return cls(
+ name=parameter.name,
+ location=location,
+ default=default_val,
+ description=parameter.description,
+ required=parameter.required,
+ type=schema_type,
+ )
+
+
+class APIRequestBodyProperty(APIPropertyBase):
+ """A model for a request body property."""
+
+ properties: List["APIRequestBodyProperty"] = Field(alias="properties")
+ """The sub-properties of the property."""
+
+ # This is useful for handling nested property cycles.
+ # We can define separate types in that case.
+ references_used: List[str] = Field(alias="references_used")
+ """The references used by the property."""
+
+ @classmethod
+ def _process_object_schema(
+ cls, schema: Schema, spec: OpenAPISpec, references_used: List[str]
+ ) -> Tuple[Union[str, List[str], None], List["APIRequestBodyProperty"]]:
+ from openapi_pydantic import (
+ Reference,
+ )
+
+ properties = []
+ required_props = schema.required or []
+ if schema.properties is None:
+ raise ValueError(
+ f"No properties found when processing object schema: {schema}"
+ )
+ for prop_name, prop_schema in schema.properties.items():
+ if isinstance(prop_schema, Reference):
+ ref_name = prop_schema.ref.split("/")[-1]
+ if ref_name not in references_used:
+ references_used.append(ref_name)
+ prop_schema = spec.get_referenced_schema(prop_schema)
+ else:
+ continue
+
+ properties.append(
+ cls.from_schema(
+ schema=prop_schema,
+ name=prop_name,
+ required=prop_name in required_props,
+ spec=spec,
+ references_used=references_used,
+ )
+ )
+ return schema.type, properties
+
+ @classmethod
+ def _process_array_schema(
+ cls,
+ schema: Schema,
+ name: str,
+ spec: OpenAPISpec,
+ references_used: List[str],
+ ) -> str:
+ from openapi_pydantic import Reference, Schema
+
+ items = schema.items
+ if items is not None:
+ if isinstance(items, Reference):
+ ref_name = items.ref.split("/")[-1]
+ if ref_name not in references_used:
+ references_used.append(ref_name)
+ items = spec.get_referenced_schema(items)
+ else:
+ pass
+ return f"Array<{ref_name}>"
+ else:
+ pass
+
+ if isinstance(items, Schema):
+ array_type = cls.from_schema(
+ schema=items,
+ name=f"{name}Item",
+ required=True, # TODO: Add required
+ spec=spec,
+ references_used=references_used,
+ )
+ return f"Array<{array_type.type}>"
+
+ return "array"
+
+ @classmethod
+ def from_schema(
+ cls,
+ schema: Schema,
+ name: str,
+ required: bool,
+ spec: OpenAPISpec,
+ references_used: Optional[List[str]] = None,
+ ) -> "APIRequestBodyProperty":
+ """Recursively populate from an OpenAPI Schema."""
+ if references_used is None:
+ references_used = []
+
+ schema_type = schema.type
+ properties: List[APIRequestBodyProperty] = []
+ if schema_type == "object" and schema.properties:
+ schema_type, properties = cls._process_object_schema(
+ schema, spec, references_used
+ )
+ elif schema_type == "array":
+ schema_type = cls._process_array_schema(schema, name, spec, references_used)
+ elif schema_type in PRIMITIVE_TYPES:
+ # Use the primitive type directly
+ pass
+ elif schema_type is None:
+ # No typing specified/parsed. WIll map to 'any'
+ pass
+ else:
+ raise ValueError(f"Unsupported type: {schema_type}")
+
+ return cls(
+ name=name,
+ required=required,
+ type=schema_type,
+ default=schema.default,
+ description=schema.description,
+ properties=properties,
+ references_used=references_used,
+ )
+
+
+# class APIRequestBodyProperty(APIPropertyBase):
+class APIRequestBody(BaseModel):
+ """A model for a request body."""
+
+ description: Optional[str] = Field(alias="description")
+ """The description of the request body."""
+
+ properties: List[APIRequestBodyProperty] = Field(alias="properties")
+
+ # E.g., application/json - we only support JSON at the moment.
+ media_type: str = Field(alias="media_type")
+ """The media type of the request body."""
+
+ @classmethod
+ def _process_supported_media_type(
+ cls,
+ media_type_obj: MediaType,
+ spec: OpenAPISpec,
+ ) -> List[APIRequestBodyProperty]:
+ """Process the media type of the request body."""
+ from openapi_pydantic import Reference
+
+ references_used = []
+ schema = media_type_obj.media_type_schema
+ if isinstance(schema, Reference):
+ references_used.append(schema.ref.split("/")[-1])
+ schema = spec.get_referenced_schema(schema)
+ if schema is None:
+ raise ValueError(
+ f"Could not resolve schema for media type: {media_type_obj}"
+ )
+ api_request_body_properties = []
+ required_properties = schema.required or []
+ if schema.type == "object" and schema.properties:
+ for prop_name, prop_schema in schema.properties.items():
+ if isinstance(prop_schema, Reference):
+ prop_schema = spec.get_referenced_schema(prop_schema)
+
+ api_request_body_properties.append(
+ APIRequestBodyProperty.from_schema(
+ schema=prop_schema,
+ name=prop_name,
+ required=prop_name in required_properties,
+ spec=spec,
+ )
+ )
+ else:
+ api_request_body_properties.append(
+ APIRequestBodyProperty(
+ name="body",
+ required=True,
+ type=schema.type,
+ default=schema.default,
+ description=schema.description,
+ properties=[],
+ references_used=references_used,
+ )
+ )
+
+ return api_request_body_properties
+
+ @classmethod
+ def from_request_body(
+ cls, request_body: RequestBody, spec: OpenAPISpec
+ ) -> "APIRequestBody":
+ """Instantiate from an OpenAPI RequestBody."""
+ properties = []
+ for media_type, media_type_obj in request_body.content.items():
+ if media_type not in _SUPPORTED_MEDIA_TYPES:
+ continue
+ api_request_body_properties = cls._process_supported_media_type(
+ media_type_obj,
+ spec,
+ )
+ properties.extend(api_request_body_properties)
+
+ return cls(
+ description=request_body.description,
+ properties=properties,
+ media_type=media_type,
+ )
+
+
+# class APIRequestBodyProperty(APIPropertyBase):
+# class APIRequestBody(BaseModel):
+class APIOperation(BaseModel):
+ """A model for a single API operation."""
+
+ operation_id: str = Field(alias="operation_id")
+ """The unique identifier of the operation."""
+
+ description: Optional[str] = Field(alias="description")
+ """The description of the operation."""
+
+ base_url: str = Field(alias="base_url")
+ """The base URL of the operation."""
+
+ path: str = Field(alias="path")
+ """The path of the operation."""
+
+ method: HTTPVerb = Field(alias="method")
+ """The HTTP method of the operation."""
+
+ properties: Sequence[APIProperty] = Field(alias="properties")
+
+ # TODO: Add parse in used components to be able to specify what type of
+ # referenced object it is.
+ # """The properties of the operation."""
+ # components: Dict[str, BaseModel] = Field(alias="components")
+
+ request_body: Optional[APIRequestBody] = Field(alias="request_body")
+ """The request body of the operation."""
+
+ @staticmethod
+ def _get_properties_from_parameters(
+ parameters: List[Parameter], spec: OpenAPISpec
+ ) -> List[APIProperty]:
+ """Get the properties of the operation."""
+ properties = []
+ for param in parameters:
+ if APIProperty.is_supported_location(param.param_in):
+ properties.append(APIProperty.from_parameter(param, spec))
+ elif param.required:
+ raise ValueError(
+ INVALID_LOCATION_TEMPL.format(
+ location=param.param_in, name=param.name
+ )
+ )
+ else:
+ logger.warning(
+ INVALID_LOCATION_TEMPL.format(
+ location=param.param_in, name=param.name
+ )
+ + " Ignoring optional parameter"
+ )
+ pass
+ return properties
+
+ @classmethod
+ def from_openapi_url(
+ cls,
+ spec_url: str,
+ path: str,
+ method: str,
+ ) -> "APIOperation":
+ """Create an APIOperation from an OpenAPI URL."""
+ spec = OpenAPISpec.from_url(spec_url)
+ return cls.from_openapi_spec(spec, path, method)
+
+ @classmethod
+ def from_openapi_spec(
+ cls,
+ spec: OpenAPISpec,
+ path: str,
+ method: str,
+ ) -> "APIOperation":
+ """Create an APIOperation from an OpenAPI spec."""
+ operation = spec.get_operation(path, method)
+ parameters = spec.get_parameters_for_operation(operation)
+ properties = cls._get_properties_from_parameters(parameters, spec)
+ operation_id = OpenAPISpec.get_cleaned_operation_id(operation, path, method)
+ request_body = spec.get_request_body_for_operation(operation)
+ api_request_body = (
+ APIRequestBody.from_request_body(request_body, spec)
+ if request_body is not None
+ else None
+ )
+ description = operation.description or operation.summary
+ if not description and spec.paths is not None:
+ description = spec.paths[path].description or spec.paths[path].summary
+ return cls(
+ operation_id=operation_id,
+ description=description or "",
+ base_url=spec.base_url,
+ path=path,
+ method=method,
+ properties=properties,
+ request_body=api_request_body,
+ )
+
+ @staticmethod
+ def ts_type_from_python(type_: SCHEMA_TYPE) -> str:
+ if type_ is None:
+ # TODO: Handle Nones better. These often result when
+ # parsing specs that are < v3
+ return "any"
+ elif isinstance(type_, str):
+ return {
+ "str": "string",
+ "integer": "number",
+ "float": "number",
+ "date-time": "string",
+ }.get(type_, type_)
+ elif isinstance(type_, tuple):
+ return f"Array<{APIOperation.ts_type_from_python(type_[0])}>"
+ elif isinstance(type_, type) and issubclass(type_, Enum):
+ return " | ".join([f"'{e.value}'" for e in type_])
+ else:
+ return str(type_)
+
+ def _format_nested_properties(
+ self, properties: List[APIRequestBodyProperty], indent: int = 2
+ ) -> str:
+ """Format nested properties."""
+ formatted_props = []
+
+ for prop in properties:
+ prop_name = prop.name
+ prop_type = self.ts_type_from_python(prop.type)
+ prop_required = "" if prop.required else "?"
+ prop_desc = f"/* {prop.description} */" if prop.description else ""
+
+ if prop.properties:
+ nested_props = self._format_nested_properties(
+ prop.properties, indent + 2
+ )
+ prop_type = f"{{\n{nested_props}\n{' ' * indent}}}"
+
+ formatted_props.append(
+ f"{prop_desc}\n{' ' * indent}{prop_name}"
+ f"{prop_required}: {prop_type},"
+ )
+
+ return "\n".join(formatted_props)
+
+ def to_typescript(self) -> str:
+ """Get typescript string representation of the operation."""
+ operation_name = self.operation_id
+ params = []
+
+ if self.request_body:
+ formatted_request_body_props = self._format_nested_properties(
+ self.request_body.properties
+ )
+ params.append(formatted_request_body_props)
+
+ for prop in self.properties:
+ prop_name = prop.name
+ prop_type = self.ts_type_from_python(prop.type)
+ prop_required = "" if prop.required else "?"
+ prop_desc = f"/* {prop.description} */" if prop.description else ""
+ params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},")
+
+ formatted_params = "\n".join(params).strip()
+ description_str = f"/* {self.description} */" if self.description else ""
+ typescript_definition = f"""
+{description_str}
+type {operation_name} = (_: {{
+{formatted_params}
+}}) => any;
+"""
+ return typescript_definition.strip()
+
+ @property
+ def query_params(self) -> List[str]:
+ return [
+ property.name
+ for property in self.properties
+ if property.location == APIPropertyLocation.QUERY
+ ]
+
+ @property
+ def path_params(self) -> List[str]:
+ return [
+ property.name
+ for property in self.properties
+ if property.location == APIPropertyLocation.PATH
+ ]
+
+ @property
+ def body_params(self) -> List[str]:
+ if self.request_body is None:
+ return []
+ return [prop.name for prop in self.request_body.properties]
diff --git a/libs/community/langchain_community/tools/openapi/utils/openapi_utils.py b/libs/community/langchain_community/tools/openapi/utils/openapi_utils.py
new file mode 100644
index 00000000000..e958d19d838
--- /dev/null
+++ b/libs/community/langchain_community/tools/openapi/utils/openapi_utils.py
@@ -0,0 +1,4 @@
+"""Utility functions for parsing an OpenAPI spec. Kept for backwards compat."""
+from langchain_community.utilities.openapi import HTTPVerb, OpenAPISpec
+
+__all__ = ["HTTPVerb", "OpenAPISpec"]
diff --git a/libs/community/langchain_community/tools/openweathermap/__init__.py b/libs/community/langchain_community/tools/openweathermap/__init__.py
new file mode 100644
index 00000000000..b2ff9b36515
--- /dev/null
+++ b/libs/community/langchain_community/tools/openweathermap/__init__.py
@@ -0,0 +1,8 @@
+"""OpenWeatherMap API toolkit."""
+
+
+from langchain_community.tools.openweathermap.tool import OpenWeatherMapQueryRun
+
+__all__ = [
+ "OpenWeatherMapQueryRun",
+]
diff --git a/libs/community/langchain_community/tools/openweathermap/tool.py b/libs/community/langchain_community/tools/openweathermap/tool.py
new file mode 100644
index 00000000000..6060321047c
--- /dev/null
+++ b/libs/community/langchain_community/tools/openweathermap/tool.py
@@ -0,0 +1,30 @@
+"""Tool for the OpenWeatherMap API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.openweathermap import OpenWeatherMapAPIWrapper
+
+
+class OpenWeatherMapQueryRun(BaseTool):
+ """Tool that queries the OpenWeatherMap API."""
+
+ api_wrapper: OpenWeatherMapAPIWrapper = Field(
+ default_factory=OpenWeatherMapAPIWrapper
+ )
+
+ name: str = "OpenWeatherMap"
+ description: str = (
+ "A wrapper around OpenWeatherMap API. "
+ "Useful for fetching current weather information for a specified location. "
+ "Input should be a location string (e.g. London,GB)."
+ )
+
+ def _run(
+ self, location: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ """Use the OpenWeatherMap tool."""
+ return self.api_wrapper.run(location)
diff --git a/libs/community/langchain_community/tools/playwright/__init__.py b/libs/community/langchain_community/tools/playwright/__init__.py
new file mode 100644
index 00000000000..f69ff8025d3
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/__init__.py
@@ -0,0 +1,21 @@
+"""Browser tools and toolkit."""
+
+from langchain_community.tools.playwright.click import ClickTool
+from langchain_community.tools.playwright.current_page import CurrentWebPageTool
+from langchain_community.tools.playwright.extract_hyperlinks import (
+ ExtractHyperlinksTool,
+)
+from langchain_community.tools.playwright.extract_text import ExtractTextTool
+from langchain_community.tools.playwright.get_elements import GetElementsTool
+from langchain_community.tools.playwright.navigate import NavigateTool
+from langchain_community.tools.playwright.navigate_back import NavigateBackTool
+
+__all__ = [
+ "NavigateTool",
+ "NavigateBackTool",
+ "ExtractTextTool",
+ "ExtractHyperlinksTool",
+ "GetElementsTool",
+ "ClickTool",
+ "CurrentWebPageTool",
+]
diff --git a/libs/community/langchain_community/tools/playwright/base.py b/libs/community/langchain_community/tools/playwright/base.py
new file mode 100644
index 00000000000..d906bfea51d
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/base.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional, Tuple, Type
+
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.tools import BaseTool
+
+if TYPE_CHECKING:
+ from playwright.async_api import Browser as AsyncBrowser
+ from playwright.sync_api import Browser as SyncBrowser
+else:
+ try:
+ # We do this so pydantic can resolve the types when instantiating
+ from playwright.async_api import Browser as AsyncBrowser
+ from playwright.sync_api import Browser as SyncBrowser
+ except ImportError:
+ pass
+
+
+def lazy_import_playwright_browsers() -> Tuple[Type[AsyncBrowser], Type[SyncBrowser]]:
+ """
+ Lazy import playwright browsers.
+
+ Returns:
+ Tuple[Type[AsyncBrowser], Type[SyncBrowser]]:
+ AsyncBrowser and SyncBrowser classes.
+ """
+ try:
+ from playwright.async_api import Browser as AsyncBrowser # noqa: F401
+ from playwright.sync_api import Browser as SyncBrowser # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "The 'playwright' package is required to use the playwright tools."
+ " Please install it with 'pip install playwright'."
+ )
+ return AsyncBrowser, SyncBrowser
+
+
+class BaseBrowserTool(BaseTool):
+ """Base class for browser tools."""
+
+ sync_browser: Optional["SyncBrowser"] = None
+ async_browser: Optional["AsyncBrowser"] = None
+
+ @root_validator
+ def validate_browser_provided(cls, values: dict) -> dict:
+ """Check that the arguments are valid."""
+ lazy_import_playwright_browsers()
+ if values.get("async_browser") is None and values.get("sync_browser") is None:
+ raise ValueError("Either async_browser or sync_browser must be specified.")
+ return values
+
+ @classmethod
+ def from_browser(
+ cls,
+ sync_browser: Optional[SyncBrowser] = None,
+ async_browser: Optional[AsyncBrowser] = None,
+ ) -> BaseBrowserTool:
+ """Instantiate the tool."""
+ lazy_import_playwright_browsers()
+ return cls(sync_browser=sync_browser, async_browser=async_browser)
diff --git a/libs/community/langchain_community/tools/playwright/click.py b/libs/community/langchain_community/tools/playwright/click.py
new file mode 100644
index 00000000000..5cf5cdda148
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/click.py
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+from typing import Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.playwright.base import BaseBrowserTool
+from langchain_community.tools.playwright.utils import (
+ aget_current_page,
+ get_current_page,
+)
+
+
+class ClickToolInput(BaseModel):
+ """Input for ClickTool."""
+
+ selector: str = Field(..., description="CSS selector for the element to click")
+
+
+class ClickTool(BaseBrowserTool):
+ """Tool for clicking on an element with the given CSS selector."""
+
+ name: str = "click_element"
+ description: str = "Click on an element with the given CSS selector"
+ args_schema: Type[BaseModel] = ClickToolInput
+
+ visible_only: bool = True
+ """Whether to consider only visible elements."""
+ playwright_strict: bool = False
+ """Whether to employ Playwright's strict mode when clicking on elements."""
+ playwright_timeout: float = 1_000
+ """Timeout (in ms) for Playwright to wait for element to be ready."""
+
+ def _selector_effective(self, selector: str) -> str:
+ if not self.visible_only:
+ return selector
+ return f"{selector} >> visible=1"
+
+ def _run(
+ self,
+ selector: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.sync_browser is None:
+ raise ValueError(f"Synchronous browser not provided to {self.name}")
+ page = get_current_page(self.sync_browser)
+ # Navigate to the desired webpage before using this tool
+ selector_effective = self._selector_effective(selector=selector)
+ from playwright.sync_api import TimeoutError as PlaywrightTimeoutError
+
+ try:
+ page.click(
+ selector_effective,
+ strict=self.playwright_strict,
+ timeout=self.playwright_timeout,
+ )
+ except PlaywrightTimeoutError:
+ return f"Unable to click on element '{selector}'"
+ return f"Clicked element '{selector}'"
+
+ async def _arun(
+ self,
+ selector: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.async_browser is None:
+ raise ValueError(f"Asynchronous browser not provided to {self.name}")
+ page = await aget_current_page(self.async_browser)
+ # Navigate to the desired webpage before using this tool
+ selector_effective = self._selector_effective(selector=selector)
+ from playwright.async_api import TimeoutError as PlaywrightTimeoutError
+
+ try:
+ await page.click(
+ selector_effective,
+ strict=self.playwright_strict,
+ timeout=self.playwright_timeout,
+ )
+ except PlaywrightTimeoutError:
+ return f"Unable to click on element '{selector}'"
+ return f"Clicked element '{selector}'"
diff --git a/libs/community/langchain_community/tools/playwright/current_page.py b/libs/community/langchain_community/tools/playwright/current_page.py
new file mode 100644
index 00000000000..9ea5e0e1f9d
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/current_page.py
@@ -0,0 +1,43 @@
+from __future__ import annotations
+
+from typing import Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel
+
+from langchain_community.tools.playwright.base import BaseBrowserTool
+from langchain_community.tools.playwright.utils import (
+ aget_current_page,
+ get_current_page,
+)
+
+
+class CurrentWebPageTool(BaseBrowserTool):
+ """Tool for getting the URL of the current webpage."""
+
+ name: str = "current_webpage"
+ description: str = "Returns the URL of the current page"
+ args_schema: Type[BaseModel] = BaseModel
+
+ def _run(
+ self,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.sync_browser is None:
+ raise ValueError(f"Synchronous browser not provided to {self.name}")
+ page = get_current_page(self.sync_browser)
+ return str(page.url)
+
+ async def _arun(
+ self,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.async_browser is None:
+ raise ValueError(f"Asynchronous browser not provided to {self.name}")
+ page = await aget_current_page(self.async_browser)
+ return str(page.url)
diff --git a/libs/community/langchain_community/tools/playwright/extract_hyperlinks.py b/libs/community/langchain_community/tools/playwright/extract_hyperlinks.py
new file mode 100644
index 00000000000..9c6f64911cc
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/extract_hyperlinks.py
@@ -0,0 +1,91 @@
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any, Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+
+from langchain_community.tools.playwright.base import BaseBrowserTool
+from langchain_community.tools.playwright.utils import (
+ aget_current_page,
+ get_current_page,
+)
+
+if TYPE_CHECKING:
+ pass
+
+
+class ExtractHyperlinksToolInput(BaseModel):
+ """Input for ExtractHyperlinksTool."""
+
+ absolute_urls: bool = Field(
+ default=False,
+ description="Return absolute URLs instead of relative URLs",
+ )
+
+
+class ExtractHyperlinksTool(BaseBrowserTool):
+ """Extract all hyperlinks on the page."""
+
+ name: str = "extract_hyperlinks"
+ description: str = "Extract all hyperlinks on the current webpage"
+ args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
+
+ @root_validator
+ def check_bs_import(cls, values: dict) -> dict:
+ """Check that the arguments are valid."""
+ try:
+ from bs4 import BeautifulSoup # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "The 'beautifulsoup4' package is required to use this tool."
+ " Please install it with 'pip install beautifulsoup4'."
+ )
+ return values
+
+ @staticmethod
+ def scrape_page(page: Any, html_content: str, absolute_urls: bool) -> str:
+ from urllib.parse import urljoin
+
+ from bs4 import BeautifulSoup
+
+ # Parse the HTML content with BeautifulSoup
+ soup = BeautifulSoup(html_content, "lxml")
+
+ # Find all the anchor elements and extract their href attributes
+ anchors = soup.find_all("a")
+ if absolute_urls:
+ base_url = page.url
+ links = [urljoin(base_url, anchor.get("href", "")) for anchor in anchors]
+ else:
+ links = [anchor.get("href", "") for anchor in anchors]
+ # Return the list of links as a JSON string
+ return json.dumps(links)
+
+ def _run(
+ self,
+ absolute_urls: bool = False,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.sync_browser is None:
+ raise ValueError(f"Synchronous browser not provided to {self.name}")
+ page = get_current_page(self.sync_browser)
+ html_content = page.content()
+ return self.scrape_page(page, html_content, absolute_urls)
+
+ async def _arun(
+ self,
+ absolute_urls: bool = False,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ if self.async_browser is None:
+ raise ValueError(f"Asynchronous browser not provided to {self.name}")
+ page = await aget_current_page(self.async_browser)
+ html_content = await page.content()
+ return self.scrape_page(page, html_content, absolute_urls)
diff --git a/libs/community/langchain_community/tools/playwright/extract_text.py b/libs/community/langchain_community/tools/playwright/extract_text.py
new file mode 100644
index 00000000000..4f01ea925d0
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/extract_text.py
@@ -0,0 +1,68 @@
+from __future__ import annotations
+
+from typing import Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+from langchain_community.tools.playwright.base import BaseBrowserTool
+from langchain_community.tools.playwright.utils import (
+ aget_current_page,
+ get_current_page,
+)
+
+
+class ExtractTextTool(BaseBrowserTool):
+ """Tool for extracting all the text on the current webpage."""
+
+ name: str = "extract_text"
+ description: str = "Extract all the text on the current webpage"
+ args_schema: Type[BaseModel] = BaseModel
+
+ @root_validator
+ def check_acheck_bs_importrgs(cls, values: dict) -> dict:
+ """Check that the arguments are valid."""
+ try:
+ from bs4 import BeautifulSoup # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "The 'beautifulsoup4' package is required to use this tool."
+ " Please install it with 'pip install beautifulsoup4'."
+ )
+ return values
+
+ def _run(self, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
+ """Use the tool."""
+ # Use Beautiful Soup since it's faster than looping through the elements
+ from bs4 import BeautifulSoup
+
+ if self.sync_browser is None:
+ raise ValueError(f"Synchronous browser not provided to {self.name}")
+
+ page = get_current_page(self.sync_browser)
+ html_content = page.content()
+
+ # Parse the HTML content with BeautifulSoup
+ soup = BeautifulSoup(html_content, "lxml")
+
+ return " ".join(text for text in soup.stripped_strings)
+
+ async def _arun(
+ self, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
+ ) -> str:
+ """Use the tool."""
+ if self.async_browser is None:
+ raise ValueError(f"Asynchronous browser not provided to {self.name}")
+ # Use Beautiful Soup since it's faster than looping through the elements
+ from bs4 import BeautifulSoup
+
+ page = await aget_current_page(self.async_browser)
+ html_content = await page.content()
+
+ # Parse the HTML content with BeautifulSoup
+ soup = BeautifulSoup(html_content, "lxml")
+
+ return " ".join(text for text in soup.stripped_strings)
diff --git a/libs/community/langchain_community/tools/playwright/get_elements.py b/libs/community/langchain_community/tools/playwright/get_elements.py
new file mode 100644
index 00000000000..3b88529fc75
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/get_elements.py
@@ -0,0 +1,111 @@
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, List, Optional, Sequence, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.playwright.base import BaseBrowserTool
+from langchain_community.tools.playwright.utils import (
+ aget_current_page,
+ get_current_page,
+)
+
+if TYPE_CHECKING:
+ from playwright.async_api import Page as AsyncPage
+ from playwright.sync_api import Page as SyncPage
+
+
+class GetElementsToolInput(BaseModel):
+ """Input for GetElementsTool."""
+
+ selector: str = Field(
+ ...,
+ description="CSS selector, such as '*', 'div', 'p', 'a', #id, .classname",
+ )
+ attributes: List[str] = Field(
+ default_factory=lambda: ["innerText"],
+ description="Set of attributes to retrieve for each element",
+ )
+
+
+async def _aget_elements(
+ page: AsyncPage, selector: str, attributes: Sequence[str]
+) -> List[dict]:
+ """Get elements matching the given CSS selector."""
+ elements = await page.query_selector_all(selector)
+ results = []
+ for element in elements:
+ result = {}
+ for attribute in attributes:
+ if attribute == "innerText":
+ val: Optional[str] = await element.inner_text()
+ else:
+ val = await element.get_attribute(attribute)
+ if val is not None and val.strip() != "":
+ result[attribute] = val
+ if result:
+ results.append(result)
+ return results
+
+
+def _get_elements(
+ page: SyncPage, selector: str, attributes: Sequence[str]
+) -> List[dict]:
+ """Get elements matching the given CSS selector."""
+ elements = page.query_selector_all(selector)
+ results = []
+ for element in elements:
+ result = {}
+ for attribute in attributes:
+ if attribute == "innerText":
+ val: Optional[str] = element.inner_text()
+ else:
+ val = element.get_attribute(attribute)
+ if val is not None and val.strip() != "":
+ result[attribute] = val
+ if result:
+ results.append(result)
+ return results
+
+
+class GetElementsTool(BaseBrowserTool):
+ """Tool for getting elements in the current web page matching a CSS selector."""
+
+ name: str = "get_elements"
+ description: str = (
+ "Retrieve elements in the current web page matching the given CSS selector"
+ )
+ args_schema: Type[BaseModel] = GetElementsToolInput
+
+ def _run(
+ self,
+ selector: str,
+ attributes: Sequence[str] = ["innerText"],
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.sync_browser is None:
+ raise ValueError(f"Synchronous browser not provided to {self.name}")
+ page = get_current_page(self.sync_browser)
+ # Navigate to the desired webpage before using this tool
+ results = _get_elements(page, selector, attributes)
+ return json.dumps(results, ensure_ascii=False)
+
+ async def _arun(
+ self,
+ selector: str,
+ attributes: Sequence[str] = ["innerText"],
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.async_browser is None:
+ raise ValueError(f"Asynchronous browser not provided to {self.name}")
+ page = await aget_current_page(self.async_browser)
+ # Navigate to the desired webpage before using this tool
+ results = await _aget_elements(page, selector, attributes)
+ return json.dumps(results, ensure_ascii=False)
diff --git a/libs/community/langchain_community/tools/playwright/navigate.py b/libs/community/langchain_community/tools/playwright/navigate.py
new file mode 100644
index 00000000000..82f6349ff06
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/navigate.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+from typing import Optional, Type
+from urllib.parse import urlparse
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field, validator
+
+from langchain_community.tools.playwright.base import BaseBrowserTool
+from langchain_community.tools.playwright.utils import (
+ aget_current_page,
+ get_current_page,
+)
+
+
+class NavigateToolInput(BaseModel):
+ """Input for NavigateToolInput."""
+
+ url: str = Field(..., description="url to navigate to")
+
+ @validator("url")
+ def validate_url_scheme(cls, url: str) -> str:
+ """Check that the URL scheme is valid."""
+ parsed_url = urlparse(url)
+ if parsed_url.scheme not in ("http", "https"):
+ raise ValueError("URL scheme must be 'http' or 'https'")
+ return url
+
+
+class NavigateTool(BaseBrowserTool):
+ """Tool for navigating a browser to a URL.
+
+ **Security Note**: This tool provides code to control web-browser navigation.
+
+ This tool can navigate to any URL, including internal network URLs, and
+ URLs exposed on the server itself.
+
+ However, if exposing this tool to end-users, consider limiting network
+ access to the server that hosts the agent.
+
+ By default, the URL scheme has been limited to 'http' and 'https' to
+ prevent navigation to local file system URLs (or other schemes).
+
+ If access to the local file system is required, consider creating a custom
+ tool or providing a custom args_schema that allows the desired URL schemes.
+
+ See https://python.langchain.com/docs/security for more information.
+ """
+
+ name: str = "navigate_browser"
+ description: str = "Navigate a browser to the specified URL"
+ args_schema: Type[BaseModel] = NavigateToolInput
+
+ def _run(
+ self,
+ url: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.sync_browser is None:
+ raise ValueError(f"Synchronous browser not provided to {self.name}")
+ page = get_current_page(self.sync_browser)
+ response = page.goto(url)
+ status = response.status if response else "unknown"
+ return f"Navigating to {url} returned status code {status}"
+
+ async def _arun(
+ self,
+ url: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.async_browser is None:
+ raise ValueError(f"Asynchronous browser not provided to {self.name}")
+ page = await aget_current_page(self.async_browser)
+ response = await page.goto(url)
+ status = response.status if response else "unknown"
+ return f"Navigating to {url} returned status code {status}"
diff --git a/libs/community/langchain_community/tools/playwright/navigate_back.py b/libs/community/langchain_community/tools/playwright/navigate_back.py
new file mode 100644
index 00000000000..5988fa7fe89
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/navigate_back.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+from typing import Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel
+
+from langchain_community.tools.playwright.base import BaseBrowserTool
+from langchain_community.tools.playwright.utils import (
+ aget_current_page,
+ get_current_page,
+)
+
+
+class NavigateBackTool(BaseBrowserTool):
+ """Navigate back to the previous page in the browser history."""
+
+ name: str = "previous_webpage"
+ description: str = "Navigate back to the previous page in the browser history"
+ args_schema: Type[BaseModel] = BaseModel
+
+ def _run(self, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
+ """Use the tool."""
+ if self.sync_browser is None:
+ raise ValueError(f"Synchronous browser not provided to {self.name}")
+ page = get_current_page(self.sync_browser)
+ response = page.go_back()
+
+ if response:
+ return (
+ f"Navigated back to the previous page with URL '{response.url}'."
+ f" Status code {response.status}"
+ )
+ else:
+ return "Unable to navigate back; no previous page in the history"
+
+ async def _arun(
+ self,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ if self.async_browser is None:
+ raise ValueError(f"Asynchronous browser not provided to {self.name}")
+ page = await aget_current_page(self.async_browser)
+ response = await page.go_back()
+
+ if response:
+ return (
+ f"Navigated back to the previous page with URL '{response.url}'."
+ f" Status code {response.status}"
+ )
+ else:
+ return "Unable to navigate back; no previous page in the history"
diff --git a/libs/community/langchain_community/tools/playwright/utils.py b/libs/community/langchain_community/tools/playwright/utils.py
new file mode 100644
index 00000000000..692288fdde3
--- /dev/null
+++ b/libs/community/langchain_community/tools/playwright/utils.py
@@ -0,0 +1,104 @@
+"""Utilities for the Playwright browser tools."""
+from __future__ import annotations
+
+import asyncio
+from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, TypeVar
+
+if TYPE_CHECKING:
+ from playwright.async_api import Browser as AsyncBrowser
+ from playwright.async_api import Page as AsyncPage
+ from playwright.sync_api import Browser as SyncBrowser
+ from playwright.sync_api import Page as SyncPage
+
+
+async def aget_current_page(browser: AsyncBrowser) -> AsyncPage:
+ """
+ Asynchronously get the current page of the browser.
+
+ Args:
+ browser: The browser (AsyncBrowser) to get the current page from.
+
+ Returns:
+ AsyncPage: The current page.
+ """
+ if not browser.contexts:
+ context = await browser.new_context()
+ return await context.new_page()
+ context = browser.contexts[0] # Assuming you're using the default browser context
+ if not context.pages:
+ return await context.new_page()
+ # Assuming the last page in the list is the active one
+ return context.pages[-1]
+
+
+def get_current_page(browser: SyncBrowser) -> SyncPage:
+ """
+ Get the current page of the browser.
+ Args:
+ browser: The browser to get the current page from.
+
+ Returns:
+ SyncPage: The current page.
+ """
+ if not browser.contexts:
+ context = browser.new_context()
+ return context.new_page()
+ context = browser.contexts[0] # Assuming you're using the default browser context
+ if not context.pages:
+ return context.new_page()
+ # Assuming the last page in the list is the active one
+ return context.pages[-1]
+
+
+def create_async_playwright_browser(
+ headless: bool = True, args: Optional[List[str]] = None
+) -> AsyncBrowser:
+ """
+ Create an async playwright browser.
+
+ Args:
+ headless: Whether to run the browser in headless mode. Defaults to True.
+ args: arguments to pass to browser.chromium.launch
+
+ Returns:
+ AsyncBrowser: The playwright browser.
+ """
+ from playwright.async_api import async_playwright
+
+ browser = run_async(async_playwright().start())
+ return run_async(browser.chromium.launch(headless=headless, args=args))
+
+
+def create_sync_playwright_browser(
+ headless: bool = True, args: Optional[List[str]] = None
+) -> SyncBrowser:
+ """
+ Create a playwright browser.
+
+ Args:
+ headless: Whether to run the browser in headless mode. Defaults to True.
+ args: arguments to pass to browser.chromium.launch
+
+ Returns:
+ SyncBrowser: The playwright browser.
+ """
+ from playwright.sync_api import sync_playwright
+
+ browser = sync_playwright().start()
+ return browser.chromium.launch(headless=headless, args=args)
+
+
+T = TypeVar("T")
+
+
+def run_async(coro: Coroutine[Any, Any, T]) -> T:
+ """Run an async coroutine.
+
+ Args:
+ coro: The coroutine to run. Coroutine[Any, Any, T]
+
+ Returns:
+ T: The result of the coroutine.
+ """
+ event_loop = asyncio.get_event_loop()
+ return event_loop.run_until_complete(coro)
diff --git a/libs/community/langchain_community/tools/plugin.py b/libs/community/langchain_community/tools/plugin.py
new file mode 100644
index 00000000000..0966ed5bede
--- /dev/null
+++ b/libs/community/langchain_community/tools/plugin.py
@@ -0,0 +1,110 @@
+from __future__ import annotations
+
+import json
+from typing import Optional, Type
+
+import requests
+import yaml
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel
+from langchain_core.tools import BaseTool
+
+
+class ApiConfig(BaseModel):
+ """API Configuration."""
+
+ type: str
+ url: str
+ has_user_authentication: Optional[bool] = False
+
+
+class AIPlugin(BaseModel):
+ """AI Plugin Definition."""
+
+ schema_version: str
+ name_for_model: str
+ name_for_human: str
+ description_for_model: str
+ description_for_human: str
+ auth: Optional[dict] = None
+ api: ApiConfig
+ logo_url: Optional[str]
+ contact_email: Optional[str]
+ legal_info_url: Optional[str]
+
+ @classmethod
+ def from_url(cls, url: str) -> AIPlugin:
+ """Instantiate AIPlugin from a URL."""
+ response = requests.get(url).json()
+ return cls(**response)
+
+
+def marshal_spec(txt: str) -> dict:
+ """Convert the yaml or json serialized spec to a dict.
+
+ Args:
+ txt: The yaml or json serialized spec.
+
+ Returns:
+ dict: The spec as a dict.
+ """
+ try:
+ return json.loads(txt)
+ except json.JSONDecodeError:
+ return yaml.safe_load(txt)
+
+
+class AIPluginToolSchema(BaseModel):
+ """Schema for AIPluginTool."""
+
+ tool_input: Optional[str] = ""
+
+
+class AIPluginTool(BaseTool):
+ """Tool for getting the OpenAPI spec for an AI Plugin."""
+
+ plugin: AIPlugin
+ api_spec: str
+ args_schema: Type[AIPluginToolSchema] = AIPluginToolSchema
+
+ @classmethod
+ def from_plugin_url(cls, url: str) -> AIPluginTool:
+ plugin = AIPlugin.from_url(url)
+ description = (
+ f"Call this tool to get the OpenAPI spec (and usage guide) "
+ f"for interacting with the {plugin.name_for_human} API. "
+ f"You should only call this ONCE! What is the "
+ f"{plugin.name_for_human} API useful for? "
+ ) + plugin.description_for_human
+ open_api_spec_str = requests.get(plugin.api.url).text
+ open_api_spec = marshal_spec(open_api_spec_str)
+ api_spec = (
+ f"Usage Guide: {plugin.description_for_model}\n\n"
+ f"OpenAPI Spec: {open_api_spec}"
+ )
+
+ return cls(
+ name=plugin.name_for_model,
+ description=description,
+ plugin=plugin,
+ api_spec=api_spec,
+ )
+
+ def _run(
+ self,
+ tool_input: Optional[str] = "",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_spec
+
+ async def _arun(
+ self,
+ tool_input: Optional[str] = None,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ return self.api_spec
diff --git a/libs/community/langchain_community/tools/powerbi/__init__.py b/libs/community/langchain_community/tools/powerbi/__init__.py
new file mode 100644
index 00000000000..3ecc25a12f5
--- /dev/null
+++ b/libs/community/langchain_community/tools/powerbi/__init__.py
@@ -0,0 +1 @@
+"""Tools for interacting with a PowerBI dataset."""
diff --git a/libs/community/langchain_community/tools/powerbi/prompt.py b/libs/community/langchain_community/tools/powerbi/prompt.py
new file mode 100644
index 00000000000..caf32756aca
--- /dev/null
+++ b/libs/community/langchain_community/tools/powerbi/prompt.py
@@ -0,0 +1,70 @@
+# flake8: noqa
+QUESTION_TO_QUERY_BASE = """
+Answer the question below with a DAX query that can be sent to Power BI. DAX queries have a simple syntax comprised of just one required keyword, EVALUATE, and several optional keywords: ORDER BY, START AT, DEFINE, MEASURE, VAR, TABLE, and COLUMN. Each keyword defines a statement used for the duration of the query. Any time < or > are used in the text below it means that those values need to be replaced by table, columns or other things. If the question is not something you can answer with a DAX query, reply with "I cannot answer this" and the question will be escalated to a human.
+
+Some DAX functions return a table instead of a scalar, and must be wrapped in a function that evaluates the table and returns a scalar; unless the table is a single column, single row table, then it is treated as a scalar value. Most DAX functions require one or more arguments, which can include tables, columns, expressions, and values. However, some functions, such as PI, do not require any arguments, but always require parentheses to indicate the null argument. For example, you must always type PI(), not PI. You can also nest functions within other functions.
+
+Some commonly used functions are:
+EVALUATE
- At the most basic level, a DAX query is an EVALUATE statement containing a table expression. At least one EVALUATE statement is required, however, a query can contain any number of EVALUATE statements.
+EVALUATE
ORDER BY ASC or DESC - The optional ORDER BY keyword defines one or more expressions used to sort query results. Any expression that can be evaluated for each row of the result is valid.
+EVALUATE
ORDER BY ASC or DESC START AT or - The optional START AT keyword is used inside an ORDER BY clause. It defines the value at which the query results begin.
+DEFINE MEASURE | VAR; EVALUATE
- The optional DEFINE keyword introduces one or more calculated entity definitions that exist only for the duration of the query. Definitions precede the EVALUATE statement and are valid for all EVALUATE statements in the query. Definitions can be variables, measures, tables1, and columns1. Definitions can reference other definitions that appear before or after the current definition. At least one definition is required if the DEFINE keyword is included in a query.
+MEASURE
[] = - Introduces a measure definition in a DEFINE statement of a DAX query.
+VAR = - Stores the result of an expression as a named variable, which can then be passed as an argument to other measure expressions. Once resultant values have been calculated for a variable expression, those values do not change, even if the variable is referenced in another expression.
+
+FILTER(
,) - Returns a table that represents a subset of another table or expression, where is a Boolean expression that is to be evaluated for each row of the table. For example, [Amount] > 0 or [Region] = "France"
+ROW(, ) - Returns a table with a single row containing values that result from the expressions given to each column.
+TOPN(,
, , ) - Returns a table with the top n rows from the specified table, sorted by the specified expression, in the order specified by 0 for descending, 1 for ascending, the default is 0. Multiple OrderBy_Expressions and Order pairs can be given, separated by a comma.
+DISTINCT() - Returns a one-column table that contains the distinct values from the specified column. In other words, duplicate values are removed and only unique values are returned. This function cannot be used to Return values into a cell or column on a worksheet; rather, you nest the DISTINCT function within a formula, to get a list of distinct values that can be passed to another function and then counted, summed, or used for other operations.
+DISTINCT(
) - Returns a table by removing duplicate rows from another table or expression.
+
+Aggregation functions, names with a A in it, handle booleans and empty strings in appropriate ways, while the same function without A only uses the numeric values in a column. Functions names with an X in it can include a expression as an argument, this will be evaluated for each row in the table and the result will be used in the regular function calculation, these are the functions:
+COUNT(), COUNTA(), COUNTX(
,), COUNTAX(
,), COUNTROWS([
]), COUNTBLANK(), DISTINCTCOUNT(), DISTINCTCOUNTNOBLANK () - these are all variations of count functions.
+AVERAGE(), AVERAGEA(), AVERAGEX(
,) - these are all variations of average functions.
+MAX(), MAXA(), MAXX(
,) - these are all variations of max functions.
+MIN(), MINA(), MINX(
,) - these are all variations of min functions.
+PRODUCT(), PRODUCTX(
,) - these are all variations of product functions.
+SUM(), SUMX(
,) - these are all variations of sum functions.
+
+Date and time functions:
+DATE(year, month, day) - Returns a date value that represents the specified year, month, and day.
+DATEDIFF(date1, date2, ) - Returns the difference between two date values, in the specified interval, that can be SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR.
+DATEVALUE() - Returns a date value that represents the specified date.
+YEAR(), QUARTER(), MONTH(), DAY(), HOUR(), MINUTE(), SECOND() - Returns the part of the date for the specified date.
+
+Finally, make sure to escape double quotes with a single backslash, and make sure that only table names have single quotes around them, while names of measures or the values of columns that you want to compare against are in escaped double quotes. Newlines are not necessary and can be skipped. The queries are serialized as json and so will have to fit be compliant with json syntax. Sometimes you will get a question, a DAX query and a error, in that case you need to rewrite the DAX query to get the correct answer.
+
+The following tables exist: {tables}
+
+and the schema's for some are given here:
+{schemas}
+
+Examples:
+{examples}
+"""
+
+USER_INPUT = """
+Question: {tool_input}
+DAX:
+"""
+
+SINGLE_QUESTION_TO_QUERY = f"{QUESTION_TO_QUERY_BASE}{USER_INPUT}"
+
+DEFAULT_FEWSHOT_EXAMPLES = """
+Question: How many rows are in the table
?
+DAX: EVALUATE ROW(\"Number of rows\", COUNTROWS(
))
+----
+Question: How many rows are in the table
where is not empty?
+DAX: EVALUATE ROW(\"Number of rows\", COUNTROWS(FILTER(
,
[] <> \"\")))
+----
+Question: What was the average of in
?
+DAX: EVALUATE ROW(\"Average\", AVERAGE(
[]))
+----
+"""
+
+RETRY_RESPONSE = (
+ "{tool_input} DAX: {query} Error: {error}. Please supply a new DAX query."
+)
+BAD_REQUEST_RESPONSE = "Error on this question, the error was {error}, you can try to rephrase the question."
+SCHEMA_ERROR_RESPONSE = "Bad request, are you sure the table name is correct?"
+UNAUTHORIZED_RESPONSE = "Unauthorized. Try changing your authentication, do not retry."
diff --git a/libs/community/langchain_community/tools/powerbi/tool.py b/libs/community/langchain_community/tools/powerbi/tool.py
new file mode 100644
index 00000000000..9f54ec453d4
--- /dev/null
+++ b/libs/community/langchain_community/tools/powerbi/tool.py
@@ -0,0 +1,276 @@
+"""Tools for interacting with a Power BI dataset."""
+import logging
+from time import perf_counter
+from typing import Any, Dict, Optional, Tuple
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import Field, validator
+from langchain_core.tools import BaseTool
+
+from langchain_community.chat_models.openai import _import_tiktoken
+from langchain_community.tools.powerbi.prompt import (
+ BAD_REQUEST_RESPONSE,
+ DEFAULT_FEWSHOT_EXAMPLES,
+ RETRY_RESPONSE,
+)
+from langchain_community.utilities.powerbi import PowerBIDataset, json_to_md
+
+logger = logging.getLogger(__name__)
+
+
+class QueryPowerBITool(BaseTool):
+ """Tool for querying a Power BI Dataset."""
+
+ name: str = "query_powerbi"
+ description: str = """
+ Input to this tool is a detailed question about the dataset, output is a result from the dataset. It will try to answer the question using the dataset, and if it cannot, it will ask for clarification.
+
+ Example Input: "How many rows are in table1?"
+ """ # noqa: E501
+ llm_chain: Any
+ powerbi: PowerBIDataset = Field(exclude=True)
+ examples: Optional[str] = DEFAULT_FEWSHOT_EXAMPLES
+ session_cache: Dict[str, Any] = Field(default_factory=dict, exclude=True)
+ max_iterations: int = 5
+ output_token_limit: int = 4000
+ tiktoken_model_name: Optional[str] = None # "cl100k_base"
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @validator("llm_chain")
+ def validate_llm_chain_input_variables( # pylint: disable=E0213
+ cls, llm_chain: Any
+ ) -> Any:
+ """Make sure the LLM chain has the correct input variables."""
+ for var in llm_chain.prompt.input_variables:
+ if var not in ["tool_input", "tables", "schemas", "examples"]:
+ raise ValueError(
+ "LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: C0301 E501 # pylint: disable=C0301
+ llm_chain.prompt.input_variables,
+ )
+ return llm_chain
+
+ def _check_cache(self, tool_input: str) -> Optional[str]:
+ """Check if the input is present in the cache.
+
+ If the value is a bad request, overwrite with the escalated version,
+ if not present return None."""
+ if tool_input not in self.session_cache:
+ return None
+ return self.session_cache[tool_input]
+
+ def _run(
+ self,
+ tool_input: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Execute the query, return the results or an error message."""
+ if cache := self._check_cache(tool_input):
+ logger.debug("Found cached result for %s: %s", tool_input, cache)
+ return cache
+
+ try:
+ logger.info("Running PBI Query Tool with input: %s", tool_input)
+ query = self.llm_chain.predict(
+ tool_input=tool_input,
+ tables=self.powerbi.get_table_names(),
+ schemas=self.powerbi.get_schemas(),
+ examples=self.examples,
+ callbacks=run_manager.get_child() if run_manager else None,
+ )
+ except Exception as exc: # pylint: disable=broad-except
+ self.session_cache[tool_input] = f"Error on call to LLM: {exc}"
+ return self.session_cache[tool_input]
+ if query == "I cannot answer this":
+ self.session_cache[tool_input] = query
+ return self.session_cache[tool_input]
+ logger.info("PBI Query:\n%s", query)
+ start_time = perf_counter()
+ pbi_result = self.powerbi.run(command=query)
+ end_time = perf_counter()
+ logger.debug("PBI Result: %s", pbi_result)
+ logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}")
+ result, error = self._parse_output(pbi_result)
+ if error is not None and "TokenExpired" in error:
+ self.session_cache[
+ tool_input
+ ] = "Authentication token expired or invalid, please try reauthenticate."
+ return self.session_cache[tool_input]
+
+ iterations = kwargs.get("iterations", 0)
+ if error and iterations < self.max_iterations:
+ return self._run(
+ tool_input=RETRY_RESPONSE.format(
+ tool_input=tool_input, query=query, error=error
+ ),
+ run_manager=run_manager,
+ iterations=iterations + 1,
+ )
+
+ self.session_cache[tool_input] = (
+ result if result else BAD_REQUEST_RESPONSE.format(error=error)
+ )
+ return self.session_cache[tool_input]
+
+ async def _arun(
+ self,
+ tool_input: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Execute the query, return the results or an error message."""
+ if cache := self._check_cache(tool_input):
+ logger.debug("Found cached result for %s: %s", tool_input, cache)
+ return f"{cache}, from cache, you have already asked this question."
+ try:
+ logger.info("Running PBI Query Tool with input: %s", tool_input)
+ query = await self.llm_chain.apredict(
+ tool_input=tool_input,
+ tables=self.powerbi.get_table_names(),
+ schemas=self.powerbi.get_schemas(),
+ examples=self.examples,
+ callbacks=run_manager.get_child() if run_manager else None,
+ )
+ except Exception as exc: # pylint: disable=broad-except
+ self.session_cache[tool_input] = f"Error on call to LLM: {exc}"
+ return self.session_cache[tool_input]
+
+ if query == "I cannot answer this":
+ self.session_cache[tool_input] = query
+ return self.session_cache[tool_input]
+ logger.info("PBI Query: %s", query)
+ start_time = perf_counter()
+ pbi_result = await self.powerbi.arun(command=query)
+ end_time = perf_counter()
+ logger.debug("PBI Result: %s", pbi_result)
+ logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}")
+ result, error = self._parse_output(pbi_result)
+ if error is not None and ("TokenExpired" in error or "TokenError" in error):
+ self.session_cache[
+ tool_input
+ ] = "Authentication token expired or invalid, please try to reauthenticate or check the scope of the credential." # noqa: E501
+ return self.session_cache[tool_input]
+
+ iterations = kwargs.get("iterations", 0)
+ if error and iterations < self.max_iterations:
+ return await self._arun(
+ tool_input=RETRY_RESPONSE.format(
+ tool_input=tool_input, query=query, error=error
+ ),
+ run_manager=run_manager,
+ iterations=iterations + 1,
+ )
+
+ self.session_cache[tool_input] = (
+ result if result else BAD_REQUEST_RESPONSE.format(error=error)
+ )
+ return self.session_cache[tool_input]
+
+ def _parse_output(
+ self, pbi_result: Dict[str, Any]
+ ) -> Tuple[Optional[str], Optional[Any]]:
+ """Parse the output of the query to a markdown table."""
+ if "results" in pbi_result:
+ rows = pbi_result["results"][0]["tables"][0]["rows"]
+ if len(rows) == 0:
+ logger.info("0 records in result, query was valid.")
+ return (
+ None,
+ "0 rows returned, this might be correct, but please validate if all filter values were correct?", # noqa: E501
+ )
+ result = json_to_md(rows)
+ too_long, length = self._result_too_large(result)
+ if too_long:
+ return (
+ f"Result too large, please try to be more specific or use the `TOPN` function. The result is {length} tokens long, the limit is {self.output_token_limit} tokens.", # noqa: E501
+ None,
+ )
+ return result, None
+
+ if "error" in pbi_result:
+ if (
+ "pbi.error" in pbi_result["error"]
+ and "details" in pbi_result["error"]["pbi.error"]
+ ):
+ return None, pbi_result["error"]["pbi.error"]["details"][0]["detail"]
+ return None, pbi_result["error"]
+ return None, pbi_result
+
+ def _result_too_large(self, result: str) -> Tuple[bool, int]:
+ """Tokenize the output of the query."""
+ if self.tiktoken_model_name:
+ tiktoken_ = _import_tiktoken()
+ encoding = tiktoken_.encoding_for_model(self.tiktoken_model_name)
+ length = len(encoding.encode(result))
+ logger.info("Result length: %s", length)
+ return length > self.output_token_limit, length
+ return False, 0
+
+
+class InfoPowerBITool(BaseTool):
+ """Tool for getting metadata about a PowerBI Dataset."""
+
+ name: str = "schema_powerbi"
+ description: str = """
+ Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
+ Be sure that the tables actually exist by calling list_tables_powerbi first!
+
+ Example Input: "table1, table2, table3"
+ """ # noqa: E501
+ powerbi: PowerBIDataset = Field(exclude=True)
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def _run(
+ self,
+ tool_input: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Get the schema for tables in a comma-separated list."""
+ return self.powerbi.get_table_info(tool_input.split(", "))
+
+ async def _arun(
+ self,
+ tool_input: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ return await self.powerbi.aget_table_info(tool_input.split(", "))
+
+
+class ListPowerBITool(BaseTool):
+ """Tool for getting tables names."""
+
+ name: str = "list_tables_powerbi"
+ description: str = "Input is an empty string, output is a comma separated list of tables in the database." # noqa: E501 # pylint: disable=C0301
+ powerbi: PowerBIDataset = Field(exclude=True)
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def _run(
+ self,
+ tool_input: Optional[str] = None,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Get the names of the tables."""
+ return ", ".join(self.powerbi.get_table_names())
+
+ async def _arun(
+ self,
+ tool_input: Optional[str] = None,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Get the names of the tables."""
+ return ", ".join(self.powerbi.get_table_names())
diff --git a/libs/community/langchain_community/tools/pubmed/__init__.py b/libs/community/langchain_community/tools/pubmed/__init__.py
new file mode 100644
index 00000000000..687e908ee13
--- /dev/null
+++ b/libs/community/langchain_community/tools/pubmed/__init__.py
@@ -0,0 +1 @@
+"""PubMed API toolkit."""
diff --git a/libs/community/langchain_community/tools/pubmed/tool.py b/libs/community/langchain_community/tools/pubmed/tool.py
new file mode 100644
index 00000000000..cd9e1a38884
--- /dev/null
+++ b/libs/community/langchain_community/tools/pubmed/tool.py
@@ -0,0 +1,29 @@
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.pubmed import PubMedAPIWrapper
+
+
+class PubmedQueryRun(BaseTool):
+ """Tool that searches the PubMed API."""
+
+ name: str = "PubMed"
+ description: str = (
+ "A wrapper around PubMed. "
+ "Useful for when you need to answer questions about medicine, health, "
+ "and biomedical topics "
+ "from biomedical literature, MEDLINE, life science journals, and online books. "
+ "Input should be a search query."
+ )
+ api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the PubMed tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/reddit_search/tool.py b/libs/community/langchain_community/tools/reddit_search/tool.py
new file mode 100644
index 00000000000..7c0da2a01e7
--- /dev/null
+++ b/libs/community/langchain_community/tools/reddit_search/tool.py
@@ -0,0 +1,64 @@
+"""Tool for the Reddit search API."""
+
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper
+
+
+class RedditSearchSchema(BaseModel):
+ """Input for Reddit search."""
+
+ query: str = Field(
+ description="should be query string that post title should \
+ contain, or '*' if anything is allowed."
+ )
+ sort: str = Field(
+ description='should be sort method, which is one of: "relevance" \
+ , "hot", "top", "new", or "comments".'
+ )
+ time_filter: str = Field(
+ description='should be time period to filter by, which is \
+ one of "all", "day", "hour", "month", "week", or "year"'
+ )
+ subreddit: str = Field(
+ description='should be name of subreddit, like "all" for \
+ r/all'
+ )
+ limit: str = Field(
+ description="a positive integer indicating the maximum number \
+ of results to return"
+ )
+
+
+class RedditSearchRun(BaseTool):
+ """Tool that queries for posts on a subreddit."""
+
+ name: str = "reddit_search"
+ description: str = (
+ "A tool that searches for posts on Reddit."
+ "Useful when you need to know post information on a subreddit."
+ )
+ api_wrapper: RedditSearchAPIWrapper = Field(default_factory=RedditSearchAPIWrapper)
+ args_schema: Type[BaseModel] = RedditSearchSchema
+
+ def _run(
+ self,
+ query: str,
+ sort: str,
+ time_filter: str,
+ subreddit: str,
+ limit: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(
+ query=query,
+ sort=sort,
+ time_filter=time_filter,
+ subreddit=subreddit,
+ limit=int(limit),
+ )
diff --git a/libs/community/langchain_community/tools/render.py b/libs/community/langchain_community/tools/render.py
new file mode 100644
index 00000000000..7a5c3b1ed47
--- /dev/null
+++ b/libs/community/langchain_community/tools/render.py
@@ -0,0 +1,77 @@
+"""Different methods for rendering Tools to be passed to LLMs.
+
+Depending on the LLM you are using and the prompting strategy you are using,
+you may want Tools to be rendered in a different way.
+This module contains various ways to render tools.
+"""
+from typing import List
+
+from langchain_core.tools import BaseTool
+
+from langchain_community.utils.openai_functions import (
+ FunctionDescription,
+ ToolDescription,
+ convert_pydantic_to_openai_function,
+)
+
+
+def render_text_description(tools: List[BaseTool]) -> str:
+ """Render the tool name and description in plain text.
+
+ Output will be in the format of:
+
+ .. code-block:: markdown
+
+ search: This tool is used for search
+ calculator: This tool is used for math
+ """
+ return "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
+
+
+def render_text_description_and_args(tools: List[BaseTool]) -> str:
+ """Render the tool name, description, and args in plain text.
+
+ Output will be in the format of:
+
+ .. code-block:: markdown
+
+ search: This tool is used for search, args: {"query": {"type": "string"}}
+ calculator: This tool is used for math, \
+args: {"expression": {"type": "string"}}
+ """
+ tool_strings = []
+ for tool in tools:
+ args_schema = str(tool.args)
+ tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
+ return "\n".join(tool_strings)
+
+
+def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
+ """Format tool into the OpenAI function API."""
+ if tool.args_schema:
+ return convert_pydantic_to_openai_function(
+ tool.args_schema, name=tool.name, description=tool.description
+ )
+ else:
+ return {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": {
+ # This is a hack to get around the fact that some tools
+ # do not expose an args_schema, and expect an argument
+ # which is a string.
+ # And Open AI does not support an array type for the
+ # parameters.
+ "properties": {
+ "__arg1": {"title": "__arg1", "type": "string"},
+ },
+ "required": ["__arg1"],
+ "type": "object",
+ },
+ }
+
+
+def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
+ """Format tool into the OpenAI function API."""
+ function = format_tool_to_openai_function(tool)
+ return {"type": "function", "function": function}
diff --git a/libs/community/langchain_community/tools/requests/__init__.py b/libs/community/langchain_community/tools/requests/__init__.py
new file mode 100644
index 00000000000..ec421f18dbd
--- /dev/null
+++ b/libs/community/langchain_community/tools/requests/__init__.py
@@ -0,0 +1 @@
+"""Tools for making requests to an API endpoint."""
diff --git a/libs/community/langchain_community/tools/requests/tool.py b/libs/community/langchain_community/tools/requests/tool.py
new file mode 100644
index 00000000000..ae2a2ca9546
--- /dev/null
+++ b/libs/community/langchain_community/tools/requests/tool.py
@@ -0,0 +1,184 @@
+# flake8: noqa
+"""Tools for making requests to an API endpoint."""
+import json
+from typing import Any, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+
+from langchain_community.utilities.requests import TextRequestsWrapper
+from langchain_core.tools import BaseTool
+
+
+def _parse_input(text: str) -> Dict[str, Any]:
+ """Parse the json string into a dict."""
+ return json.loads(text)
+
+
+def _clean_url(url: str) -> str:
+ """Strips quotes from the url."""
+ return url.strip("\"'")
+
+
+class BaseRequestsTool(BaseModel):
+ """Base class for requests tools."""
+
+ requests_wrapper: TextRequestsWrapper
+
+
+class RequestsGetTool(BaseRequestsTool, BaseTool):
+ """Tool for making a GET request to an API endpoint."""
+
+ name: str = "requests_get"
+ description: str = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request."
+
+ def _run(
+ self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ """Run the tool."""
+ return self.requests_wrapper.get(_clean_url(url))
+
+ async def _arun(
+ self,
+ url: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run the tool asynchronously."""
+ return await self.requests_wrapper.aget(_clean_url(url))
+
+
+class RequestsPostTool(BaseRequestsTool, BaseTool):
+ """Tool for making a POST request to an API endpoint."""
+
+ name: str = "requests_post"
+ description: str = """Use this when you want to POST to a website.
+ Input should be a json string with two keys: "url" and "data".
+ The value of "url" should be a string, and the value of "data" should be a dictionary of
+ key-value pairs you want to POST to the url.
+ Be careful to always use double quotes for strings in the json string
+ The output will be the text response of the POST request.
+ """
+
+ def _run(
+ self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ """Run the tool."""
+ try:
+ data = _parse_input(text)
+ return self.requests_wrapper.post(_clean_url(data["url"]), data["data"])
+ except Exception as e:
+ return repr(e)
+
+ async def _arun(
+ self,
+ text: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run the tool asynchronously."""
+ try:
+ data = _parse_input(text)
+ return await self.requests_wrapper.apost(
+ _clean_url(data["url"]), data["data"]
+ )
+ except Exception as e:
+ return repr(e)
+
+
+class RequestsPatchTool(BaseRequestsTool, BaseTool):
+ """Tool for making a PATCH request to an API endpoint."""
+
+ name: str = "requests_patch"
+ description: str = """Use this when you want to PATCH to a website.
+ Input should be a json string with two keys: "url" and "data".
+ The value of "url" should be a string, and the value of "data" should be a dictionary of
+ key-value pairs you want to PATCH to the url.
+ Be careful to always use double quotes for strings in the json string
+ The output will be the text response of the PATCH request.
+ """
+
+ def _run(
+ self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ """Run the tool."""
+ try:
+ data = _parse_input(text)
+ return self.requests_wrapper.patch(_clean_url(data["url"]), data["data"])
+ except Exception as e:
+ return repr(e)
+
+ async def _arun(
+ self,
+ text: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run the tool asynchronously."""
+ try:
+ data = _parse_input(text)
+ return await self.requests_wrapper.apatch(
+ _clean_url(data["url"]), data["data"]
+ )
+ except Exception as e:
+ return repr(e)
+
+
+class RequestsPutTool(BaseRequestsTool, BaseTool):
+ """Tool for making a PUT request to an API endpoint."""
+
+ name: str = "requests_put"
+ description: str = """Use this when you want to PUT to a website.
+ Input should be a json string with two keys: "url" and "data".
+ The value of "url" should be a string, and the value of "data" should be a dictionary of
+ key-value pairs you want to PUT to the url.
+ Be careful to always use double quotes for strings in the json string.
+ The output will be the text response of the PUT request.
+ """
+
+ def _run(
+ self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ """Run the tool."""
+ try:
+ data = _parse_input(text)
+ return self.requests_wrapper.put(_clean_url(data["url"]), data["data"])
+ except Exception as e:
+ return repr(e)
+
+ async def _arun(
+ self,
+ text: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run the tool asynchronously."""
+ try:
+ data = _parse_input(text)
+ return await self.requests_wrapper.aput(
+ _clean_url(data["url"]), data["data"]
+ )
+ except Exception as e:
+ return repr(e)
+
+
+class RequestsDeleteTool(BaseRequestsTool, BaseTool):
+ """Tool for making a DELETE request to an API endpoint."""
+
+ name: str = "requests_delete"
+ description: str = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request."
+
+ def _run(
+ self,
+ url: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run the tool."""
+ return self.requests_wrapper.delete(_clean_url(url))
+
+ async def _arun(
+ self,
+ url: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run the tool asynchronously."""
+ return await self.requests_wrapper.adelete(_clean_url(url))
diff --git a/libs/community/langchain_community/tools/scenexplain/__init__.py b/libs/community/langchain_community/tools/scenexplain/__init__.py
new file mode 100644
index 00000000000..2e6553b7356
--- /dev/null
+++ b/libs/community/langchain_community/tools/scenexplain/__init__.py
@@ -0,0 +1 @@
+"""SceneXplain API toolkit."""
diff --git a/libs/community/langchain_community/tools/scenexplain/tool.py b/libs/community/langchain_community/tools/scenexplain/tool.py
new file mode 100644
index 00000000000..5806cd7bb25
--- /dev/null
+++ b/libs/community/langchain_community/tools/scenexplain/tool.py
@@ -0,0 +1,32 @@
+"""Tool for the SceneXplain API."""
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.scenexplain import SceneXplainAPIWrapper
+
+
+class SceneXplainInput(BaseModel):
+ """Input for SceneXplain."""
+
+ query: str = Field(..., description="The link to the image to explain")
+
+
+class SceneXplainTool(BaseTool):
+ """Tool that explains images."""
+
+ name: str = "image_explainer"
+ description: str = (
+ "An Image Captioning Tool: Use this tool to generate a detailed caption "
+ "for an image. The input can be an image file of any format, and "
+ "the output will be a text description that covers every detail of the image."
+ )
+ api_wrapper: SceneXplainAPIWrapper = Field(default_factory=SceneXplainAPIWrapper)
+
+ def _run(
+ self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/searchapi/__init__.py b/libs/community/langchain_community/tools/searchapi/__init__.py
new file mode 100644
index 00000000000..7a89dfceffa
--- /dev/null
+++ b/libs/community/langchain_community/tools/searchapi/__init__.py
@@ -0,0 +1,6 @@
+from langchain_community.tools.searchapi.tool import SearchAPIResults, SearchAPIRun
+
+"""SearchApi.io API Toolkit."""
+"""Tool for the SearchApi.io Google SERP API."""
+
+__all__ = ["SearchAPIResults", "SearchAPIRun"]
diff --git a/libs/community/langchain_community/tools/searchapi/tool.py b/libs/community/langchain_community/tools/searchapi/tool.py
new file mode 100644
index 00000000000..5e7f5e8917c
--- /dev/null
+++ b/libs/community/langchain_community/tools/searchapi/tool.py
@@ -0,0 +1,69 @@
+"""Tool for the SearchApi.io search API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.searchapi import SearchApiAPIWrapper
+
+
+class SearchAPIRun(BaseTool):
+ """Tool that queries the SearchApi.io search API."""
+
+ name: str = "searchapi"
+ description: str = (
+ "Google search API provided by SearchApi.io."
+ "This tool is handy when you need to answer questions about current events."
+ "Input should be a search query."
+ )
+ api_wrapper: SearchApiAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ return await self.api_wrapper.arun(query)
+
+
+class SearchAPIResults(BaseTool):
+ """Tool that queries the SearchApi.io search API and returns JSON."""
+
+ name: str = "searchapi_results_json"
+ description: str = (
+ "Google search API provided by SearchApi.io."
+ "This tool is handy when you need to answer questions about current events."
+ "The input should be a search query and the output is a JSON object "
+ "with the query results."
+ )
+ api_wrapper: SearchApiAPIWrapper = Field(default_factory=SearchApiAPIWrapper)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return str(self.api_wrapper.results(query))
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ return (await self.api_wrapper.aresults(query)).__str__()
diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/__init__.py b/libs/community/langchain_community/tools/searx_search/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/vectorstores/qdrant/__init__.py
rename to libs/community/langchain_community/tools/searx_search/__init__.py
diff --git a/libs/community/langchain_community/tools/searx_search/tool.py b/libs/community/langchain_community/tools/searx_search/tool.py
new file mode 100644
index 00000000000..005faf239bf
--- /dev/null
+++ b/libs/community/langchain_community/tools/searx_search/tool.py
@@ -0,0 +1,77 @@
+"""Tool for the SearxNG search API."""
+from typing import Optional
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import Extra, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.searx_search import SearxSearchWrapper
+
+
+class SearxSearchRun(BaseTool):
+ """Tool that queries a Searx instance."""
+
+ name: str = "searx_search"
+ description: str = (
+ "A meta search engine."
+ "Useful for when you need to answer questions about current events."
+ "Input should be a search query."
+ )
+ wrapper: SearxSearchWrapper
+ kwargs: dict = Field(default_factory=dict)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return self.wrapper.run(query, **self.kwargs)
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ return await self.wrapper.arun(query, **self.kwargs)
+
+
+class SearxSearchResults(BaseTool):
+ """Tool that queries a Searx instance and gets back json."""
+
+ name: str = "Searx-Search-Results"
+ description: str = (
+ "A meta search engine."
+ "Useful for when you need to answer questions about current events."
+ "Input should be a search query. Output is a JSON array of the query results"
+ )
+ wrapper: SearxSearchWrapper
+ num_results: int = 4
+ kwargs: dict = Field(default_factory=dict)
+
+ class Config:
+ """Pydantic config."""
+
+ extra = Extra.allow
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ return str(self.wrapper.results(query, self.num_results, **self.kwargs))
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool asynchronously."""
+ return (
+ await self.wrapper.aresults(query, self.num_results, **self.kwargs)
+ ).__str__()
diff --git a/libs/community/langchain_community/tools/shell/__init__.py b/libs/community/langchain_community/tools/shell/__init__.py
new file mode 100644
index 00000000000..37e11d4b597
--- /dev/null
+++ b/libs/community/langchain_community/tools/shell/__init__.py
@@ -0,0 +1,5 @@
+"""Shell tool."""
+
+from langchain_community.tools.shell.tool import ShellTool
+
+__all__ = ["ShellTool"]
diff --git a/libs/community/langchain_community/tools/shell/tool.py b/libs/community/langchain_community/tools/shell/tool.py
new file mode 100644
index 00000000000..e92d51445aa
--- /dev/null
+++ b/libs/community/langchain_community/tools/shell/tool.py
@@ -0,0 +1,89 @@
+import asyncio
+import platform
+import warnings
+from typing import Any, List, Optional, Type, Union
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+from langchain_core.tools import BaseTool
+
+
+class ShellInput(BaseModel):
+ """Commands for the Bash Shell tool."""
+
+ commands: Union[str, List[str]] = Field(
+ ...,
+ description="List of shell commands to run. Deserialized using json.loads",
+ )
+ """List of shell commands to run."""
+
+ @root_validator
+ def _validate_commands(cls, values: dict) -> dict:
+ """Validate commands."""
+ # TODO: Add real validators
+ commands = values.get("commands")
+ if not isinstance(commands, list):
+ values["commands"] = [commands]
+ # Warn that the bash tool is not safe
+ warnings.warn(
+ "The shell tool has no safeguards by default. Use at your own risk."
+ )
+ return values
+
+
+def _get_default_bash_process() -> Any:
+ """Get default bash process."""
+ try:
+ from langchain_experimental.llm_bash.bash import BashProcess
+ except ImportError:
+ raise ImportError(
+ "BashProcess has been moved to langchain experimental."
+ "To use this tool, install langchain-experimental "
+ "with `pip install langchain-experimental`."
+ )
+ return BashProcess(return_err_output=True)
+
+
+def _get_platform() -> str:
+ """Get platform."""
+ system = platform.system()
+ if system == "Darwin":
+ return "MacOS"
+ return system
+
+
+class ShellTool(BaseTool):
+ """Tool to run shell commands."""
+
+ process: Any = Field(default_factory=_get_default_bash_process)
+ """Bash process to run commands."""
+
+ name: str = "terminal"
+ """Name of tool."""
+
+ description: str = f"Run shell commands on this {_get_platform()} machine."
+ """Description of tool."""
+
+ args_schema: Type[BaseModel] = ShellInput
+ """Schema for input arguments."""
+
+ def _run(
+ self,
+ commands: Union[str, List[str]],
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run commands and return final output."""
+ return self.process.run(commands)
+
+ async def _arun(
+ self,
+ commands: Union[str, List[str]],
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Run commands asynchronously and return final output."""
+ return await asyncio.get_event_loop().run_in_executor(
+ None, self.process.run, commands
+ )
diff --git a/libs/community/langchain_community/tools/slack/__init__.py b/libs/community/langchain_community/tools/slack/__init__.py
new file mode 100644
index 00000000000..b77e61619c1
--- /dev/null
+++ b/libs/community/langchain_community/tools/slack/__init__.py
@@ -0,0 +1,15 @@
+"""Slack tools."""
+
+from langchain_community.tools.slack.get_channel import SlackGetChannel
+from langchain_community.tools.slack.get_message import SlackGetMessage
+from langchain_community.tools.slack.schedule_message import SlackScheduleMessage
+from langchain_community.tools.slack.send_message import SlackSendMessage
+from langchain_community.tools.slack.utils import login
+
+__all__ = [
+ "SlackGetChannel",
+ "SlackGetMessage",
+ "SlackScheduleMessage",
+ "SlackSendMessage",
+ "login",
+]
diff --git a/libs/community/langchain_community/tools/slack/base.py b/libs/community/langchain_community/tools/slack/base.py
new file mode 100644
index 00000000000..1d16450ed8e
--- /dev/null
+++ b/libs/community/langchain_community/tools/slack/base.py
@@ -0,0 +1,19 @@
+"""Base class for Slack tools."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from langchain_core.pydantic_v1 import Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.slack.utils import login
+
+if TYPE_CHECKING:
+ from slack_sdk import WebClient
+
+
+class SlackBaseTool(BaseTool):
+ """Base class for Slack tools."""
+
+ client: WebClient = Field(default_factory=login)
+ """The WebClient object."""
diff --git a/libs/community/langchain_community/tools/slack/get_channel.py b/libs/community/langchain_community/tools/slack/get_channel.py
new file mode 100644
index 00000000000..3cb4f513ee2
--- /dev/null
+++ b/libs/community/langchain_community/tools/slack/get_channel.py
@@ -0,0 +1,34 @@
+import json
+import logging
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+
+from langchain_community.tools.slack.base import SlackBaseTool
+
+
+class SlackGetChannel(SlackBaseTool):
+ name: str = "get_channelid_name_dict"
+ description: str = "Use this tool to get channelid-name dict."
+
+ def _run(
+ self,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ logging.getLogger(__name__)
+
+ result = self.client.conversations_list()
+ channels = result["channels"]
+ filtered_result = [
+ {key: channel[key] for key in ("id", "name", "created", "num_members")}
+ for channel in channels
+ if "id" in channel
+ and "name" in channel
+ and "created" in channel
+ and "num_members" in channel
+ ]
+ return json.dumps(filtered_result)
+
+ except Exception as e:
+ return "Error creating conversation: {}".format(e)
diff --git a/libs/community/langchain_community/tools/slack/get_message.py b/libs/community/langchain_community/tools/slack/get_message.py
new file mode 100644
index 00000000000..a6504f38b76
--- /dev/null
+++ b/libs/community/langchain_community/tools/slack/get_message.py
@@ -0,0 +1,42 @@
+import json
+import logging
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.slack.base import SlackBaseTool
+
+
+class SlackGetMessageSchema(BaseModel):
+ """Input schema for SlackGetMessages."""
+
+ channel_id: str = Field(
+ ...,
+ description="The channel id, private group, or IM channel to send message to.",
+ )
+
+
+class SlackGetMessage(SlackBaseTool):
+ name: str = "get_messages"
+ description: str = "Use this tool to get messages from a channel."
+
+ args_schema: Type[SlackGetMessageSchema] = SlackGetMessageSchema
+
+ def _run(
+ self,
+ channel_id: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ logging.getLogger(__name__)
+ try:
+ result = self.client.conversations_history(channel=channel_id)
+ messages = result["messages"]
+ filtered_messages = [
+ {key: message[key] for key in ("user", "text", "ts")}
+ for message in messages
+ if "user" in message and "text" in message and "ts" in message
+ ]
+ return json.dumps(filtered_messages)
+ except Exception as e:
+ return "Error creating conversation: {}".format(e)
diff --git a/libs/community/langchain_community/tools/slack/schedule_message.py b/libs/community/langchain_community/tools/slack/schedule_message.py
new file mode 100644
index 00000000000..90edc9bcb9d
--- /dev/null
+++ b/libs/community/langchain_community/tools/slack/schedule_message.py
@@ -0,0 +1,60 @@
+import logging
+from datetime import datetime as dt
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.slack.base import SlackBaseTool
+from langchain_community.tools.slack.utils import UTC_FORMAT
+
+logger = logging.getLogger(__name__)
+
+
+class ScheduleMessageSchema(BaseModel):
+ """Input for ScheduleMessageTool."""
+
+ message: str = Field(
+ ...,
+ description="The message to be sent.",
+ )
+ channel: str = Field(
+ ...,
+ description="The channel, private group, or IM channel to send message to.",
+ )
+ timestamp: str = Field(
+ ...,
+ description="The datetime for when the message should be sent in the "
+ ' following format: YYYY-MM-DDTHH:MM:SSΒ±hh:mm, where "T" separates the date '
+ " and time components, and the time zone offset is specified as Β±hh:mm. "
+ ' For example: "2023-06-09T10:30:00+03:00" represents June 9th, '
+ " 2023, at 10:30 AM in a time zone with a positive offset of 3 "
+ " hours from Coordinated Universal Time (UTC).",
+ )
+
+
+class SlackScheduleMessage(SlackBaseTool):
+ """Tool for scheduling a message in Slack."""
+
+ name: str = "schedule_message"
+ description: str = (
+ "Use this tool to schedule a message to be sent on a specific date and time."
+ )
+ args_schema: Type[ScheduleMessageSchema] = ScheduleMessageSchema
+
+ def _run(
+ self,
+ message: str,
+ channel: str,
+ timestamp: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ unix_timestamp = dt.timestamp(dt.strptime(timestamp, UTC_FORMAT))
+ result = self.client.chat_scheduleMessage(
+ channel=channel, text=message, post_at=unix_timestamp
+ )
+ output = "Message scheduled: " + str(result)
+ return output
+ except Exception as e:
+ return "Error scheduling message: {}".format(e)
diff --git a/libs/community/langchain_community/tools/slack/send_message.py b/libs/community/langchain_community/tools/slack/send_message.py
new file mode 100644
index 00000000000..c5e8a875ad0
--- /dev/null
+++ b/libs/community/langchain_community/tools/slack/send_message.py
@@ -0,0 +1,42 @@
+from typing import Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.tools.slack.base import SlackBaseTool
+
+
+class SendMessageSchema(BaseModel):
+ """Input for SendMessageTool."""
+
+ message: str = Field(
+ ...,
+ description="The message to be sent.",
+ )
+ channel: str = Field(
+ ...,
+ description="The channel, private group, or IM channel to send message to.",
+ )
+
+
+class SlackSendMessage(SlackBaseTool):
+ """Tool for sending a message in Slack."""
+
+ name: str = "send_message"
+ description: str = (
+ "Use this tool to send a message with the provided message fields."
+ )
+ args_schema: Type[SendMessageSchema] = SendMessageSchema
+
+ def _run(
+ self,
+ message: str,
+ channel: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ try:
+ result = self.client.chat_postMessage(channel=channel, text=message)
+ output = "Message sent: " + str(result)
+ return output
+ except Exception as e:
+ return "Error creating conversation: {}".format(e)
diff --git a/libs/community/langchain_community/tools/slack/utils.py b/libs/community/langchain_community/tools/slack/utils.py
new file mode 100644
index 00000000000..1d614f6e9b3
--- /dev/null
+++ b/libs/community/langchain_community/tools/slack/utils.py
@@ -0,0 +1,42 @@
+"""Slack tool utils."""
+from __future__ import annotations
+
+import logging
+import os
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from slack_sdk import WebClient
+
+logger = logging.getLogger(__name__)
+
+
+def login() -> WebClient:
+ """Authenticate using the Slack API."""
+ try:
+ from slack_sdk import WebClient
+ except ImportError as e:
+ raise ImportError(
+ "Cannot import slack_sdk. Please install the package with \
+ `pip install slack_sdk`."
+ ) from e
+
+ if "SLACK_BOT_TOKEN" in os.environ:
+ token = os.environ["SLACK_BOT_TOKEN"]
+ client = WebClient(token=token)
+ logger.info("slack login success")
+ return client
+ elif "SLACK_USER_TOKEN" in os.environ:
+ token = os.environ["SLACK_USER_TOKEN"]
+ client = WebClient(token=token)
+ logger.info("slack login success")
+ return client
+ else:
+ logger.error(
+ "Error: The SLACK_BOT_TOKEN or SLACK_USER_TOKEN \
+ environment variable have not been set."
+ )
+
+
+UTC_FORMAT = "%Y-%m-%dT%H:%M:%S%z"
+"""UTC format for datetime objects."""
diff --git a/libs/community/langchain_community/tools/sleep/__init__.py b/libs/community/langchain_community/tools/sleep/__init__.py
new file mode 100644
index 00000000000..4d6319e2640
--- /dev/null
+++ b/libs/community/langchain_community/tools/sleep/__init__.py
@@ -0,0 +1 @@
+"""Sleep tool."""
diff --git a/libs/community/langchain_community/tools/sleep/tool.py b/libs/community/langchain_community/tools/sleep/tool.py
new file mode 100644
index 00000000000..d56a114ace4
--- /dev/null
+++ b/libs/community/langchain_community/tools/sleep/tool.py
@@ -0,0 +1,43 @@
+"""Tool for agent to sleep."""
+from asyncio import sleep as asleep
+from time import sleep
+from typing import Optional, Type
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+
+class SleepInput(BaseModel):
+ """Input for CopyFileTool."""
+
+ sleep_time: int = Field(..., description="Time to sleep in seconds")
+
+
+class SleepTool(BaseTool):
+ """Tool that adds the capability to sleep."""
+
+ name: str = "sleep"
+ args_schema: Type[BaseModel] = SleepInput
+ description: str = "Make agent sleep for a specified number of seconds."
+
+ def _run(
+ self,
+ sleep_time: int,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Sleep tool."""
+ sleep(sleep_time)
+ return f"Agent slept for {sleep_time} seconds."
+
+ async def _arun(
+ self,
+ sleep_time: int,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the sleep tool asynchronously."""
+ await asleep(sleep_time)
+ return f"Agent slept for {sleep_time} seconds."
diff --git a/libs/community/langchain_community/tools/spark_sql/__init__.py b/libs/community/langchain_community/tools/spark_sql/__init__.py
new file mode 100644
index 00000000000..01039b772c6
--- /dev/null
+++ b/libs/community/langchain_community/tools/spark_sql/__init__.py
@@ -0,0 +1 @@
+"""Tools for interacting with Spark SQL."""
diff --git a/libs/community/langchain_community/tools/spark_sql/prompt.py b/libs/community/langchain_community/tools/spark_sql/prompt.py
new file mode 100644
index 00000000000..98a523b88cf
--- /dev/null
+++ b/libs/community/langchain_community/tools/spark_sql/prompt.py
@@ -0,0 +1,14 @@
+# flake8: noqa
+QUERY_CHECKER = """
+{query}
+Double check the Spark SQL query above for common mistakes, including:
+- Using NOT IN with NULL values
+- Using UNION when UNION ALL should have been used
+- Using BETWEEN for exclusive ranges
+- Data type mismatch in predicates
+- Properly quoting identifiers
+- Using the correct number of arguments for functions
+- Casting to the correct data type
+- Using the proper columns for joins
+
+If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""
diff --git a/libs/community/langchain_community/tools/spark_sql/tool.py b/libs/community/langchain_community/tools/spark_sql/tool.py
new file mode 100644
index 00000000000..9110d543fdf
--- /dev/null
+++ b/libs/community/langchain_community/tools/spark_sql/tool.py
@@ -0,0 +1,131 @@
+# flake8: noqa
+"""Tools for interacting with Spark SQL."""
+from typing import Any, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.prompts import PromptTemplate
+from langchain_community.utilities.spark_sql import SparkSQL
+from langchain_core.tools import BaseTool
+from langchain_community.tools.spark_sql.prompt import QUERY_CHECKER
+
+
+class BaseSparkSQLTool(BaseModel):
+ """Base tool for interacting with Spark SQL."""
+
+ db: SparkSQL = Field(exclude=True)
+
+ class Config(BaseTool.Config):
+ pass
+
+
+class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
+ """Tool for querying a Spark SQL."""
+
+ name: str = "query_sql_db"
+ description: str = """
+ Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL.
+ If the query is not correct, an error message will be returned.
+ If an error is returned, rewrite the query, check the query, and try again.
+ """
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Execute the query, return the results or an error message."""
+ return self.db.run_no_throw(query)
+
+
+class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
+ """Tool for getting metadata about a Spark SQL."""
+
+ name: str = "schema_sql_db"
+ description: str = """
+ Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
+ Be sure that the tables actually exist by calling list_tables_sql_db first!
+
+ Example Input: "table1, table2, table3"
+ """
+
+ def _run(
+ self,
+ table_names: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Get the schema for tables in a comma-separated list."""
+ return self.db.get_table_info_no_throw(table_names.split(", "))
+
+
+class ListSparkSQLTool(BaseSparkSQLTool, BaseTool):
+ """Tool for getting tables names."""
+
+ name: str = "list_tables_sql_db"
+ description: str = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."
+
+ def _run(
+ self,
+ tool_input: str = "",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Get the schema for a specific table."""
+ return ", ".join(self.db.get_usable_table_names())
+
+
+class QueryCheckerTool(BaseSparkSQLTool, BaseTool):
+ """Use an LLM to check if a query is correct.
+ Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
+
+ template: str = QUERY_CHECKER
+ llm: BaseLanguageModel
+ llm_chain: Any = Field(init=False)
+ name: str = "query_checker_sql_db"
+ description: str = """
+ Use this tool to double check if your query is correct before executing it.
+ Always use this tool before executing a query with query_sql_db!
+ """
+
+ @root_validator(pre=True)
+ def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ if "llm_chain" not in values:
+ from langchain.chains.llm import LLMChain
+
+ values["llm_chain"] = LLMChain(
+ llm=values.get("llm"),
+ prompt=PromptTemplate(
+ template=QUERY_CHECKER, input_variables=["query"]
+ ),
+ )
+
+ if values["llm_chain"].prompt.input_variables != ["query"]:
+ raise ValueError(
+ "LLM chain for QueryCheckerTool need to use ['query'] as input_variables "
+ "for the embedded prompt"
+ )
+
+ return values
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the LLM to check the query."""
+ return self.llm_chain.predict(
+ query=query, callbacks=run_manager.get_child() if run_manager else None
+ )
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ return await self.llm_chain.apredict(
+ query=query, callbacks=run_manager.get_child() if run_manager else None
+ )
diff --git a/libs/community/langchain_community/tools/sql_database/__init__.py b/libs/community/langchain_community/tools/sql_database/__init__.py
new file mode 100644
index 00000000000..90fb3be1322
--- /dev/null
+++ b/libs/community/langchain_community/tools/sql_database/__init__.py
@@ -0,0 +1 @@
+"""Tools for interacting with a SQL database."""
diff --git a/libs/community/langchain_community/tools/sql_database/prompt.py b/libs/community/langchain_community/tools/sql_database/prompt.py
new file mode 100644
index 00000000000..34ab0fd3b16
--- /dev/null
+++ b/libs/community/langchain_community/tools/sql_database/prompt.py
@@ -0,0 +1,18 @@
+# flake8: noqa
+QUERY_CHECKER = """
+{query}
+Double check the {dialect} query above for common mistakes, including:
+- Using NOT IN with NULL values
+- Using UNION when UNION ALL should have been used
+- Using BETWEEN for exclusive ranges
+- Data type mismatch in predicates
+- Properly quoting identifiers
+- Using the correct number of arguments for functions
+- Casting to the correct data type
+- Using the proper columns for joins
+
+If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
+
+Output the final SQL query only.
+
+SQL Query: """
diff --git a/libs/community/langchain_community/tools/sql_database/tool.py b/libs/community/langchain_community/tools/sql_database/tool.py
new file mode 100644
index 00000000000..91dcffa04b0
--- /dev/null
+++ b/libs/community/langchain_community/tools/sql_database/tool.py
@@ -0,0 +1,135 @@
+# flake8: noqa
+"""Tools for interacting with a SQL database."""
+from typing import Any, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.prompts import PromptTemplate
+from langchain_community.utilities.sql_database import SQLDatabase
+from langchain_core.tools import BaseTool
+from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
+
+
+class BaseSQLDatabaseTool(BaseModel):
+ """Base tool for interacting with a SQL database."""
+
+ db: SQLDatabase = Field(exclude=True)
+
+ class Config(BaseTool.Config):
+ pass
+
+
+class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
+ """Tool for querying a SQL database."""
+
+ name: str = "sql_db_query"
+ description: str = """
+ Input to this tool is a detailed and correct SQL query, output is a result from the database.
+ If the query is not correct, an error message will be returned.
+ If an error is returned, rewrite the query, check the query, and try again.
+ """
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Execute the query, return the results or an error message."""
+ return self.db.run_no_throw(query)
+
+
+class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
+ """Tool for getting metadata about a SQL database."""
+
+ name: str = "sql_db_schema"
+ description: str = """
+ Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
+
+ Example Input: "table1, table2, table3"
+ """
+
+ def _run(
+ self,
+ table_names: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Get the schema for tables in a comma-separated list."""
+ return self.db.get_table_info_no_throw(
+ [t.strip() for t in table_names.split(",")]
+ )
+
+
+class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
+ """Tool for getting tables names."""
+
+ name: str = "sql_db_list_tables"
+ description: str = "Input is an empty string, output is a comma separated list of tables in the database."
+
+ def _run(
+ self,
+ tool_input: str = "",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Get the schema for a specific table."""
+ return ", ".join(self.db.get_usable_table_names())
+
+
+class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
+ """Use an LLM to check if a query is correct.
+ Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
+
+ template: str = QUERY_CHECKER
+ llm: BaseLanguageModel
+ llm_chain: Any = Field(init=False)
+ name: str = "sql_db_query_checker"
+ description: str = """
+ Use this tool to double check if your query is correct before executing it.
+ Always use this tool before executing a query with sql_db_query!
+ """
+
+ @root_validator(pre=True)
+ def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ if "llm_chain" not in values:
+ from langchain.chains.llm import LLMChain
+
+ values["llm_chain"] = LLMChain(
+ llm=values.get("llm"),
+ prompt=PromptTemplate(
+ template=QUERY_CHECKER, input_variables=["dialect", "query"]
+ ),
+ )
+
+ if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
+ raise ValueError(
+ "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
+ )
+
+ return values
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the LLM to check the query."""
+ return self.llm_chain.predict(
+ query=query,
+ dialect=self.db.dialect,
+ callbacks=run_manager.get_child() if run_manager else None,
+ )
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ return await self.llm_chain.apredict(
+ query=query,
+ dialect=self.db.dialect,
+ callbacks=run_manager.get_child() if run_manager else None,
+ )
diff --git a/libs/community/langchain_community/tools/stackexchange/__init__.py b/libs/community/langchain_community/tools/stackexchange/__init__.py
new file mode 100644
index 00000000000..1fa9a483d10
--- /dev/null
+++ b/libs/community/langchain_community/tools/stackexchange/__init__.py
@@ -0,0 +1 @@
+"""StackExchange API toolkit."""
diff --git a/libs/community/langchain_community/tools/stackexchange/tool.py b/libs/community/langchain_community/tools/stackexchange/tool.py
new file mode 100644
index 00000000000..fa398b0f148
--- /dev/null
+++ b/libs/community/langchain_community/tools/stackexchange/tool.py
@@ -0,0 +1,29 @@
+"""Tool for the Wikipedia API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.stackexchange import StackExchangeAPIWrapper
+
+
+class StackExchangeTool(BaseTool):
+ """Tool that uses StackExchange"""
+
+ name: str = "StackExchange"
+ description: str = (
+ "A wrapper around StackExchange. "
+ "Useful for when you need to answer specific programming questions"
+ "code excerpts, code examples and solutions"
+ "Input should be a fully formed question."
+ )
+ api_wrapper: StackExchangeAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Stack Exchange tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/steam/__init__.py b/libs/community/langchain_community/tools/steam/__init__.py
new file mode 100644
index 00000000000..9367fd95b30
--- /dev/null
+++ b/libs/community/langchain_community/tools/steam/__init__.py
@@ -0,0 +1 @@
+"""Steam API toolkit"""
diff --git a/libs/community/langchain_community/tools/steam/prompt.py b/libs/community/langchain_community/tools/steam/prompt.py
new file mode 100644
index 00000000000..6f82e2ff4f2
--- /dev/null
+++ b/libs/community/langchain_community/tools/steam/prompt.py
@@ -0,0 +1,26 @@
+STEAM_GET_GAMES_DETAILS = """
+ This tool is a wrapper around python-steam-api's steam.apps.search_games API and
+ steam.apps.get_app_details API, useful when you need to search for a game.
+ The input to this tool is a string specifying the name of the game you want to
+ search for. For example, to search for a game called "Counter-Strike: Global
+ Offensive", you would input "Counter-Strike: Global Offensive" as the game name.
+ This input will be passed into steam.apps.search_games to find the game id, link
+ and price, and then the game id will be passed into steam.apps.get_app_details to
+ get the detailed description and supported languages of the game. Finally the
+ results are combined and returned as a string.
+"""
+
+STEAM_GET_RECOMMENDED_GAMES = """
+ This tool is a wrapper around python-steam-api's steam.users.get_owned_games API
+ and steamspypi's steamspypi.download API, useful when you need to get a list of
+ recommended games. The input to this tool is a string specifying the steam id of
+ the user you want to get recommended games for. For example, to get recommended
+ games for a user with steam id 76561197960435530, you would input
+ "76561197960435530" as the steam id. This steamid is then utilized to form a
+ data_request sent to steamspypi's steamspypi.download to retrieve genres of user's
+ owned games. Then, calculates the frequency of each genre, identifying the most
+ popular one, and stored it in a dictionary. Subsequently, use steamspypi.download
+ to returns all games in this genre and return 5 most-played games that is not owned
+ by the user.
+
+"""
diff --git a/libs/community/langchain_community/tools/steam/tool.py b/libs/community/langchain_community/tools/steam/tool.py
new file mode 100644
index 00000000000..69706c8ada9
--- /dev/null
+++ b/libs/community/langchain_community/tools/steam/tool.py
@@ -0,0 +1,30 @@
+"""Tool for Steam Web API"""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.steam import SteamWebAPIWrapper
+
+
+class SteamWebAPIQueryRun(BaseTool):
+ """Tool that searches the Steam Web API."""
+
+ mode: str
+ name: str = "Steam"
+ description: str = (
+ "A wrapper around Steam Web API."
+ "Steam Tool is useful for fetching User profiles and stats, Game data and more!"
+ "Input should be the User or Game you want to query."
+ )
+
+ api_wrapper: SteamWebAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Steam-WebAPI tool."""
+ return self.api_wrapper.run(self.mode, query)
diff --git a/libs/community/langchain_community/tools/steamship_image_generation/__init__.py b/libs/community/langchain_community/tools/steamship_image_generation/__init__.py
new file mode 100644
index 00000000000..e21672c72dc
--- /dev/null
+++ b/libs/community/langchain_community/tools/steamship_image_generation/__init__.py
@@ -0,0 +1,7 @@
+"""Tool to generate an image."""
+
+from langchain_community.tools.steamship_image_generation.tool import (
+ SteamshipImageGenerationTool,
+)
+
+__all__ = ["SteamshipImageGenerationTool"]
diff --git a/libs/community/langchain_community/tools/steamship_image_generation/tool.py b/libs/community/langchain_community/tools/steamship_image_generation/tool.py
new file mode 100644
index 00000000000..145f362785a
--- /dev/null
+++ b/libs/community/langchain_community/tools/steamship_image_generation/tool.py
@@ -0,0 +1,113 @@
+"""This tool allows agents to generate images using Steamship.
+
+Steamship offers access to different third party image generation APIs
+using a single API key.
+
+Today the following models are supported:
+- Dall-E
+- Stable Diffusion
+
+To use this tool, you must first set as environment variables:
+ STEAMSHIP_API_KEY
+```
+"""
+from __future__ import annotations
+
+from enum import Enum
+from typing import TYPE_CHECKING, Dict, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+from langchain_community.tools import BaseTool
+from langchain_community.tools.steamship_image_generation.utils import make_image_public
+
+if TYPE_CHECKING:
+ from steamship import Steamship
+
+
+class ModelName(str, Enum):
+ """Supported Image Models for generation."""
+
+ DALL_E = "dall-e"
+ STABLE_DIFFUSION = "stable-diffusion"
+
+
+SUPPORTED_IMAGE_SIZES = {
+ ModelName.DALL_E: ("256x256", "512x512", "1024x1024"),
+ ModelName.STABLE_DIFFUSION: ("512x512", "768x768"),
+}
+
+
+class SteamshipImageGenerationTool(BaseTool):
+
+ """Tool used to generate images from a text-prompt."""
+
+ model_name: ModelName
+ size: Optional[str] = "512x512"
+ steamship: Steamship
+ return_urls: Optional[bool] = False
+
+ name: str = "GenerateImage"
+ description: str = (
+ "Useful for when you need to generate an image."
+ "Input: A detailed text-2-image prompt describing an image"
+ "Output: the UUID of a generated image"
+ )
+
+ @root_validator(pre=True)
+ def validate_size(cls, values: Dict) -> Dict:
+ if "size" in values:
+ size = values["size"]
+ model_name = values["model_name"]
+ if size not in SUPPORTED_IMAGE_SIZES[model_name]:
+ raise RuntimeError(f"size {size} is not supported by {model_name}")
+
+ return values
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ steamship_api_key = get_from_dict_or_env(
+ values, "steamship_api_key", "STEAMSHIP_API_KEY"
+ )
+
+ try:
+ from steamship import Steamship
+ except ImportError:
+ raise ImportError(
+ "steamship is not installed. "
+ "Please install it with `pip install steamship`"
+ )
+
+ steamship = Steamship(
+ api_key=steamship_api_key,
+ )
+ values["steamship"] = steamship
+ if "steamship_api_key" in values:
+ del values["steamship_api_key"]
+
+ return values
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+
+ image_generator = self.steamship.use_plugin(
+ plugin_handle=self.model_name.value, config={"n": 1, "size": self.size}
+ )
+
+ task = image_generator.generate(text=query, append_output_to_file=True)
+ task.wait()
+ blocks = task.output.blocks
+ if len(blocks) > 0:
+ if self.return_urls:
+ return make_image_public(self.steamship, blocks[0])
+ else:
+ return blocks[0].id
+
+ raise RuntimeError(f"[{self.name}] Tool unable to generate image!")
diff --git a/libs/community/langchain_community/tools/steamship_image_generation/utils.py b/libs/community/langchain_community/tools/steamship_image_generation/utils.py
new file mode 100644
index 00000000000..e89014e58c2
--- /dev/null
+++ b/libs/community/langchain_community/tools/steamship_image_generation/utils.py
@@ -0,0 +1,47 @@
+"""Steamship Utils."""
+from __future__ import annotations
+
+import uuid
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from steamship import Block, Steamship
+
+
+def make_image_public(client: Steamship, block: Block) -> str:
+ """Upload a block to a signed URL and return the public URL."""
+ try:
+ from steamship.data.workspace import SignedUrl
+ from steamship.utils.signed_urls import upload_to_signed_url
+ except ImportError:
+ raise ImportError(
+ "The make_image_public function requires the steamship"
+ " package to be installed. Please install steamship"
+ " with `pip install --upgrade steamship`"
+ )
+
+ filepath = str(uuid.uuid4())
+ signed_url = (
+ client.get_workspace()
+ .create_signed_url(
+ SignedUrl.Request(
+ bucket=SignedUrl.Bucket.PLUGIN_DATA,
+ filepath=filepath,
+ operation=SignedUrl.Operation.WRITE,
+ )
+ )
+ .signed_url
+ )
+ read_signed_url = (
+ client.get_workspace()
+ .create_signed_url(
+ SignedUrl.Request(
+ bucket=SignedUrl.Bucket.PLUGIN_DATA,
+ filepath=filepath,
+ operation=SignedUrl.Operation.READ,
+ )
+ )
+ .signed_url
+ )
+ upload_to_signed_url(signed_url, block.raw())
+ return read_signed_url
diff --git a/libs/community/langchain_community/tools/tavily_search/__init__.py b/libs/community/langchain_community/tools/tavily_search/__init__.py
new file mode 100644
index 00000000000..7c8ad700997
--- /dev/null
+++ b/libs/community/langchain_community/tools/tavily_search/__init__.py
@@ -0,0 +1,8 @@
+"""Tavily Search API toolkit."""
+
+from langchain_community.tools.tavily_search.tool import (
+ TavilyAnswer,
+ TavilySearchResults,
+)
+
+__all__ = ["TavilySearchResults", "TavilyAnswer"]
diff --git a/libs/community/langchain_community/tools/tavily_search/tool.py b/libs/community/langchain_community/tools/tavily_search/tool.py
new file mode 100644
index 00000000000..af33cacc4f7
--- /dev/null
+++ b/libs/community/langchain_community/tools/tavily_search/tool.py
@@ -0,0 +1,105 @@
+"""Tool for the Tavily search API."""
+
+from typing import Dict, List, Optional, Type, Union
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
+
+
+class TavilyInput(BaseModel):
+ query: str = Field(description="search query to look up")
+
+
+class TavilySearchResults(BaseTool):
+ """Tool that queries the Tavily Search API and gets back json."""
+
+ name: str = "tavily_search_results_json"
+ description: str = (
+ "A search engine optimized for comprehensive, accurate, and trusted results. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query."
+ )
+ api_wrapper: TavilySearchAPIWrapper
+ max_results: int = 5
+ args_schema: Type[BaseModel] = TavilyInput
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> Union[List[Dict], str]:
+ """Use the tool."""
+ try:
+ return self.api_wrapper.results(
+ query,
+ self.max_results,
+ )
+ except Exception as e:
+ return repr(e)
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> Union[List[Dict], str]:
+ """Use the tool asynchronously."""
+ try:
+ return await self.api_wrapper.results_async(
+ query,
+ self.max_results,
+ )
+ except Exception as e:
+ return repr(e)
+
+
+class TavilyAnswer(BaseTool):
+ """Tool that queries the Tavily Search API and gets back an answer."""
+
+ name: str = "tavily_answer"
+ description: str = (
+ "A search engine optimized for comprehensive, accurate, and trusted results. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query. "
+ "This returns only the answer - not the original source data."
+ )
+ api_wrapper: TavilySearchAPIWrapper
+ args_schema: Type[BaseModel] = TavilyInput
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> Union[List[Dict], str]:
+ """Use the tool."""
+ try:
+ return self.api_wrapper.raw_results(
+ query,
+ max_results=5,
+ include_answer=True,
+ search_depth="basic",
+ )["answer"]
+ except Exception as e:
+ return repr(e)
+
+ async def _arun(
+ self,
+ query: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> Union[List[Dict], str]:
+ """Use the tool asynchronously."""
+ try:
+ result = await self.api_wrapper.raw_results_async(
+ query,
+ max_results=5,
+ include_answer=True,
+ search_depth="basic",
+ )
+ return result["answer"]
+ except Exception as e:
+ return repr(e)
diff --git a/libs/community/langchain_community/tools/vectorstore/__init__.py b/libs/community/langchain_community/tools/vectorstore/__init__.py
new file mode 100644
index 00000000000..2bb63810195
--- /dev/null
+++ b/libs/community/langchain_community/tools/vectorstore/__init__.py
@@ -0,0 +1 @@
+"""Simple tool wrapper around VectorDBQA chain."""
diff --git a/libs/community/langchain_community/tools/vectorstore/tool.py b/libs/community/langchain_community/tools/vectorstore/tool.py
new file mode 100644
index 00000000000..e51deaeaa1d
--- /dev/null
+++ b/libs/community/langchain_community/tools/vectorstore/tool.py
@@ -0,0 +1,95 @@
+"""Tools for interacting with vectorstores."""
+
+import json
+from typing import Any, Dict, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.language_models import BaseLanguageModel
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_core.tools import BaseTool
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.llms.openai import OpenAI
+
+
+class BaseVectorStoreTool(BaseModel):
+ """Base class for tools that use a VectorStore."""
+
+ vectorstore: VectorStore = Field(exclude=True)
+ llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
+
+ class Config(BaseTool.Config):
+ pass
+
+
+def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
+ values["description"] = values["template"].format(name=values["name"])
+ return values
+
+
+class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
+ """Tool for the VectorDBQA chain. To be initialized with name and chain."""
+
+ @staticmethod
+ def get_description(name: str, description: str) -> str:
+ template: str = (
+ "Useful for when you need to answer questions about {name}. "
+ "Whenever you need information about {description} "
+ "you should ALWAYS use this. "
+ "Input should be a fully formed question."
+ )
+ return template.format(name=name, description=description)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ from langchain.chains.retrieval_qa.base import RetrievalQA
+
+ chain = RetrievalQA.from_chain_type(
+ self.llm, retriever=self.vectorstore.as_retriever()
+ )
+ return chain.run(
+ query, callbacks=run_manager.get_child() if run_manager else None
+ )
+
+
+class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
+ """Tool for the VectorDBQAWithSources chain."""
+
+ @staticmethod
+ def get_description(name: str, description: str) -> str:
+ template: str = (
+ "Useful for when you need to answer questions about {name} and the sources "
+ "used to construct the answer. "
+ "Whenever you need information about {description} "
+ "you should ALWAYS use this. "
+ " Input should be a fully formed question. "
+ "Output is a json serialized dictionary with keys `answer` and `sources`. "
+ "Only use this tool if the user explicitly asks for sources."
+ )
+ return template.format(name=name, description=description)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+
+ from langchain.chains.qa_with_sources.retrieval import (
+ RetrievalQAWithSourcesChain,
+ )
+
+ chain = RetrievalQAWithSourcesChain.from_chain_type(
+ self.llm, retriever=self.vectorstore.as_retriever()
+ )
+ return json.dumps(
+ chain(
+ {chain.question_key: query},
+ return_only_outputs=True,
+ callbacks=run_manager.get_child() if run_manager else None,
+ )
+ )
diff --git a/libs/community/langchain_community/tools/wikipedia/__init__.py b/libs/community/langchain_community/tools/wikipedia/__init__.py
new file mode 100644
index 00000000000..0b3edd08387
--- /dev/null
+++ b/libs/community/langchain_community/tools/wikipedia/__init__.py
@@ -0,0 +1 @@
+"""Wikipedia API toolkit."""
diff --git a/libs/community/langchain_community/tools/wikipedia/tool.py b/libs/community/langchain_community/tools/wikipedia/tool.py
new file mode 100644
index 00000000000..0ccab574f23
--- /dev/null
+++ b/libs/community/langchain_community/tools/wikipedia/tool.py
@@ -0,0 +1,29 @@
+"""Tool for the Wikipedia API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
+
+
+class WikipediaQueryRun(BaseTool):
+ """Tool that searches the Wikipedia API."""
+
+ name: str = "Wikipedia"
+ description: str = (
+ "A wrapper around Wikipedia. "
+ "Useful for when you need to answer general questions about "
+ "people, places, companies, facts, historical events, or other subjects. "
+ "Input should be a search query."
+ )
+ api_wrapper: WikipediaAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Wikipedia tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/wolfram_alpha/__init__.py b/libs/community/langchain_community/tools/wolfram_alpha/__init__.py
new file mode 100644
index 00000000000..f5ec860d0ac
--- /dev/null
+++ b/libs/community/langchain_community/tools/wolfram_alpha/__init__.py
@@ -0,0 +1,8 @@
+"""Wolfram Alpha API toolkit."""
+
+
+from langchain_community.tools.wolfram_alpha.tool import WolframAlphaQueryRun
+
+__all__ = [
+ "WolframAlphaQueryRun",
+]
diff --git a/libs/community/langchain_community/tools/wolfram_alpha/tool.py b/libs/community/langchain_community/tools/wolfram_alpha/tool.py
new file mode 100644
index 00000000000..e4364e669a0
--- /dev/null
+++ b/libs/community/langchain_community/tools/wolfram_alpha/tool.py
@@ -0,0 +1,29 @@
+"""Tool for the Wolfram Alpha API."""
+
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.tools import BaseTool
+
+from langchain_community.utilities.wolfram_alpha import WolframAlphaAPIWrapper
+
+
+class WolframAlphaQueryRun(BaseTool):
+ """Tool that queries using the Wolfram Alpha SDK."""
+
+ name: str = "wolfram_alpha"
+ description: str = (
+ "A wrapper around Wolfram Alpha. "
+ "Useful for when you need to answer questions about Math, "
+ "Science, Technology, Culture, Society and Everyday Life. "
+ "Input should be a search query."
+ )
+ api_wrapper: WolframAlphaAPIWrapper
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the WolframAlpha tool."""
+ return self.api_wrapper.run(query)
diff --git a/libs/community/langchain_community/tools/yahoo_finance_news.py b/libs/community/langchain_community/tools/yahoo_finance_news.py
new file mode 100644
index 00000000000..c470fa01d2b
--- /dev/null
+++ b/libs/community/langchain_community/tools/yahoo_finance_news.py
@@ -0,0 +1,67 @@
+from typing import Iterable, Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+from langchain_core.documents import Document
+from langchain_core.tools import BaseTool
+from requests.exceptions import HTTPError, ReadTimeout
+from urllib3.exceptions import ConnectionError
+
+from langchain_community.document_loaders.web_base import WebBaseLoader
+
+
+class YahooFinanceNewsTool(BaseTool):
+ """Tool that searches financial news on Yahoo Finance."""
+
+ name: str = "yahoo_finance_news"
+ description: str = (
+ "Useful for when you need to find financial news "
+ "about a public company. "
+ "Input should be a company ticker. "
+ "For example, AAPL for Apple, MSFT for Microsoft."
+ )
+ top_k: int = 10
+ """The number of results to return."""
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Yahoo Finance News tool."""
+ try:
+ import yfinance
+ except ImportError:
+ raise ImportError(
+ "Could not import yfinance python package. "
+ "Please install it with `pip install yfinance`."
+ )
+ company = yfinance.Ticker(query)
+ try:
+ if company.isin is None:
+ return f"Company ticker {query} not found."
+ except (HTTPError, ReadTimeout, ConnectionError):
+ return f"Company ticker {query} not found."
+
+ links = []
+ try:
+ links = [n["link"] for n in company.news if n["type"] == "STORY"]
+ except (HTTPError, ReadTimeout, ConnectionError):
+ if not links:
+ return f"No news found for company that searched with {query} ticker."
+ if not links:
+ return f"No news found for company that searched with {query} ticker."
+ loader = WebBaseLoader(web_paths=links)
+ docs = loader.load()
+ result = self._format_results(docs, query)
+ if not result:
+ return f"No news found for company that searched with {query} ticker."
+ return result
+
+ @staticmethod
+ def _format_results(docs: Iterable[Document], query: str) -> str:
+ doc_strings = [
+ "\n".join([doc.metadata["title"], doc.metadata["description"]])
+ for doc in docs
+ if query in doc.metadata["description"] or query in doc.metadata["title"]
+ ]
+ return "\n\n".join(doc_strings)
diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/__init__.py b/libs/community/langchain_community/tools/youtube/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/__init__.py
rename to libs/community/langchain_community/tools/youtube/__init__.py
diff --git a/libs/community/langchain_community/tools/youtube/search.py b/libs/community/langchain_community/tools/youtube/search.py
new file mode 100644
index 00000000000..580815c3922
--- /dev/null
+++ b/libs/community/langchain_community/tools/youtube/search.py
@@ -0,0 +1,53 @@
+"""
+Adapted from https://github.com/venuv/langchain_yt_tools
+
+CustomYTSearchTool searches YouTube videos related to a person
+and returns a specified number of video URLs.
+Input to this tool should be a comma separated list,
+ - the first part contains a person name
+ - and the second(optional) a number that is the
+ maximum number of video results to return
+ """
+import json
+from typing import Optional
+
+from langchain_core.callbacks import CallbackManagerForToolRun
+
+from langchain_community.tools import BaseTool
+
+
+class YouTubeSearchTool(BaseTool):
+ """Tool that queries YouTube."""
+
+ name: str = "youtube_search"
+ description: str = (
+ "search for youtube videos associated with a person. "
+ "the input to this tool should be a comma separated list, "
+ "the first part contains a person name and the second a "
+ "number that is the maximum number of video results "
+ "to return aka num_results. the second part is optional"
+ )
+
+ def _search(self, person: str, num_results: int) -> str:
+ from youtube_search import YoutubeSearch
+
+ results = YoutubeSearch(person, num_results).to_json()
+ data = json.loads(results)
+ url_suffix_list = [
+ "https://www.youtube.com" + video["url_suffix"] for video in data["videos"]
+ ]
+ return str(url_suffix_list)
+
+ def _run(
+ self,
+ query: str,
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the tool."""
+ values = query.split(",")
+ person = values[0]
+ if len(values) > 1:
+ num_results = int(values[1])
+ else:
+ num_results = 2
+ return self._search(person, num_results)
diff --git a/libs/community/langchain_community/tools/zapier/__init__.py b/libs/community/langchain_community/tools/zapier/__init__.py
new file mode 100644
index 00000000000..d7f2c588844
--- /dev/null
+++ b/libs/community/langchain_community/tools/zapier/__init__.py
@@ -0,0 +1,11 @@
+"""Zapier Tool."""
+
+from langchain_community.tools.zapier.tool import (
+ ZapierNLAListActions,
+ ZapierNLARunAction,
+)
+
+__all__ = [
+ "ZapierNLARunAction",
+ "ZapierNLAListActions",
+]
diff --git a/libs/community/langchain_community/tools/zapier/prompt.py b/libs/community/langchain_community/tools/zapier/prompt.py
new file mode 100644
index 00000000000..063e3952ef2
--- /dev/null
+++ b/libs/community/langchain_community/tools/zapier/prompt.py
@@ -0,0 +1,15 @@
+# flake8: noqa
+BASE_ZAPIER_TOOL_PROMPT = (
+ "A wrapper around Zapier NLA actions. "
+ "The input to this tool is a natural language instruction, "
+ 'for example "get the latest email from my bank" or '
+ '"send a slack message to the #general channel". '
+ "Each tool will have params associated with it that are specified as a list. You MUST take into account the params when creating the instruction. "
+ "For example, if the params are ['Message_Text', 'Channel'], your instruction should be something like 'send a slack message to the #general channel with the text hello world'. "
+ "Another example: if the params are ['Calendar', 'Search_Term'], your instruction should be something like 'find the meeting in my personal calendar at 3pm'. "
+ "Do not make up params, they will be explicitly specified in the tool description. "
+ "If you do not have enough information to fill in the params, just say 'not enough information provided in the instruction, missing '. "
+ "If you get a none or null response, STOP EXECUTION, do not try to another tool!"
+ "This tool specifically used for: {zapier_description}, "
+ "and has params: {params}"
+)
diff --git a/libs/community/langchain_community/tools/zapier/tool.py b/libs/community/langchain_community/tools/zapier/tool.py
new file mode 100644
index 00000000000..3d5f3955546
--- /dev/null
+++ b/libs/community/langchain_community/tools/zapier/tool.py
@@ -0,0 +1,215 @@
+"""[DEPRECATED]
+
+## Zapier Natural Language Actions API
+\
+Full docs here: https://nla.zapier.com/start/
+
+**Zapier Natural Language Actions** gives you access to the 5k+ apps, 20k+ actions
+on Zapier's platform through a natural language API interface.
+
+NLA supports apps like Gmail, Salesforce, Trello, Slack, Asana, HubSpot, Google Sheets,
+Microsoft Teams, and thousands more apps: https://zapier.com/apps
+
+Zapier NLA handles ALL the underlying API auth and translation from
+natural language --> underlying API call --> return simplified output for LLMs
+The key idea is you, or your users, expose a set of actions via an oauth-like setup
+window, which you can then query and execute via a REST API.
+
+NLA offers both API Key and OAuth for signing NLA API requests.
+
+1. Server-side (API Key): for quickly getting started, testing, and production scenarios
+ where LangChain will only use actions exposed in the developer's Zapier account
+ (and will use the developer's connected accounts on Zapier.com)
+
+2. User-facing (Oauth): for production scenarios where you are deploying an end-user
+ facing application and LangChain needs access to end-user's exposed actions and
+ connected accounts on Zapier.com
+
+This quick start will focus on the server-side use case for brevity.
+Review [full docs](https://nla.zapier.com/start/) for user-facing oauth developer
+support.
+
+Typically, you'd use SequentialChain, here's a basic example:
+
+ 1. Use NLA to find an email in Gmail
+ 2. Use LLMChain to generate a draft reply to (1)
+ 3. Use NLA to send the draft reply (2) to someone in Slack via direct message
+
+In code, below:
+
+```python
+
+import os
+
+# get from https://platform.openai.com/
+os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "")
+
+# get from https://nla.zapier.com/docs/authentication/
+os.environ["ZAPIER_NLA_API_KEY"] = os.environ.get("ZAPIER_NLA_API_KEY", "")
+
+from langchain_community.agent_toolkits import ZapierToolkit
+from langchain_community.utilities.zapier import ZapierNLAWrapper
+
+## step 0. expose gmail 'find email' and slack 'send channel message' actions
+
+# first go here, log in, expose (enable) the two actions:
+# https://nla.zapier.com/demo/start
+# -- for this example, can leave all fields "Have AI guess"
+# in an oauth scenario, you'd get your own id (instead of 'demo')
+# which you route your users through first
+
+zapier = ZapierNLAWrapper()
+## To leverage OAuth you may pass the value `nla_oauth_access_token` to
+## the ZapierNLAWrapper. If you do this there is no need to initialize
+## the ZAPIER_NLA_API_KEY env variable
+# zapier = ZapierNLAWrapper(zapier_nla_oauth_access_token="TOKEN_HERE")
+toolkit = ZapierToolkit.from_zapier_nla_wrapper(zapier)
+```
+
+"""
+from typing import Any, Dict, Optional
+
+from langchain_core._api import warn_deprecated
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForToolRun,
+ CallbackManagerForToolRun,
+)
+from langchain_core.pydantic_v1 import Field, root_validator
+from langchain_core.tools import BaseTool
+
+from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
+from langchain_community.utilities.zapier import ZapierNLAWrapper
+
+
+class ZapierNLARunAction(BaseTool):
+ """
+ Args:
+ action_id: a specific action ID (from list actions) of the action to execute
+ (the set api_key must be associated with the action owner)
+ instructions: a natural language instruction string for using the action
+ (eg. "get the latest email from Mike Knoop" for "Gmail: find email" action)
+ params: a dict, optional. Any params provided will *override* AI guesses
+ from `instructions` (see "understanding the AI guessing flow" here:
+ https://nla.zapier.com/docs/using-the-api#ai-guessing)
+
+ """
+
+ api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper)
+ action_id: str
+ params: Optional[dict] = None
+ base_prompt: str = BASE_ZAPIER_TOOL_PROMPT
+ zapier_description: str
+ params_schema: Dict[str, str] = Field(default_factory=dict)
+ name: str = ""
+ description: str = ""
+
+ @root_validator
+ def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ zapier_description = values["zapier_description"]
+ params_schema = values["params_schema"]
+ if "instructions" in params_schema:
+ del params_schema["instructions"]
+
+ # Ensure base prompt (if overridden) contains necessary input fields
+ necessary_fields = {"{zapier_description}", "{params}"}
+ if not all(field in values["base_prompt"] for field in necessary_fields):
+ raise ValueError(
+ "Your custom base Zapier prompt must contain input fields for "
+ "{zapier_description} and {params}."
+ )
+
+ values["name"] = zapier_description
+ values["description"] = values["base_prompt"].format(
+ zapier_description=zapier_description,
+ params=str(list(params_schema.keys())),
+ )
+ return values
+
+ def _run(
+ self, instructions: str, run_manager: Optional[CallbackManagerForToolRun] = None
+ ) -> str:
+ """Use the Zapier NLA tool to return a list of all exposed user actions."""
+ warn_deprecated(
+ since="0.0.319",
+ message=(
+ "This tool will be deprecated on 2023-11-17. See "
+ "https://nla.zapier.com/sunset/ for details"
+ ),
+ )
+ return self.api_wrapper.run_as_str(self.action_id, instructions, self.params)
+
+ async def _arun(
+ self,
+ instructions: str,
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Zapier NLA tool to return a list of all exposed user actions."""
+ warn_deprecated(
+ since="0.0.319",
+ message=(
+ "This tool will be deprecated on 2023-11-17. See "
+ "https://nla.zapier.com/sunset/ for details"
+ ),
+ )
+ return await self.api_wrapper.arun_as_str(
+ self.action_id,
+ instructions,
+ self.params,
+ )
+
+
+ZapierNLARunAction.__doc__ = (
+ ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore
+)
+
+
+# other useful actions
+
+
+class ZapierNLAListActions(BaseTool):
+ """
+ Args:
+ None
+
+ """
+
+ name: str = "ZapierNLA_list_actions"
+ description: str = BASE_ZAPIER_TOOL_PROMPT + (
+ "This tool returns a list of the user's exposed actions."
+ )
+ api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper)
+
+ def _run(
+ self,
+ _: str = "",
+ run_manager: Optional[CallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Zapier NLA tool to return a list of all exposed user actions."""
+ warn_deprecated(
+ since="0.0.319",
+ message=(
+ "This tool will be deprecated on 2023-11-17. See "
+ "https://nla.zapier.com/sunset/ for details"
+ ),
+ )
+ return self.api_wrapper.list_as_str()
+
+ async def _arun(
+ self,
+ _: str = "",
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
+ ) -> str:
+ """Use the Zapier NLA tool to return a list of all exposed user actions."""
+ warn_deprecated(
+ since="0.0.319",
+ message=(
+ "This tool will be deprecated on 2023-11-17. See "
+ "https://nla.zapier.com/sunset/ for details"
+ ),
+ )
+ return await self.api_wrapper.alist_as_str()
+
+
+ZapierNLAListActions.__doc__ = (
+ ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore
+)
diff --git a/libs/community/langchain_community/utilities/__init__.py b/libs/community/langchain_community/utilities/__init__.py
new file mode 100644
index 00000000000..ac91e964d2d
--- /dev/null
+++ b/libs/community/langchain_community/utilities/__init__.py
@@ -0,0 +1,413 @@
+"""**Utilities** are the integrations with third-part systems and packages.
+
+Other LangChain classes use **Utilities** to interact with third-part systems
+and packages.
+"""
+from typing import Any
+
+from langchain_community.utilities.requests import (
+ Requests,
+ RequestsWrapper,
+ TextRequestsWrapper,
+)
+
+
+def _import_alpha_vantage() -> Any:
+ from langchain_community.utilities.alpha_vantage import AlphaVantageAPIWrapper
+
+ return AlphaVantageAPIWrapper
+
+
+def _import_apify() -> Any:
+ from langchain_community.utilities.apify import ApifyWrapper
+
+ return ApifyWrapper
+
+
+def _import_arcee() -> Any:
+ from langchain_community.utilities.arcee import ArceeWrapper
+
+ return ArceeWrapper
+
+
+def _import_arxiv() -> Any:
+ from langchain_community.utilities.arxiv import ArxivAPIWrapper
+
+ return ArxivAPIWrapper
+
+
+def _import_awslambda() -> Any:
+ from langchain_community.utilities.awslambda import LambdaWrapper
+
+ return LambdaWrapper
+
+
+def _import_bibtex() -> Any:
+ from langchain_community.utilities.bibtex import BibtexparserWrapper
+
+ return BibtexparserWrapper
+
+
+def _import_bing_search() -> Any:
+ from langchain_community.utilities.bing_search import BingSearchAPIWrapper
+
+ return BingSearchAPIWrapper
+
+
+def _import_brave_search() -> Any:
+ from langchain_community.utilities.brave_search import BraveSearchWrapper
+
+ return BraveSearchWrapper
+
+
+def _import_duckduckgo_search() -> Any:
+ from langchain_community.utilities.duckduckgo_search import (
+ DuckDuckGoSearchAPIWrapper,
+ )
+
+ return DuckDuckGoSearchAPIWrapper
+
+
+def _import_golden_query() -> Any:
+ from langchain_community.utilities.golden_query import GoldenQueryAPIWrapper
+
+ return GoldenQueryAPIWrapper
+
+
+def _import_google_lens() -> Any:
+ from langchain_community.utilities.google_lens import GoogleLensAPIWrapper
+
+ return GoogleLensAPIWrapper
+
+
+def _import_google_places_api() -> Any:
+ from langchain_community.utilities.google_places_api import GooglePlacesAPIWrapper
+
+ return GooglePlacesAPIWrapper
+
+
+def _import_google_jobs() -> Any:
+ from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
+
+ return GoogleJobsAPIWrapper
+
+
+def _import_google_scholar() -> Any:
+ from langchain_community.utilities.google_scholar import GoogleScholarAPIWrapper
+
+ return GoogleScholarAPIWrapper
+
+
+def _import_google_trends() -> Any:
+ from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
+
+ return GoogleTrendsAPIWrapper
+
+
+def _import_google_finance() -> Any:
+ from langchain_community.utilities.google_finance import GoogleFinanceAPIWrapper
+
+ return GoogleFinanceAPIWrapper
+
+
+def _import_google_search() -> Any:
+ from langchain_community.utilities.google_search import GoogleSearchAPIWrapper
+
+ return GoogleSearchAPIWrapper
+
+
+def _import_google_serper() -> Any:
+ from langchain_community.utilities.google_serper import GoogleSerperAPIWrapper
+
+ return GoogleSerperAPIWrapper
+
+
+def _import_graphql() -> Any:
+ from langchain_community.utilities.graphql import GraphQLAPIWrapper
+
+ return GraphQLAPIWrapper
+
+
+def _import_jira() -> Any:
+ from langchain_community.utilities.jira import JiraAPIWrapper
+
+ return JiraAPIWrapper
+
+
+def _import_max_compute() -> Any:
+ from langchain_community.utilities.max_compute import MaxComputeAPIWrapper
+
+ return MaxComputeAPIWrapper
+
+
+def _import_merriam_webster() -> Any:
+ from langchain_community.utilities.merriam_webster import MerriamWebsterAPIWrapper
+
+ return MerriamWebsterAPIWrapper
+
+
+def _import_metaphor_search() -> Any:
+ from langchain_community.utilities.metaphor_search import MetaphorSearchAPIWrapper
+
+ return MetaphorSearchAPIWrapper
+
+
+def _import_openweathermap() -> Any:
+ from langchain_community.utilities.openweathermap import OpenWeatherMapAPIWrapper
+
+ return OpenWeatherMapAPIWrapper
+
+
+def _import_outline() -> Any:
+ from langchain_community.utilities.outline import OutlineAPIWrapper
+
+ return OutlineAPIWrapper
+
+
+def _import_portkey() -> Any:
+ from langchain_community.utilities.portkey import Portkey
+
+ return Portkey
+
+
+def _import_powerbi() -> Any:
+ from langchain_community.utilities.powerbi import PowerBIDataset
+
+ return PowerBIDataset
+
+
+def _import_pubmed() -> Any:
+ from langchain_community.utilities.pubmed import PubMedAPIWrapper
+
+ return PubMedAPIWrapper
+
+
+def _import_python() -> Any:
+ from langchain_community.utilities.python import PythonREPL
+
+ return PythonREPL
+
+
+def _import_scenexplain() -> Any:
+ from langchain_community.utilities.scenexplain import SceneXplainAPIWrapper
+
+ return SceneXplainAPIWrapper
+
+
+def _import_searchapi() -> Any:
+ from langchain_community.utilities.searchapi import SearchApiAPIWrapper
+
+ return SearchApiAPIWrapper
+
+
+def _import_searx_search() -> Any:
+ from langchain_community.utilities.searx_search import SearxSearchWrapper
+
+ return SearxSearchWrapper
+
+
+def _import_serpapi() -> Any:
+ from langchain_community.utilities.serpapi import SerpAPIWrapper
+
+ return SerpAPIWrapper
+
+
+def _import_spark_sql() -> Any:
+ from langchain_community.utilities.spark_sql import SparkSQL
+
+ return SparkSQL
+
+
+def _import_sql_database() -> Any:
+ from langchain_community.utilities.sql_database import SQLDatabase
+
+ return SQLDatabase
+
+
+def _import_steam_webapi() -> Any:
+ from langchain_community.utilities.steam import SteamWebAPIWrapper
+
+ return SteamWebAPIWrapper
+
+
+def _import_stackexchange() -> Any:
+ from langchain_community.utilities.stackexchange import StackExchangeAPIWrapper
+
+ return StackExchangeAPIWrapper
+
+
+def _import_tensorflow_datasets() -> Any:
+ from langchain_community.utilities.tensorflow_datasets import TensorflowDatasets
+
+ return TensorflowDatasets
+
+
+def _import_twilio() -> Any:
+ from langchain_community.utilities.twilio import TwilioAPIWrapper
+
+ return TwilioAPIWrapper
+
+
+def _import_wikipedia() -> Any:
+ from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
+
+ return WikipediaAPIWrapper
+
+
+def _import_wolfram_alpha() -> Any:
+ from langchain_community.utilities.wolfram_alpha import WolframAlphaAPIWrapper
+
+ return WolframAlphaAPIWrapper
+
+
+def _import_zapier() -> Any:
+ from langchain_community.utilities.zapier import ZapierNLAWrapper
+
+ return ZapierNLAWrapper
+
+
+def _import_nasa() -> Any:
+ from langchain_community.utilities.nasa import NasaAPIWrapper
+
+ return NasaAPIWrapper
+
+
+def __getattr__(name: str) -> Any:
+ if name == "AlphaVantageAPIWrapper":
+ return _import_alpha_vantage()
+ elif name == "ApifyWrapper":
+ return _import_apify()
+ elif name == "ArceeWrapper":
+ return _import_arcee()
+ elif name == "ArxivAPIWrapper":
+ return _import_arxiv()
+ elif name == "LambdaWrapper":
+ return _import_awslambda()
+ elif name == "BibtexparserWrapper":
+ return _import_bibtex()
+ elif name == "BingSearchAPIWrapper":
+ return _import_bing_search()
+ elif name == "BraveSearchWrapper":
+ return _import_brave_search()
+ elif name == "DuckDuckGoSearchAPIWrapper":
+ return _import_duckduckgo_search()
+ elif name == "GoogleLensAPIWrapper":
+ return _import_google_lens()
+ elif name == "GoldenQueryAPIWrapper":
+ return _import_golden_query()
+ elif name == "GoogleJobsAPIWrapper":
+ return _import_google_jobs()
+ elif name == "GoogleScholarAPIWrapper":
+ return _import_google_scholar()
+ elif name == "GoogleFinanceAPIWrapper":
+ return _import_google_finance()
+ elif name == "GoogleTrendsAPIWrapper":
+ return _import_google_trends()
+ elif name == "GooglePlacesAPIWrapper":
+ return _import_google_places_api()
+ elif name == "GoogleSearchAPIWrapper":
+ return _import_google_search()
+ elif name == "GoogleSerperAPIWrapper":
+ return _import_google_serper()
+ elif name == "GraphQLAPIWrapper":
+ return _import_graphql()
+ elif name == "JiraAPIWrapper":
+ return _import_jira()
+ elif name == "MaxComputeAPIWrapper":
+ return _import_max_compute()
+ elif name == "MerriamWebsterAPIWrapper":
+ return _import_merriam_webster()
+ elif name == "MetaphorSearchAPIWrapper":
+ return _import_metaphor_search()
+ elif name == "NasaAPIWrapper":
+ return _import_nasa()
+ elif name == "OpenWeatherMapAPIWrapper":
+ return _import_openweathermap()
+ elif name == "OutlineAPIWrapper":
+ return _import_outline()
+ elif name == "Portkey":
+ return _import_portkey()
+ elif name == "PowerBIDataset":
+ return _import_powerbi()
+ elif name == "PubMedAPIWrapper":
+ return _import_pubmed()
+ elif name == "PythonREPL":
+ return _import_python()
+ elif name == "SceneXplainAPIWrapper":
+ return _import_scenexplain()
+ elif name == "SearchApiAPIWrapper":
+ return _import_searchapi()
+ elif name == "SearxSearchWrapper":
+ return _import_searx_search()
+ elif name == "SerpAPIWrapper":
+ return _import_serpapi()
+ elif name == "SparkSQL":
+ return _import_spark_sql()
+ elif name == "StackExchangeAPIWrapper":
+ return _import_stackexchange()
+ elif name == "SQLDatabase":
+ return _import_sql_database()
+ elif name == "SteamWebAPIWrapper":
+ return _import_steam_webapi()
+ elif name == "TensorflowDatasets":
+ return _import_tensorflow_datasets()
+ elif name == "TwilioAPIWrapper":
+ return _import_twilio()
+ elif name == "WikipediaAPIWrapper":
+ return _import_wikipedia()
+ elif name == "WolframAlphaAPIWrapper":
+ return _import_wolfram_alpha()
+ elif name == "ZapierNLAWrapper":
+ return _import_zapier()
+ else:
+ raise AttributeError(f"Could not find: {name}")
+
+
+__all__ = [
+ "AlphaVantageAPIWrapper",
+ "ApifyWrapper",
+ "ArceeWrapper",
+ "ArxivAPIWrapper",
+ "BibtexparserWrapper",
+ "BingSearchAPIWrapper",
+ "BraveSearchWrapper",
+ "DuckDuckGoSearchAPIWrapper",
+ "GoldenQueryAPIWrapper",
+ "GoogleFinanceAPIWrapper",
+ "GoogleLensAPIWrapper",
+ "GoogleJobsAPIWrapper",
+ "GooglePlacesAPIWrapper",
+ "GoogleScholarAPIWrapper",
+ "GoogleTrendsAPIWrapper",
+ "GoogleSearchAPIWrapper",
+ "GoogleSerperAPIWrapper",
+ "GraphQLAPIWrapper",
+ "JiraAPIWrapper",
+ "LambdaWrapper",
+ "MaxComputeAPIWrapper",
+ "MerriamWebsterAPIWrapper",
+ "MetaphorSearchAPIWrapper",
+ "NasaAPIWrapper",
+ "OpenWeatherMapAPIWrapper",
+ "OutlineAPIWrapper",
+ "Portkey",
+ "PowerBIDataset",
+ "PubMedAPIWrapper",
+ "PythonREPL",
+ "Requests",
+ "RequestsWrapper",
+ "SteamWebAPIWrapper",
+ "SQLDatabase",
+ "SceneXplainAPIWrapper",
+ "SearchApiAPIWrapper",
+ "SearxSearchWrapper",
+ "SerpAPIWrapper",
+ "SparkSQL",
+ "StackExchangeAPIWrapper",
+ "TensorflowDatasets",
+ "TextRequestsWrapper",
+ "TwilioAPIWrapper",
+ "WikipediaAPIWrapper",
+ "WolframAlphaAPIWrapper",
+ "ZapierNLAWrapper",
+]
diff --git a/libs/community/langchain_community/utilities/alpha_vantage.py b/libs/community/langchain_community/utilities/alpha_vantage.py
new file mode 100644
index 00000000000..db51abf96da
--- /dev/null
+++ b/libs/community/langchain_community/utilities/alpha_vantage.py
@@ -0,0 +1,64 @@
+"""Util that calls AlphaVantage for Currency Exchange Rate."""
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class AlphaVantageAPIWrapper(BaseModel):
+ """Wrapper for AlphaVantage API for Currency Exchange Rate.
+
+ Docs for using:
+
+ 1. Go to AlphaVantage and sign up for an API key
+ 2. Save your API KEY into ALPHAVANTAGE_API_KEY env variable
+ """
+
+ alphavantage_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ values["alphavantage_api_key"] = get_from_dict_or_env(
+ values, "alphavantage_api_key", "ALPHAVANTAGE_API_KEY"
+ )
+ return values
+
+ def _get_exchange_rate(
+ self, from_currency: str, to_currency: str
+ ) -> Dict[str, Any]:
+ """Make a request to the AlphaVantage API to get the exchange rate."""
+ response = requests.get(
+ "https://www.alphavantage.co/query/",
+ params={
+ "function": "CURRENCY_EXCHANGE_RATE",
+ "from_currency": from_currency,
+ "to_currency": to_currency,
+ "apikey": self.alphavantage_api_key,
+ },
+ )
+ response.raise_for_status()
+ data = response.json()
+
+ if "Error Message" in data:
+ raise ValueError(f"API Error: {data['Error Message']}")
+
+ return data
+
+ @property
+ def standard_currencies(self) -> List[str]:
+ return ["USD", "EUR", "GBP", "JPY", "CHF", "CAD", "AUD", "NZD"]
+
+ def run(self, from_currency: str, to_currency: str) -> str:
+ """Get the current exchange rate for a specified currency pair."""
+ if to_currency not in self.standard_currencies:
+ from_currency, to_currency = to_currency, from_currency
+
+ data = self._get_exchange_rate(from_currency, to_currency)
+ return data["Realtime Currency Exchange Rate"]
diff --git a/libs/community/langchain_community/utilities/anthropic.py b/libs/community/langchain_community/utilities/anthropic.py
new file mode 100644
index 00000000000..31bb4015b1d
--- /dev/null
+++ b/libs/community/langchain_community/utilities/anthropic.py
@@ -0,0 +1,27 @@
+from typing import Any, List
+
+
+def _get_anthropic_client() -> Any:
+ try:
+ import anthropic
+ except ImportError:
+ raise ImportError(
+ "Could not import anthropic python package. "
+ "This is needed in order to accurately tokenize the text "
+ "for anthropic models. Please install it with `pip install anthropic`."
+ )
+ return anthropic.Anthropic()
+
+
+def get_num_tokens_anthropic(text: str) -> int:
+ """Get the number of tokens in a string of text."""
+ client = _get_anthropic_client()
+ return client.count_tokens(text=text)
+
+
+def get_token_ids_anthropic(text: str) -> List[int]:
+ """Get the token ids for a string of text."""
+ client = _get_anthropic_client()
+ tokenizer = client.get_tokenizer()
+ encoded_text = tokenizer.encode(text)
+ return encoded_text.ids
diff --git a/libs/community/langchain_community/utilities/apify.py b/libs/community/langchain_community/utilities/apify.py
new file mode 100644
index 00000000000..6f37f84f21e
--- /dev/null
+++ b/libs/community/langchain_community/utilities/apify.py
@@ -0,0 +1,204 @@
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+if TYPE_CHECKING:
+ from langchain_community.document_loaders import ApifyDatasetLoader
+
+
+class ApifyWrapper(BaseModel):
+ """Wrapper around Apify.
+ To use, you should have the ``apify-client`` python package installed,
+ and the environment variable ``APIFY_API_TOKEN`` set with your API key, or pass
+ `apify_api_token` as a named parameter to the constructor.
+ """
+
+ apify_client: Any
+ apify_client_async: Any
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate environment.
+ Validate that an Apify API token is set and the apify-client
+ Python package exists in the current environment.
+ """
+ apify_api_token = get_from_dict_or_env(
+ values, "apify_api_token", "APIFY_API_TOKEN"
+ )
+
+ try:
+ from apify_client import ApifyClient, ApifyClientAsync
+
+ values["apify_client"] = ApifyClient(apify_api_token)
+ values["apify_client_async"] = ApifyClientAsync(apify_api_token)
+ except ImportError:
+ raise ImportError(
+ "Could not import apify-client Python package. "
+ "Please install it with `pip install apify-client`."
+ )
+
+ return values
+
+ def call_actor(
+ self,
+ actor_id: str,
+ run_input: Dict,
+ dataset_mapping_function: Callable[[Dict], Document],
+ *,
+ build: Optional[str] = None,
+ memory_mbytes: Optional[int] = None,
+ timeout_secs: Optional[int] = None,
+ ) -> "ApifyDatasetLoader":
+ """Run an Actor on the Apify platform and wait for results to be ready.
+ Args:
+ actor_id (str): The ID or name of the Actor on the Apify platform.
+ run_input (Dict): The input object of the Actor that you're trying to run.
+ dataset_mapping_function (Callable): A function that takes a single
+ dictionary (an Apify dataset item) and converts it to an
+ instance of the Document class.
+ build (str, optional): Optionally specifies the actor build to run.
+ It can be either a build tag or build number.
+ memory_mbytes (int, optional): Optional memory limit for the run,
+ in megabytes.
+ timeout_secs (int, optional): Optional timeout for the run, in seconds.
+ Returns:
+ ApifyDatasetLoader: A loader that will fetch the records from the
+ Actor run's default dataset.
+ """
+ from langchain_community.document_loaders import ApifyDatasetLoader
+
+ actor_call = self.apify_client.actor(actor_id).call(
+ run_input=run_input,
+ build=build,
+ memory_mbytes=memory_mbytes,
+ timeout_secs=timeout_secs,
+ )
+
+ return ApifyDatasetLoader(
+ dataset_id=actor_call["defaultDatasetId"],
+ dataset_mapping_function=dataset_mapping_function,
+ )
+
+ async def acall_actor(
+ self,
+ actor_id: str,
+ run_input: Dict,
+ dataset_mapping_function: Callable[[Dict], Document],
+ *,
+ build: Optional[str] = None,
+ memory_mbytes: Optional[int] = None,
+ timeout_secs: Optional[int] = None,
+ ) -> "ApifyDatasetLoader":
+ """Run an Actor on the Apify platform and wait for results to be ready.
+ Args:
+ actor_id (str): The ID or name of the Actor on the Apify platform.
+ run_input (Dict): The input object of the Actor that you're trying to run.
+ dataset_mapping_function (Callable): A function that takes a single
+ dictionary (an Apify dataset item) and converts it to
+ an instance of the Document class.
+ build (str, optional): Optionally specifies the actor build to run.
+ It can be either a build tag or build number.
+ memory_mbytes (int, optional): Optional memory limit for the run,
+ in megabytes.
+ timeout_secs (int, optional): Optional timeout for the run, in seconds.
+ Returns:
+ ApifyDatasetLoader: A loader that will fetch the records from the
+ Actor run's default dataset.
+ """
+ from langchain_community.document_loaders import ApifyDatasetLoader
+
+ actor_call = await self.apify_client_async.actor(actor_id).call(
+ run_input=run_input,
+ build=build,
+ memory_mbytes=memory_mbytes,
+ timeout_secs=timeout_secs,
+ )
+
+ return ApifyDatasetLoader(
+ dataset_id=actor_call["defaultDatasetId"],
+ dataset_mapping_function=dataset_mapping_function,
+ )
+
+ def call_actor_task(
+ self,
+ task_id: str,
+ task_input: Dict,
+ dataset_mapping_function: Callable[[Dict], Document],
+ *,
+ build: Optional[str] = None,
+ memory_mbytes: Optional[int] = None,
+ timeout_secs: Optional[int] = None,
+ ) -> "ApifyDatasetLoader":
+ """Run a saved Actor task on Apify and wait for results to be ready.
+ Args:
+ task_id (str): The ID or name of the task on the Apify platform.
+ task_input (Dict): The input object of the task that you're trying to run.
+ Overrides the task's saved input.
+ dataset_mapping_function (Callable): A function that takes a single
+ dictionary (an Apify dataset item) and converts it to an
+ instance of the Document class.
+ build (str, optional): Optionally specifies the actor build to run.
+ It can be either a build tag or build number.
+ memory_mbytes (int, optional): Optional memory limit for the run,
+ in megabytes.
+ timeout_secs (int, optional): Optional timeout for the run, in seconds.
+ Returns:
+ ApifyDatasetLoader: A loader that will fetch the records from the
+ task run's default dataset.
+ """
+ from langchain_community.document_loaders import ApifyDatasetLoader
+
+ task_call = self.apify_client.task(task_id).call(
+ task_input=task_input,
+ build=build,
+ memory_mbytes=memory_mbytes,
+ timeout_secs=timeout_secs,
+ )
+
+ return ApifyDatasetLoader(
+ dataset_id=task_call["defaultDatasetId"],
+ dataset_mapping_function=dataset_mapping_function,
+ )
+
+ async def acall_actor_task(
+ self,
+ task_id: str,
+ task_input: Dict,
+ dataset_mapping_function: Callable[[Dict], Document],
+ *,
+ build: Optional[str] = None,
+ memory_mbytes: Optional[int] = None,
+ timeout_secs: Optional[int] = None,
+ ) -> "ApifyDatasetLoader":
+ """Run a saved Actor task on Apify and wait for results to be ready.
+ Args:
+ task_id (str): The ID or name of the task on the Apify platform.
+ task_input (Dict): The input object of the task that you're trying to run.
+ Overrides the task's saved input.
+ dataset_mapping_function (Callable): A function that takes a single
+ dictionary (an Apify dataset item) and converts it to an
+ instance of the Document class.
+ build (str, optional): Optionally specifies the actor build to run.
+ It can be either a build tag or build number.
+ memory_mbytes (int, optional): Optional memory limit for the run,
+ in megabytes.
+ timeout_secs (int, optional): Optional timeout for the run, in seconds.
+ Returns:
+ ApifyDatasetLoader: A loader that will fetch the records from the
+ task run's default dataset.
+ """
+ from langchain_community.document_loaders import ApifyDatasetLoader
+
+ task_call = await self.apify_client_async.task(task_id).call(
+ task_input=task_input,
+ build=build,
+ memory_mbytes=memory_mbytes,
+ timeout_secs=timeout_secs,
+ )
+
+ return ApifyDatasetLoader(
+ dataset_id=task_call["defaultDatasetId"],
+ dataset_mapping_function=dataset_mapping_function,
+ )
diff --git a/libs/community/langchain_community/utilities/arcee.py b/libs/community/langchain_community/utilities/arcee.py
new file mode 100644
index 00000000000..72170348583
--- /dev/null
+++ b/libs/community/langchain_community/utilities/arcee.py
@@ -0,0 +1,255 @@
+# This module contains utility classes and functions for interacting with Arcee API.
+# For more information and updates, refer to the Arcee utils page:
+# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
+
+from enum import Enum
+from typing import Any, Dict, List, Literal, Mapping, Optional, Union
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
+from langchain_core.retrievers import Document
+
+
+class ArceeRoute(str, Enum):
+ """Routes available for the Arcee API as enumerator."""
+
+ generate = "models/generate"
+ retrieve = "models/retrieve"
+ model_training_status = "models/status/{id_or_name}"
+
+
+class DALMFilterType(str, Enum):
+ """Filter types available for a DALM retrieval as enumerator."""
+
+ fuzzy_search = "fuzzy_search"
+ strict_search = "strict_search"
+
+
+class DALMFilter(BaseModel):
+ """Filters available for a DALM retrieval and generation.
+
+ Arguments:
+ field_name: The field to filter on. Can be 'document' or 'name' to filter
+ on your document's raw text or title. Any other field will be presumed
+ to be a metadata field you included when uploading your context data
+ filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
+ 'fuzzy_search' means a fuzzy search on the provided field is performed.
+ The exact strict doesn't need to exist in the document
+ for this to find a match.
+ Very useful for scanning a document for some keyword terms.
+ 'strict_search' means that the exact string must appear
+ in the provided field.
+ This is NOT an exact eq filter. ie a document with content
+ "the happy dog crossed the street" will match on a strict_search of
+ "dog" but won't match on "the dog".
+ Python equivalent of `return search_string in full_string`.
+ value: The actual value to search for in the context data/metadata
+ """
+
+ field_name: str
+ filter_type: DALMFilterType
+ value: str
+ _is_metadata: bool = False
+
+ @root_validator()
+ def set_meta(cls, values: Dict) -> Dict:
+ """document and name are reserved arcee keys. Anything else is metadata"""
+ values["_is_meta"] = values.get("field_name") not in ["document", "name"]
+ return values
+
+
+class ArceeDocumentSource(BaseModel):
+ """Source of an Arcee document."""
+
+ document: str
+ name: str
+ id: str
+
+
+class ArceeDocument(BaseModel):
+ """Arcee document."""
+
+ index: str
+ id: str
+ score: float
+ source: ArceeDocumentSource
+
+
+class ArceeDocumentAdapter:
+ """Adapter for Arcee documents"""
+
+ @classmethod
+ def adapt(cls, arcee_document: ArceeDocument) -> Document:
+ """Adapts an `ArceeDocument` to a langchain's `Document` object."""
+ return Document(
+ page_content=arcee_document.source.document,
+ metadata={
+ # arcee document; source metadata
+ "name": arcee_document.source.name,
+ "source_id": arcee_document.source.id,
+ # arcee document metadata
+ "index": arcee_document.index,
+ "id": arcee_document.id,
+ "score": arcee_document.score,
+ },
+ )
+
+
+class ArceeWrapper:
+ """Wrapper for Arcee API.
+
+ For more details, see: https://www.arcee.ai/
+ """
+
+ def __init__(
+ self,
+ arcee_api_key: Union[str, SecretStr],
+ arcee_api_url: str,
+ arcee_api_version: str,
+ model_kwargs: Optional[Dict[str, Any]],
+ model_name: str,
+ ):
+ """Initialize ArceeWrapper.
+
+ Arguments:
+ arcee_api_key: API key for Arcee API.
+ arcee_api_url: URL for Arcee API.
+ arcee_api_version: Version of Arcee API.
+ model_kwargs: Keyword arguments for Arcee API.
+ model_name: Name of an Arcee model.
+ """
+ if isinstance(arcee_api_key, str):
+ arcee_api_key_ = SecretStr(arcee_api_key)
+ else:
+ arcee_api_key_ = arcee_api_key
+ self.arcee_api_key: SecretStr = arcee_api_key_
+ self.model_kwargs = model_kwargs
+ self.arcee_api_url = arcee_api_url
+ self.arcee_api_version = arcee_api_version
+
+ try:
+ route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
+ response = self._make_request("get", route)
+ self.model_id = response.get("model_id")
+ self.model_training_status = response.get("status")
+ except Exception as e:
+ raise ValueError(
+ f"Error while validating model training status for '{model_name}': {e}"
+ ) from e
+
+ def validate_model_training_status(self) -> None:
+ if self.model_training_status != "training_complete":
+ raise Exception(
+ f"Model {self.model_id} is not ready. "
+ "Please wait for training to complete."
+ )
+
+ def _make_request(
+ self,
+ method: Literal["post", "get"],
+ route: Union[ArceeRoute, str],
+ body: Optional[Mapping[str, Any]] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ ) -> dict:
+ """Make a request to the Arcee API
+ Args:
+ method: The HTTP method to use
+ route: The route to call
+ body: The body of the request
+ params: The query params of the request
+ headers: The headers of the request
+ """
+ headers = self._make_request_headers(headers=headers)
+ url = self._make_request_url(route=route)
+
+ req_type = getattr(requests, method)
+
+ response = req_type(url, json=body, params=params, headers=headers)
+ if response.status_code not in (200, 201):
+ raise Exception(f"Failed to make request. Response: {response.text}")
+ return response.json()
+
+ def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
+ headers = headers or {}
+ if not isinstance(self.arcee_api_key, SecretStr):
+ raise TypeError(
+ f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
+ )
+ api_key = self.arcee_api_key.get_secret_value()
+ internal_headers = {
+ "X-Token": api_key,
+ "Content-Type": "application/json",
+ }
+ headers.update(internal_headers)
+ return headers
+
+ def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
+ return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
+
+ def _make_request_body_for_models(
+ self, prompt: str, **kwargs: Mapping[str, Any]
+ ) -> Mapping[str, Any]:
+ """Make the request body for generate/retrieve models endpoint"""
+ _model_kwargs = self.model_kwargs or {}
+ _params = {**_model_kwargs, **kwargs}
+
+ filters = [DALMFilter(**f) for f in _params.get("filters", [])]
+ return dict(
+ model_id=self.model_id,
+ query=prompt,
+ size=_params.get("size", 3),
+ filters=filters,
+ id=self.model_id,
+ )
+
+ def generate(
+ self,
+ prompt: str,
+ **kwargs: Any,
+ ) -> str:
+ """Generate text from Arcee DALM.
+
+ Args:
+ prompt: Prompt to generate text from.
+ size: The max number of context results to retrieve. Defaults to 3.
+ (Can be less if filters are provided).
+ filters: Filters to apply to the context dataset.
+ """
+
+ response = self._make_request(
+ method="post",
+ route=ArceeRoute.generate.value,
+ body=self._make_request_body_for_models(
+ prompt=prompt,
+ **kwargs,
+ ),
+ )
+ return response["text"]
+
+ def retrieve(
+ self,
+ query: str,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Retrieve {size} contexts with your retriever for a given query
+
+ Args:
+ query: Query to submit to the model
+ size: The max number of context results to retrieve. Defaults to 3.
+ (Can be less if filters are provided).
+ filters: Filters to apply to the context dataset.
+ """
+
+ response = self._make_request(
+ method="post",
+ route=ArceeRoute.retrieve.value,
+ body=self._make_request_body_for_models(
+ prompt=query,
+ **kwargs,
+ ),
+ )
+ return [
+ ArceeDocumentAdapter.adapt(ArceeDocument(**doc))
+ for doc in response["results"]
+ ]
diff --git a/libs/community/langchain_community/utilities/arxiv.py b/libs/community/langchain_community/utilities/arxiv.py
new file mode 100644
index 00000000000..5a728f5e2c8
--- /dev/null
+++ b/libs/community/langchain_community/utilities/arxiv.py
@@ -0,0 +1,238 @@
+"""Util that calls Arxiv."""
+import logging
+import os
+import re
+from typing import Any, Dict, List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+logger = logging.getLogger(__name__)
+
+
+class ArxivAPIWrapper(BaseModel):
+ """Wrapper around ArxivAPI.
+
+ To use, you should have the ``arxiv`` python package installed.
+ https://lukasschwab.me/arxiv.py/index.html
+ This wrapper will use the Arxiv API to conduct searches and
+ fetch document summaries. By default, it will return the document summaries
+ of the top-k results.
+ If the query is in the form of arxiv identifier
+ (see https://info.arxiv.org/help/find/index.html), it will return the paper
+ corresponding to the arxiv identifier.
+ It limits the Document content by doc_content_chars_max.
+ Set doc_content_chars_max=None if you don't want to limit the content size.
+
+ Attributes:
+ top_k_results: number of the top-scored document used for the arxiv tool
+ ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
+ load_max_docs: a limit to the number of loaded documents
+ load_all_available_meta:
+ if True: the `metadata` of the loaded Documents contains all available
+ meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
+ if False: the `metadata` contains only the published date, title,
+ authors and summary.
+ doc_content_chars_max: an optional cut limit for the length of a document's
+ content
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities.arxiv import ArxivAPIWrapper
+ arxiv = ArxivAPIWrapper(
+ top_k_results = 3,
+ ARXIV_MAX_QUERY_LENGTH = 300,
+ load_max_docs = 3,
+ load_all_available_meta = False,
+ doc_content_chars_max = 40000
+ )
+ arxiv.run("tree of thought llm)
+ """
+
+ arxiv_search: Any #: :meta private:
+ arxiv_exceptions: Any # :meta private:
+ top_k_results: int = 3
+ ARXIV_MAX_QUERY_LENGTH: int = 300
+ load_max_docs: int = 100
+ load_all_available_meta: bool = False
+ doc_content_chars_max: Optional[int] = 4000
+
+ def is_arxiv_identifier(self, query: str) -> bool:
+ """Check if a query is an arxiv identifier."""
+ arxiv_identifier_pattern = r"\d{2}(0[1-9]|1[0-2])\.\d{4,5}(v\d+|)|\d{7}.*"
+ for query_item in query[: self.ARXIV_MAX_QUERY_LENGTH].split():
+ match_result = re.match(arxiv_identifier_pattern, query_item)
+ if not match_result:
+ return False
+ assert match_result is not None
+ if not match_result.group(0) == query_item:
+ return False
+ return True
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in environment."""
+ try:
+ import arxiv
+
+ values["arxiv_search"] = arxiv.Search
+ values["arxiv_exceptions"] = (
+ arxiv.ArxivError,
+ arxiv.UnexpectedEmptyPageError,
+ arxiv.HTTPError,
+ )
+ values["arxiv_result"] = arxiv.Result
+ except ImportError:
+ raise ImportError(
+ "Could not import arxiv python package. "
+ "Please install it with `pip install arxiv`."
+ )
+ return values
+
+ def get_summaries_as_docs(self, query: str) -> List[Document]:
+ """
+ Performs an arxiv search and returns list of
+ documents, with summaries as the content.
+
+ If an error occurs or no documents found, error text
+ is returned instead. Wrapper for
+ https://lukasschwab.me/arxiv.py/index.html#Search
+
+ Args:
+ query: a plaintext search query
+ """ # noqa: E501
+ try:
+ if self.is_arxiv_identifier(query):
+ results = self.arxiv_search(
+ id_list=query.split(),
+ max_results=self.top_k_results,
+ ).results()
+ else:
+ results = self.arxiv_search( # type: ignore
+ query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
+ ).results()
+ except self.arxiv_exceptions as ex:
+ return [Document(page_content=f"Arxiv exception: {ex}")]
+ docs = [
+ Document(
+ page_content=result.summary,
+ metadata={
+ "Published": result.updated.date(),
+ "Title": result.title,
+ "Authors": ", ".join(a.name for a in result.authors),
+ },
+ )
+ for result in results
+ ]
+ return docs
+
+ def run(self, query: str) -> str:
+ """
+ Performs an arxiv search and A single string
+ with the publish date, title, authors, and summary
+ for each article separated by two newlines.
+
+ If an error occurs or no documents found, error text
+ is returned instead. Wrapper for
+ https://lukasschwab.me/arxiv.py/index.html#Search
+
+ Args:
+ query: a plaintext search query
+ """ # noqa: E501
+ try:
+ if self.is_arxiv_identifier(query):
+ results = self.arxiv_search(
+ id_list=query.split(),
+ max_results=self.top_k_results,
+ ).results()
+ else:
+ results = self.arxiv_search( # type: ignore
+ query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
+ ).results()
+ except self.arxiv_exceptions as ex:
+ return f"Arxiv exception: {ex}"
+ docs = [
+ f"Published: {result.updated.date()}\n"
+ f"Title: {result.title}\n"
+ f"Authors: {', '.join(a.name for a in result.authors)}\n"
+ f"Summary: {result.summary}"
+ for result in results
+ ]
+ if docs:
+ return "\n\n".join(docs)[: self.doc_content_chars_max]
+ else:
+ return "No good Arxiv Result was found"
+
+ def load(self, query: str) -> List[Document]:
+ """
+ Run Arxiv search and get the article texts plus the article meta information.
+ See https://lukasschwab.me/arxiv.py/index.html#Search
+
+ Returns: a list of documents with the document.page_content in text format
+
+ Performs an arxiv search, downloads the top k results as PDFs, loads
+ them as Documents, and returns them in a List.
+
+ Args:
+ query: a plaintext search query
+ """ # noqa: E501
+ try:
+ import fitz
+ except ImportError:
+ raise ImportError(
+ "PyMuPDF package not found, please install it with "
+ "`pip install pymupdf`"
+ )
+
+ try:
+ # Remove the ":" and "-" from the query, as they can cause search problems
+ query = query.replace(":", "").replace("-", "")
+ if self.is_arxiv_identifier(query):
+ results = self.arxiv_search(
+ id_list=query[: self.ARXIV_MAX_QUERY_LENGTH].split(),
+ max_results=self.load_max_docs,
+ ).results()
+ else:
+ results = self.arxiv_search( # type: ignore
+ query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs
+ ).results()
+ except self.arxiv_exceptions as ex:
+ logger.debug("Error on arxiv: %s", ex)
+ return []
+
+ docs: List[Document] = []
+ for result in results:
+ try:
+ doc_file_name: str = result.download_pdf()
+ with fitz.open(doc_file_name) as doc_file:
+ text: str = "".join(page.get_text() for page in doc_file)
+ except (FileNotFoundError, fitz.fitz.FileDataError) as f_ex:
+ logger.debug(f_ex)
+ continue
+ if self.load_all_available_meta:
+ extra_metadata = {
+ "entry_id": result.entry_id,
+ "published_first_time": str(result.published.date()),
+ "comment": result.comment,
+ "journal_ref": result.journal_ref,
+ "doi": result.doi,
+ "primary_category": result.primary_category,
+ "categories": result.categories,
+ "links": [link.href for link in result.links],
+ }
+ else:
+ extra_metadata = {}
+ metadata = {
+ "Published": str(result.updated.date()),
+ "Title": result.title,
+ "Authors": ", ".join(a.name for a in result.authors),
+ "Summary": result.summary,
+ **extra_metadata,
+ }
+ doc = Document(
+ page_content=text[: self.doc_content_chars_max], metadata=metadata
+ )
+ docs.append(doc)
+ os.remove(doc_file_name)
+ return docs
diff --git a/libs/community/langchain_community/utilities/awslambda.py b/libs/community/langchain_community/utilities/awslambda.py
new file mode 100644
index 00000000000..1b497dd5dd2
--- /dev/null
+++ b/libs/community/langchain_community/utilities/awslambda.py
@@ -0,0 +1,82 @@
+"""Util that calls Lambda."""
+import json
+from typing import Any, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+
+class LambdaWrapper(BaseModel):
+ """Wrapper for AWS Lambda SDK.
+ To use, you should have the ``boto3`` package installed
+ and a lambda functions built from the AWS Console or
+ CLI. Set up your AWS credentials with ``aws configure``
+
+ Example:
+ .. code-block:: bash
+
+ pip install boto3
+
+ aws configure
+
+ """
+
+ lambda_client: Any #: :meta private:
+ """The configured boto3 client"""
+ function_name: Optional[str] = None
+ """The name of your lambda function"""
+ awslambda_tool_name: Optional[str] = None
+ """If passing to an agent as a tool, the tool name"""
+ awslambda_tool_description: Optional[str] = None
+ """If passing to an agent as a tool, the description"""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+
+ try:
+ import boto3
+
+ except ImportError:
+ raise ImportError(
+ "boto3 is not installed. Please install it with `pip install boto3`"
+ )
+
+ values["lambda_client"] = boto3.client("lambda")
+ values["function_name"] = values["function_name"]
+
+ return values
+
+ def run(self, query: str) -> str:
+ """
+ Invokes the lambda function and returns the
+ result.
+
+ Args:
+ query: an input to passed to the lambda
+ function as the ``body`` of a JSON
+ object.
+ """ # noqa: E501
+ res = self.lambda_client.invoke(
+ FunctionName=self.function_name,
+ InvocationType="RequestResponse",
+ Payload=json.dumps({"body": query}),
+ )
+
+ try:
+ payload_stream = res["Payload"]
+ payload_string = payload_stream.read().decode("utf-8")
+ answer = json.loads(payload_string)["body"]
+
+ except StopIteration:
+ return "Failed to parse response from Lambda"
+
+ if answer is None or answer == "":
+ # We don't want to return the assumption alone if answer is empty
+ return "Request failed."
+ else:
+ return f"Result: {answer}"
diff --git a/libs/community/langchain_community/utilities/bibtex.py b/libs/community/langchain_community/utilities/bibtex.py
new file mode 100644
index 00000000000..45d83aefea3
--- /dev/null
+++ b/libs/community/langchain_community/utilities/bibtex.py
@@ -0,0 +1,87 @@
+"""Util that calls bibtexparser."""
+import logging
+from typing import Any, Dict, List, Mapping
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+logger = logging.getLogger(__name__)
+
+OPTIONAL_FIELDS = [
+ "annotate",
+ "booktitle",
+ "editor",
+ "howpublished",
+ "journal",
+ "keywords",
+ "note",
+ "organization",
+ "publisher",
+ "school",
+ "series",
+ "type",
+ "doi",
+ "issn",
+ "isbn",
+]
+
+
+class BibtexparserWrapper(BaseModel):
+ """Wrapper around bibtexparser.
+
+ To use, you should have the ``bibtexparser`` python package installed.
+ https://bibtexparser.readthedocs.io/en/master/
+
+ This wrapper will use bibtexparser to load a collection of references from
+ a bibtex file and fetch document summaries.
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in environment."""
+ try:
+ import bibtexparser # noqa
+ except ImportError:
+ raise ImportError(
+ "Could not import bibtexparser python package. "
+ "Please install it with `pip install bibtexparser`."
+ )
+
+ return values
+
+ def load_bibtex_entries(self, path: str) -> List[Dict[str, Any]]:
+ """Load bibtex entries from the bibtex file at the given path."""
+ import bibtexparser
+
+ with open(path) as file:
+ entries = bibtexparser.load(file).entries
+ return entries
+
+ def get_metadata(
+ self, entry: Mapping[str, Any], load_extra: bool = False
+ ) -> Dict[str, Any]:
+ """Get metadata for the given entry."""
+ publication = entry.get("journal") or entry.get("booktitle")
+ if "url" in entry:
+ url = entry["url"]
+ elif "doi" in entry:
+ url = f'https://doi.org/{entry["doi"]}'
+ else:
+ url = None
+ meta = {
+ "id": entry.get("ID"),
+ "published_year": entry.get("year"),
+ "title": entry.get("title"),
+ "publication": publication,
+ "authors": entry.get("author"),
+ "abstract": entry.get("abstract"),
+ "url": url,
+ }
+ if load_extra:
+ for field in OPTIONAL_FIELDS:
+ meta[field] = entry.get(field)
+ return {k: v for k, v in meta.items() if v is not None}
diff --git a/libs/community/langchain_community/utilities/bing_search.py b/libs/community/langchain_community/utilities/bing_search.py
new file mode 100644
index 00000000000..8166f1dfa0c
--- /dev/null
+++ b/libs/community/langchain_community/utilities/bing_search.py
@@ -0,0 +1,104 @@
+"""Util that calls Bing Search.
+
+In order to set this up, follow instructions at:
+https://levelup.gitconnected.com/api-tutorial-how-to-use-bing-web-search-api-in-python-4165d5592a7e
+"""
+from typing import Dict, List
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class BingSearchAPIWrapper(BaseModel):
+ """Wrapper for Bing Search API.
+
+ In order to set this up, follow instructions at:
+ https://levelup.gitconnected.com/api-tutorial-how-to-use-bing-web-search-api-in-python-4165d5592a7e
+ """
+
+ bing_subscription_key: str
+ bing_search_url: str
+ k: int = 10
+ search_kwargs: dict = Field(default_factory=dict)
+ """Additional keyword arguments to pass to the search request."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _bing_search_results(self, search_term: str, count: int) -> List[dict]:
+ headers = {"Ocp-Apim-Subscription-Key": self.bing_subscription_key}
+ params = {
+ "q": search_term,
+ "count": count,
+ "textDecorations": True,
+ "textFormat": "HTML",
+ **self.search_kwargs,
+ }
+ response = requests.get(
+ self.bing_search_url,
+ headers=headers,
+ params=params, # type: ignore
+ )
+ response.raise_for_status()
+ search_results = response.json()
+ return search_results["webPages"]["value"]
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and endpoint exists in environment."""
+ bing_subscription_key = get_from_dict_or_env(
+ values, "bing_subscription_key", "BING_SUBSCRIPTION_KEY"
+ )
+ values["bing_subscription_key"] = bing_subscription_key
+
+ bing_search_url = get_from_dict_or_env(
+ values,
+ "bing_search_url",
+ "BING_SEARCH_URL",
+ # default="https://api.bing.microsoft.com/v7.0/search",
+ )
+
+ values["bing_search_url"] = bing_search_url
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through BingSearch and parse result."""
+ snippets = []
+ results = self._bing_search_results(query, count=self.k)
+ if len(results) == 0:
+ return "No good Bing Search Result was found"
+ for result in results:
+ snippets.append(result["snippet"])
+
+ return " ".join(snippets)
+
+ def results(self, query: str, num_results: int) -> List[Dict]:
+ """Run query through BingSearch and return metadata.
+
+ Args:
+ query: The query to search for.
+ num_results: The number of results to return.
+
+ Returns:
+ A list of dictionaries with the following keys:
+ snippet - The description of the result.
+ title - The title of the result.
+ link - The link to the result.
+ """
+ metadata_results = []
+ results = self._bing_search_results(query, count=num_results)
+ if len(results) == 0:
+ return [{"Result": "No good Bing Search Result was found"}]
+ for result in results:
+ metadata_result = {
+ "snippet": result["snippet"],
+ "title": result["name"],
+ "link": result["url"],
+ }
+ metadata_results.append(metadata_result)
+
+ return metadata_results
diff --git a/libs/community/langchain_community/utilities/brave_search.py b/libs/community/langchain_community/utilities/brave_search.py
new file mode 100644
index 00000000000..8f3df0666c6
--- /dev/null
+++ b/libs/community/langchain_community/utilities/brave_search.py
@@ -0,0 +1,72 @@
+import json
+from typing import List
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+
+class BraveSearchWrapper(BaseModel):
+ """Wrapper around the Brave search engine."""
+
+ api_key: str
+ """The API key to use for the Brave search engine."""
+ search_kwargs: dict = Field(default_factory=dict)
+ """Additional keyword arguments to pass to the search request."""
+ base_url: str = "https://api.search.brave.com/res/v1/web/search"
+ """The base URL for the Brave search engine."""
+
+ def run(self, query: str) -> str:
+ """Query the Brave search engine and return the results as a JSON string.
+
+ Args:
+ query: The query to search for.
+
+ Returns: The results as a JSON string.
+
+ """
+ web_search_results = self._search_request(query=query)
+ final_results = [
+ {
+ "title": item.get("title"),
+ "link": item.get("url"),
+ "snippet": item.get("description"),
+ }
+ for item in web_search_results
+ ]
+ return json.dumps(final_results)
+
+ def download_documents(self, query: str) -> List[Document]:
+ """Query the Brave search engine and return the results as a list of Documents.
+
+ Args:
+ query: The query to search for.
+
+ Returns: The results as a list of Documents.
+
+ """
+ results = self._search_request(query)
+ return [
+ Document(
+ page_content=item.get("description"),
+ metadata={"title": item.get("title"), "link": item.get("url")},
+ )
+ for item in results
+ ]
+
+ def _search_request(self, query: str) -> List[dict]:
+ headers = {
+ "X-Subscription-Token": self.api_key,
+ "Accept": "application/json",
+ }
+ req = requests.PreparedRequest()
+ params = {**self.search_kwargs, **{"q": query}}
+ req.prepare_url(self.base_url, params)
+ if req.url is None:
+ raise ValueError("prepared url is None, this should not happen")
+
+ response = requests.get(req.url, headers=headers)
+ if not response.ok:
+ raise Exception(f"HTTP error {response.status_code}")
+
+ return response.json().get("web", {}).get("results", [])
diff --git a/libs/community/langchain_community/utilities/clickup.py b/libs/community/langchain_community/utilities/clickup.py
new file mode 100644
index 00000000000..ed81e7fc72b
--- /dev/null
+++ b/libs/community/langchain_community/utilities/clickup.py
@@ -0,0 +1,622 @@
+"""Util that calls clickup."""
+import json
+import warnings
+from dataclasses import asdict, dataclass, fields
+from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+DEFAULT_URL = "https://api.clickup.com/api/v2"
+
+
+@dataclass
+class Component:
+ """Base class for all components."""
+
+ @classmethod
+ def from_data(cls, data: Dict[str, Any]) -> "Component":
+ raise NotImplementedError()
+
+
+@dataclass
+class Task(Component):
+ """Class for a task."""
+
+ id: int
+ name: str
+ text_content: str
+ description: str
+ status: str
+ creator_id: int
+ creator_username: str
+ creator_email: str
+ assignees: List[Dict[str, Any]]
+ watchers: List[Dict[str, Any]]
+ priority: Optional[str]
+ due_date: Optional[str]
+ start_date: Optional[str]
+ points: int
+ team_id: int
+ project_id: int
+
+ @classmethod
+ def from_data(cls, data: Dict[str, Any]) -> "Task":
+ priority = None if data["priority"] is None else data["priority"]["priority"]
+ return cls(
+ id=data["id"],
+ name=data["name"],
+ text_content=data["text_content"],
+ description=data["description"],
+ status=data["status"]["status"],
+ creator_id=data["creator"]["id"],
+ creator_username=data["creator"]["username"],
+ creator_email=data["creator"]["email"],
+ assignees=data["assignees"],
+ watchers=data["watchers"],
+ priority=priority,
+ due_date=data["due_date"],
+ start_date=data["start_date"],
+ points=data["points"],
+ team_id=data["team_id"],
+ project_id=data["project"]["id"],
+ )
+
+
+@dataclass
+class CUList(Component):
+ """Component class for a list."""
+
+ folder_id: float
+ name: str
+ content: Optional[str] = None
+ due_date: Optional[int] = None
+ due_date_time: Optional[bool] = None
+ priority: Optional[int] = None
+ assignee: Optional[int] = None
+ status: Optional[str] = None
+
+ @classmethod
+ def from_data(cls, data: dict) -> "CUList":
+ return cls(
+ folder_id=data["folder_id"],
+ name=data["name"],
+ content=data.get("content"),
+ due_date=data.get("due_date"),
+ due_date_time=data.get("due_date_time"),
+ priority=data.get("priority"),
+ assignee=data.get("assignee"),
+ status=data.get("status"),
+ )
+
+
+@dataclass
+class Member(Component):
+ """Component class for a member."""
+
+ id: int
+ username: str
+ email: str
+ initials: str
+
+ @classmethod
+ def from_data(cls, data: Dict) -> "Member":
+ return cls(
+ id=data["user"]["id"],
+ username=data["user"]["username"],
+ email=data["user"]["email"],
+ initials=data["user"]["initials"],
+ )
+
+
+@dataclass
+class Team(Component):
+ """Component class for a team."""
+
+ id: int
+ name: str
+ members: List[Member]
+
+ @classmethod
+ def from_data(cls, data: Dict) -> "Team":
+ members = [Member.from_data(member_data) for member_data in data["members"]]
+ return cls(id=data["id"], name=data["name"], members=members)
+
+
+@dataclass
+class Space(Component):
+ """Component class for a space."""
+
+ id: int
+ name: str
+ private: bool
+ enabled_features: Dict[str, Any]
+
+ @classmethod
+ def from_data(cls, data: Dict[str, Any]) -> "Space":
+ space_data = data["spaces"][0]
+ enabled_features = {
+ feature: value
+ for feature, value in space_data["features"].items()
+ if value["enabled"]
+ }
+ return cls(
+ id=space_data["id"],
+ name=space_data["name"],
+ private=space_data["private"],
+ enabled_features=enabled_features,
+ )
+
+
+def parse_dict_through_component(
+ data: dict, component: Type[Component], fault_tolerant: bool = False
+) -> Dict:
+ """Parse a dictionary by creating
+ a component and then turning it back into a dictionary.
+
+ This helps with two things
+ 1. Extract and format data from a dictionary according to schema
+ 2. Provide a central place to do this in a fault-tolerant way
+
+ """
+ try:
+ return asdict(component.from_data(data))
+ except Exception as e:
+ if fault_tolerant:
+ warning_str = f"""Error encountered while trying to parse
+{str(data)}: {str(e)}\n Falling back to returning input data."""
+ warnings.warn(warning_str)
+ return data
+ else:
+ raise e
+
+
+def extract_dict_elements_from_component_fields(
+ data: dict, component: Type[Component]
+) -> dict:
+ """Extract elements from a dictionary.
+
+ Args:
+ data: The dictionary to extract elements from.
+ component: The component to extract elements from.
+
+ Returns:
+ A dictionary containing the elements from the input dictionary that are also
+ in the component.
+ """
+ output = {}
+ for attribute in fields(component):
+ if attribute.name in data:
+ output[attribute.name] = data[attribute.name]
+ return output
+
+
+def load_query(
+ query: str, fault_tolerant: bool = False
+) -> Tuple[Optional[Dict], Optional[str]]:
+ """Attempts to parse a JSON string and return the parsed object.
+
+ If parsing fails, returns an error message.
+
+ :param query: The JSON string to parse.
+ :return: A tuple containing the parsed object or None and an error message or None.
+ """
+ try:
+ return json.loads(query), None
+ except json.JSONDecodeError as e:
+ if fault_tolerant:
+ return (
+ None,
+ f"""Input must be a valid JSON. Got the following error: {str(e)}.
+"Please reformat and try again.""",
+ )
+ else:
+ raise e
+
+
+def fetch_first_id(data: dict, key: str) -> Optional[int]:
+ """Fetch the first id from a dictionary."""
+ if key in data and len(data[key]) > 0:
+ if len(data[key]) > 1:
+ warnings.warn(f"Found multiple {key}: {data[key]}. Defaulting to first.")
+ return data[key][0]["id"]
+ return None
+
+
+def fetch_data(url: str, access_token: str, query: Optional[dict] = None) -> dict:
+ """Fetch data from a URL."""
+ headers = {"Authorization": access_token}
+ response = requests.get(url, headers=headers, params=query)
+ response.raise_for_status()
+ return response.json()
+
+
+def fetch_team_id(access_token: str) -> Optional[int]:
+ """Fetch the team id."""
+ url = f"{DEFAULT_URL}/team"
+ data = fetch_data(url, access_token)
+ return fetch_first_id(data, "teams")
+
+
+def fetch_space_id(team_id: int, access_token: str) -> Optional[int]:
+ """Fetch the space id."""
+ url = f"{DEFAULT_URL}/team/{team_id}/space"
+ data = fetch_data(url, access_token, query={"archived": "false"})
+ return fetch_first_id(data, "spaces")
+
+
+def fetch_folder_id(space_id: int, access_token: str) -> Optional[int]:
+ """Fetch the folder id."""
+ url = f"{DEFAULT_URL}/space/{space_id}/folder"
+ data = fetch_data(url, access_token, query={"archived": "false"})
+ return fetch_first_id(data, "folders")
+
+
+def fetch_list_id(space_id: int, folder_id: int, access_token: str) -> Optional[int]:
+ """Fetch the list id."""
+ if folder_id:
+ url = f"{DEFAULT_URL}/folder/{folder_id}/list"
+ else:
+ url = f"{DEFAULT_URL}/space/{space_id}/list"
+
+ data = fetch_data(url, access_token, query={"archived": "false"})
+
+ # The structure to fetch list id differs based if its folderless
+ if folder_id and "id" in data:
+ return data["id"]
+ else:
+ return fetch_first_id(data, "lists")
+
+
+class ClickupAPIWrapper(BaseModel):
+ """Wrapper for Clickup API."""
+
+ access_token: Optional[str] = None
+ team_id: Optional[str] = None
+ space_id: Optional[str] = None
+ folder_id: Optional[str] = None
+ list_id: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @classmethod
+ def get_access_code_url(
+ cls, oauth_client_id: str, redirect_uri: str = "https://google.com"
+ ) -> str:
+ """Get the URL to get an access code."""
+ url = f"https://app.clickup.com/api?client_id={oauth_client_id}"
+ return f"{url}&redirect_uri={redirect_uri}"
+
+ @classmethod
+ def get_access_token(
+ cls, oauth_client_id: str, oauth_client_secret: str, code: str
+ ) -> Optional[str]:
+ """Get the access token."""
+ url = f"{DEFAULT_URL}/oauth/token"
+
+ params = {
+ "client_id": oauth_client_id,
+ "client_secret": oauth_client_secret,
+ "code": code,
+ }
+
+ response = requests.post(url, params=params)
+ data = response.json()
+
+ if "access_token" not in data:
+ print(f"Error: {data}")
+ if "ECODE" in data and data["ECODE"] == "OAUTH_014":
+ url = ClickupAPIWrapper.get_access_code_url(oauth_client_id)
+ print(
+ "You already used this code once. Generate a new one.",
+ f"Our best guess for the url to get a new code is:\n{url}",
+ )
+ return None
+
+ return data["access_token"]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["access_token"] = get_from_dict_or_env(
+ values, "access_token", "CLICKUP_ACCESS_TOKEN"
+ )
+ values["team_id"] = fetch_team_id(values["access_token"])
+ values["space_id"] = fetch_space_id(values["team_id"], values["access_token"])
+ values["folder_id"] = fetch_folder_id(
+ values["space_id"], values["access_token"]
+ )
+ values["list_id"] = fetch_list_id(
+ values["space_id"], values["folder_id"], values["access_token"]
+ )
+
+ return values
+
+ def attempt_parse_teams(self, input_dict: dict) -> Dict[str, List[dict]]:
+ """Parse appropriate content from the list of teams."""
+ parsed_teams: Dict[str, List[dict]] = {"teams": []}
+ for team in input_dict["teams"]:
+ try:
+ team = parse_dict_through_component(team, Team, fault_tolerant=False)
+ parsed_teams["teams"].append(team)
+ except Exception as e:
+ warnings.warn(f"Error parsing a team {e}")
+
+ return parsed_teams
+
+ def get_headers(
+ self,
+ ) -> Mapping[str, Union[str, bytes]]:
+ """Get the headers for the request."""
+ if not isinstance(self.access_token, str):
+ raise TypeError(f"Access Token: {self.access_token}, must be str.")
+
+ headers = {
+ "Authorization": str(self.access_token),
+ "Content-Type": "application/json",
+ }
+ return headers
+
+ def get_default_params(self) -> Dict:
+ return {"archived": "false"}
+
+ def get_authorized_teams(self) -> Dict[Any, Any]:
+ """Get all teams for the user."""
+ url = f"{DEFAULT_URL}/team"
+
+ response = requests.get(url, headers=self.get_headers())
+
+ data = response.json()
+ parsed_teams = self.attempt_parse_teams(data)
+
+ return parsed_teams
+
+ def get_folders(self) -> Dict:
+ """
+ Get all the folders for the team.
+ """
+ url = f"{DEFAULT_URL}/team/" + str(self.team_id) + "/space"
+ params = self.get_default_params()
+ response = requests.get(url, headers=self.get_headers(), params=params)
+ return {"response": response}
+
+ def get_task(self, query: str, fault_tolerant: bool = True) -> Dict:
+ """
+ Retrieve a specific task.
+ """
+
+ params, error = load_query(query, fault_tolerant=True)
+ if params is None:
+ return {"Error": error}
+
+ url = f"{DEFAULT_URL}/task/{params['task_id']}"
+ params = {
+ "custom_task_ids": "true",
+ "team_id": self.team_id,
+ "include_subtasks": "true",
+ }
+ response = requests.get(url, headers=self.get_headers(), params=params)
+ data = response.json()
+ parsed_task = parse_dict_through_component(
+ data, Task, fault_tolerant=fault_tolerant
+ )
+
+ return parsed_task
+
+ def get_lists(self) -> Dict:
+ """
+ Get all available lists.
+ """
+
+ url = f"{DEFAULT_URL}/folder/{self.folder_id}/list"
+ params = self.get_default_params()
+ response = requests.get(url, headers=self.get_headers(), params=params)
+ return {"response": response}
+
+ def query_tasks(self, query: str) -> Dict:
+ """
+ Query tasks that match certain fields
+ """
+ params, error = load_query(query, fault_tolerant=True)
+ if params is None:
+ return {"Error": error}
+
+ url = f"{DEFAULT_URL}/list/{params['list_id']}/task"
+
+ params = self.get_default_params()
+ response = requests.get(url, headers=self.get_headers(), params=params)
+
+ return {"response": response}
+
+ def get_spaces(self) -> Dict:
+ """
+ Get all spaces for the team.
+ """
+ url = f"{DEFAULT_URL}/team/{self.team_id}/space"
+ response = requests.get(
+ url, headers=self.get_headers(), params=self.get_default_params()
+ )
+ data = response.json()
+ parsed_spaces = parse_dict_through_component(data, Space, fault_tolerant=True)
+ return parsed_spaces
+
+ def get_task_attribute(self, query: str) -> Dict:
+ """
+ Update an attribute of a specified task.
+ """
+
+ task = self.get_task(query, fault_tolerant=True)
+ params, error = load_query(query, fault_tolerant=True)
+ if not isinstance(params, dict):
+ return {"Error": error}
+
+ if params["attribute_name"] not in task:
+ return {
+ "Error": f"""attribute_name = {params['attribute_name']} was not
+found in task keys {task.keys()}. Please call again with one of the key names."""
+ }
+
+ return {params["attribute_name"]: task[params["attribute_name"]]}
+
+ def update_task(self, query: str) -> Dict:
+ """
+ Update an attribute of a specified task.
+ """
+ query_dict, error = load_query(query, fault_tolerant=True)
+ if query_dict is None:
+ return {"Error": error}
+
+ url = f"{DEFAULT_URL}/task/{query_dict['task_id']}"
+ params = {
+ "custom_task_ids": "true",
+ "team_id": self.team_id,
+ "include_subtasks": "true",
+ }
+ headers = self.get_headers()
+ payload = {query_dict["attribute_name"]: query_dict["value"]}
+
+ response = requests.put(url, headers=headers, params=params, json=payload)
+
+ return {"response": response}
+
+ def update_task_assignees(self, query: str) -> Dict:
+ """
+ Add or remove assignees of a specified task.
+ """
+ query_dict, error = load_query(query, fault_tolerant=True)
+ if query_dict is None:
+ return {"Error": error}
+
+ for user in query_dict["users"]:
+ if not isinstance(user, int):
+ return {
+ "Error": f"""All users must be integers, not strings!
+"Got user {user} if type {type(user)}"""
+ }
+
+ url = f"{DEFAULT_URL}/task/{query_dict['task_id']}"
+
+ headers = self.get_headers()
+
+ if query_dict["operation"] == "add":
+ assigne_payload = {"add": query_dict["users"], "rem": []}
+ elif query_dict["operation"] == "rem":
+ assigne_payload = {"add": [], "rem": query_dict["users"]}
+ else:
+ raise ValueError(
+ f"Invalid operation ({query_dict['operation']}). ",
+ "Valid options ['add', 'rem'].",
+ )
+
+ params = {
+ "custom_task_ids": "true",
+ "team_id": self.team_id,
+ "include_subtasks": "true",
+ }
+
+ payload = {"assignees": assigne_payload}
+ response = requests.put(url, headers=headers, params=params, json=payload)
+ return {"response": response}
+
+ def create_task(self, query: str) -> Dict:
+ """
+ Creates a new task.
+ """
+ query_dict, error = load_query(query, fault_tolerant=True)
+ if query_dict is None:
+ return {"Error": error}
+
+ list_id = self.list_id
+ url = f"{DEFAULT_URL}/list/{list_id}/task"
+ params = {"custom_task_ids": "true", "team_id": self.team_id}
+
+ payload = extract_dict_elements_from_component_fields(query_dict, Task)
+ headers = self.get_headers()
+
+ response = requests.post(url, json=payload, headers=headers, params=params)
+ data: Dict = response.json()
+ return parse_dict_through_component(data, Task, fault_tolerant=True)
+
+ def create_list(self, query: str) -> Dict:
+ """
+ Creates a new list.
+ """
+ query_dict, error = load_query(query, fault_tolerant=True)
+ if query_dict is None:
+ return {"Error": error}
+
+ # Default to using folder as location if it exists.
+ # If not, fall back to using the space.
+ location = self.folder_id if self.folder_id else self.space_id
+ url = f"{DEFAULT_URL}/folder/{location}/list"
+
+ payload = extract_dict_elements_from_component_fields(query_dict, Task)
+ headers = self.get_headers()
+
+ response = requests.post(url, json=payload, headers=headers)
+ data = response.json()
+ parsed_list = parse_dict_through_component(data, CUList, fault_tolerant=True)
+ # set list id to new list
+ if "id" in parsed_list:
+ self.list_id = parsed_list["id"]
+ return parsed_list
+
+ def create_folder(self, query: str) -> Dict:
+ """
+ Creates a new folder.
+ """
+
+ query_dict, error = load_query(query, fault_tolerant=True)
+ if query_dict is None:
+ return {"Error": error}
+
+ space_id = self.space_id
+ url = f"{DEFAULT_URL}/space/{space_id}/folder"
+ payload = {
+ "name": query_dict["name"],
+ }
+
+ headers = self.get_headers()
+
+ response = requests.post(url, json=payload, headers=headers)
+ data = response.json()
+
+ if "id" in data:
+ self.list_id = data["id"]
+ return data
+
+ def run(self, mode: str, query: str) -> str:
+ """Run the API."""
+ if mode == "get_task":
+ output = self.get_task(query)
+ elif mode == "get_task_attribute":
+ output = self.get_task_attribute(query)
+ elif mode == "get_teams":
+ output = self.get_authorized_teams()
+ elif mode == "create_task":
+ output = self.create_task(query)
+ elif mode == "create_list":
+ output = self.create_list(query)
+ elif mode == "create_folder":
+ output = self.create_folder(query)
+ elif mode == "get_lists":
+ output = self.get_lists()
+ elif mode == "get_folders":
+ output = self.get_folders()
+ elif mode == "get_spaces":
+ output = self.get_spaces()
+ elif mode == "update_task":
+ output = self.update_task(query)
+ elif mode == "update_task_assignees":
+ output = self.update_task_assignees(query)
+ else:
+ output = {"ModeError": f"Got unexpected mode {mode}."}
+
+ try:
+ return json.dumps(output)
+ except Exception:
+ return str(output)
diff --git a/libs/community/langchain_community/utilities/dalle_image_generator.py b/libs/community/langchain_community/utilities/dalle_image_generator.py
new file mode 100644
index 00000000000..6dddb710be8
--- /dev/null
+++ b/libs/community/langchain_community/utilities/dalle_image_generator.py
@@ -0,0 +1,164 @@
+"""Utility that calls OpenAI's Dall-E Image Generator."""
+import logging
+import os
+from typing import Any, Dict, Mapping, Optional, Tuple, Union
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import (
+ get_from_dict_or_env,
+ get_pydantic_field_names,
+)
+
+from langchain_community.utils.openai import is_openai_v1
+
+logger = logging.getLogger(__name__)
+
+
+class DallEAPIWrapper(BaseModel):
+ """Wrapper for OpenAI's DALL-E Image Generator.
+
+ https://platform.openai.com/docs/guides/images/generations?context=node
+
+ Usage instructions:
+
+ 1. `pip install openai`
+ 2. save your OPENAI_API_KEY in an environment variable
+ """
+
+ client: Any #: :meta private:
+ async_client: Any = Field(default=None, exclude=True) #: :meta private:
+ model_name: str = Field(default="dall-e-2", alias="model")
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ openai_api_key: Optional[str] = Field(default=None, alias="api_key")
+ """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
+ openai_api_base: Optional[str] = Field(default=None, alias="base_url")
+ """Base URL path for API requests, leave blank if not using a proxy or service
+ emulator."""
+ openai_organization: Optional[str] = Field(default=None, alias="organization")
+ """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
+ # to support explicit proxy for OpenAI
+ openai_proxy: Optional[str] = None
+ request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
+ default=None, alias="timeout"
+ )
+ n: int = 1
+ """Number of images to generate"""
+ size: str = "1024x1024"
+ """Size of image to generate"""
+ separator: str = "\n"
+ """Separator to use when multiple URLs are returned."""
+ quality: Optional[str] = "standard"
+ """Quality of the image that will be generated"""
+ max_retries: int = 2
+ """Maximum number of retries to make when generating."""
+ default_headers: Union[Mapping[str, str], None] = None
+ default_query: Union[Mapping[str, object], None] = None
+ # Configure a custom httpx client. See the
+ # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
+ http_client: Union[Any, None] = None
+ """Optional httpx.Client."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["openai_api_key"] = get_from_dict_or_env(
+ values, "openai_api_key", "OPENAI_API_KEY"
+ )
+ # Check OPENAI_ORGANIZATION for backwards compatibility.
+ values["openai_organization"] = (
+ values["openai_organization"]
+ or os.getenv("OPENAI_ORG_ID")
+ or os.getenv("OPENAI_ORGANIZATION")
+ or None
+ )
+ values["openai_api_base"] = values["openai_api_base"] or os.getenv(
+ "OPENAI_API_BASE"
+ )
+ values["openai_proxy"] = get_from_dict_or_env(
+ values,
+ "openai_proxy",
+ "OPENAI_PROXY",
+ default="",
+ )
+
+ try:
+ import openai
+
+ except ImportError:
+ raise ImportError(
+ "Could not import openai python package. "
+ "Please install it with `pip install openai`."
+ )
+
+ if is_openai_v1():
+ client_params = {
+ "api_key": values["openai_api_key"],
+ "organization": values["openai_organization"],
+ "base_url": values["openai_api_base"],
+ "timeout": values["request_timeout"],
+ "max_retries": values["max_retries"],
+ "default_headers": values["default_headers"],
+ "default_query": values["default_query"],
+ "http_client": values["http_client"],
+ }
+
+ if not values.get("client"):
+ values["client"] = openai.OpenAI(**client_params).images
+ if not values.get("async_client"):
+ values["async_client"] = openai.AsyncOpenAI(**client_params).images
+ elif not values.get("client"):
+ values["client"] = openai.Image
+ else:
+ pass
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through OpenAI and parse result."""
+
+ if is_openai_v1():
+ response = self.client.generate(
+ prompt=query,
+ n=self.n,
+ size=self.size,
+ model=self.model_name,
+ quality=self.quality,
+ )
+ image_urls = self.separator.join([item.url for item in response.data])
+ else:
+ response = self.client.create(
+ prompt=query, n=self.n, size=self.size, model=self.model_name
+ )
+ image_urls = self.separator.join([item["url"] for item in response["data"]])
+
+ return image_urls if image_urls else "No image was generated"
diff --git a/libs/community/langchain_community/utilities/dataforseo_api_search.py b/libs/community/langchain_community/utilities/dataforseo_api_search.py
new file mode 100644
index 00000000000..16c90fc4ff6
--- /dev/null
+++ b/libs/community/langchain_community/utilities/dataforseo_api_search.py
@@ -0,0 +1,195 @@
+import base64
+from typing import Dict, Optional
+from urllib.parse import quote
+
+import aiohttp
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class DataForSeoAPIWrapper(BaseModel):
+ """Wrapper around the DataForSeo API."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ default_params: dict = Field(
+ default={
+ "location_name": "United States",
+ "language_code": "en",
+ "depth": 10,
+ "se_name": "google",
+ "se_type": "organic",
+ }
+ )
+ """Default parameters to use for the DataForSEO SERP API."""
+ params: dict = Field(default={})
+ """Additional parameters to pass to the DataForSEO SERP API."""
+ api_login: Optional[str] = None
+ """The API login to use for the DataForSEO SERP API."""
+ api_password: Optional[str] = None
+ """The API password to use for the DataForSEO SERP API."""
+ json_result_types: Optional[list] = None
+ """The JSON result types."""
+ json_result_fields: Optional[list] = None
+ """The JSON result fields."""
+ top_count: Optional[int] = None
+ """The number of top results to return."""
+ aiosession: Optional[aiohttp.ClientSession] = None
+ """The aiohttp session to use for the DataForSEO SERP API."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that login and password exists in environment."""
+ login = get_from_dict_or_env(values, "api_login", "DATAFORSEO_LOGIN")
+ password = get_from_dict_or_env(values, "api_password", "DATAFORSEO_PASSWORD")
+ values["api_login"] = login
+ values["api_password"] = password
+ return values
+
+ async def arun(self, url: str) -> str:
+ """Run request to DataForSEO SERP API and parse result async."""
+ return self._process_response(await self._aresponse_json(url))
+
+ def run(self, url: str) -> str:
+ """Run request to DataForSEO SERP API and parse result async."""
+ return self._process_response(self._response_json(url))
+
+ def results(self, url: str) -> list:
+ res = self._response_json(url)
+ return self._filter_results(res)
+
+ async def aresults(self, url: str) -> list:
+ res = await self._aresponse_json(url)
+ return self._filter_results(res)
+
+ def _prepare_request(self, keyword: str) -> dict:
+ """Prepare the request details for the DataForSEO SERP API."""
+ if self.api_login is None or self.api_password is None:
+ raise ValueError("api_login or api_password is not provided")
+ cred = base64.b64encode(
+ f"{self.api_login}:{self.api_password}".encode("utf-8")
+ ).decode("utf-8")
+ headers = {"Authorization": f"Basic {cred}", "Content-Type": "application/json"}
+ obj = {"keyword": quote(keyword)}
+ obj = {**obj, **self.default_params, **self.params}
+ data = [obj]
+ _url = (
+ f"https://api.dataforseo.com/v3/serp/{obj['se_name']}"
+ f"/{obj['se_type']}/live/advanced"
+ )
+ return {
+ "url": _url,
+ "headers": headers,
+ "data": data,
+ }
+
+ def _check_response(self, response: dict) -> dict:
+ """Check the response from the DataForSEO SERP API for errors."""
+ if response.get("status_code") != 20000:
+ raise ValueError(
+ f"Got error from DataForSEO SERP API: {response.get('status_message')}"
+ )
+ return response
+
+ def _response_json(self, url: str) -> dict:
+ """Use requests to run request to DataForSEO SERP API and return results."""
+ request_details = self._prepare_request(url)
+ response = requests.post(
+ request_details["url"],
+ headers=request_details["headers"],
+ json=request_details["data"],
+ )
+ response.raise_for_status()
+ return self._check_response(response.json())
+
+ async def _aresponse_json(self, url: str) -> dict:
+ """Use aiohttp to request DataForSEO SERP API and return results async."""
+ request_details = self._prepare_request(url)
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ request_details["url"],
+ headers=request_details["headers"],
+ json=request_details["data"],
+ ) as response:
+ res = await response.json()
+ else:
+ async with self.aiosession.post(
+ request_details["url"],
+ headers=request_details["headers"],
+ json=request_details["data"],
+ ) as response:
+ res = await response.json()
+ return self._check_response(res)
+
+ def _filter_results(self, res: dict) -> list:
+ output = []
+ types = self.json_result_types if self.json_result_types is not None else []
+ for task in res.get("tasks", []):
+ for result in task.get("result", []):
+ for item in result.get("items", []):
+ if len(types) == 0 or item.get("type", "") in types:
+ self._cleanup_unnecessary_items(item)
+ if len(item) != 0:
+ output.append(item)
+ if self.top_count is not None and len(output) >= self.top_count:
+ break
+ return output
+
+ def _cleanup_unnecessary_items(self, d: dict) -> dict:
+ fields = self.json_result_fields if self.json_result_fields is not None else []
+ if len(fields) > 0:
+ for k, v in list(d.items()):
+ if isinstance(v, dict):
+ self._cleanup_unnecessary_items(v)
+ if len(v) == 0:
+ del d[k]
+ elif k not in fields:
+ del d[k]
+
+ if "xpath" in d:
+ del d["xpath"]
+ if "position" in d:
+ del d["position"]
+ if "rectangle" in d:
+ del d["rectangle"]
+ for k, v in list(d.items()):
+ if isinstance(v, dict):
+ self._cleanup_unnecessary_items(v)
+ return d
+
+ def _process_response(self, res: dict) -> str:
+ """Process response from DataForSEO SERP API."""
+ toret = "No good search result found"
+ for task in res.get("tasks", []):
+ for result in task.get("result", []):
+ item_types = result.get("item_types")
+ items = result.get("items", [])
+ if "answer_box" in item_types:
+ toret = next(
+ item for item in items if item.get("type") == "answer_box"
+ ).get("text")
+ elif "knowledge_graph" in item_types:
+ toret = next(
+ item for item in items if item.get("type") == "knowledge_graph"
+ ).get("description")
+ elif "featured_snippet" in item_types:
+ toret = next(
+ item for item in items if item.get("type") == "featured_snippet"
+ ).get("description")
+ elif "shopping" in item_types:
+ toret = next(
+ item for item in items if item.get("type") == "shopping"
+ ).get("price")
+ elif "organic" in item_types:
+ toret = next(
+ item for item in items if item.get("type") == "organic"
+ ).get("description")
+ if toret:
+ break
+ return toret
diff --git a/libs/community/langchain_community/utilities/duckduckgo_search.py b/libs/community/langchain_community/utilities/duckduckgo_search.py
new file mode 100644
index 00000000000..d258726896a
--- /dev/null
+++ b/libs/community/langchain_community/utilities/duckduckgo_search.py
@@ -0,0 +1,130 @@
+"""Util that calls DuckDuckGo Search.
+
+No setup required. Free.
+https://pypi.org/project/duckduckgo-search/
+"""
+from typing import Dict, List, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+
+class DuckDuckGoSearchAPIWrapper(BaseModel):
+ """Wrapper for DuckDuckGo Search API.
+
+ Free and does not require any setup.
+ """
+
+ region: Optional[str] = "wt-wt"
+ safesearch: str = "moderate"
+ time: Optional[str] = "y"
+ max_results: int = 5
+ backend: str = "api" # which backend to use in DDGS.text() (api, html, lite)
+ source: str = "text" # which function to use in DDGS (DDGS.text() or DDGS.news())
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+ try:
+ from duckduckgo_search import DDGS # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import duckduckgo-search python package. "
+ "Please install it with `pip install -U duckduckgo-search`."
+ )
+ return values
+
+ def _ddgs_text(
+ self, query: str, max_results: Optional[int] = None
+ ) -> List[Dict[str, str]]:
+ """Run query through DuckDuckGo text search and return results."""
+ from duckduckgo_search import DDGS
+
+ with DDGS() as ddgs:
+ ddgs_gen = ddgs.text(
+ query,
+ region=self.region,
+ safesearch=self.safesearch,
+ timelimit=self.time,
+ max_results=max_results or self.max_results,
+ backend=self.backend,
+ )
+ if ddgs_gen:
+ return [r for r in ddgs_gen]
+ return []
+
+ def _ddgs_news(
+ self, query: str, max_results: Optional[int] = None
+ ) -> List[Dict[str, str]]:
+ """Run query through DuckDuckGo news search and return results."""
+ from duckduckgo_search import DDGS
+
+ with DDGS() as ddgs:
+ ddgs_gen = ddgs.news(
+ query,
+ region=self.region,
+ safesearch=self.safesearch,
+ timelimit=self.time,
+ max_results=max_results or self.max_results,
+ )
+ if ddgs_gen:
+ return [r for r in ddgs_gen]
+ return []
+
+ def run(self, query: str) -> str:
+ """Run query through DuckDuckGo and return concatenated results."""
+ if self.source == "text":
+ results = self._ddgs_text(query)
+ elif self.source == "news":
+ results = self._ddgs_news(query)
+ else:
+ results = []
+
+ if not results:
+ return "No good DuckDuckGo Search Result was found"
+ return " ".join(r["body"] for r in results)
+
+ def results(
+ self, query: str, max_results: int, source: Optional[str] = None
+ ) -> List[Dict[str, str]]:
+ """Run query through DuckDuckGo and return metadata.
+
+ Args:
+ query: The query to search for.
+ max_results: The number of results to return.
+ source: The source to look from.
+
+ Returns:
+ A list of dictionaries with the following keys:
+ snippet - The description of the result.
+ title - The title of the result.
+ link - The link to the result.
+ """
+ source = source or self.source
+ if source == "text":
+ results = [
+ {"snippet": r["body"], "title": r["title"], "link": r["href"]}
+ for r in self._ddgs_text(query, max_results=max_results)
+ ]
+ elif source == "news":
+ results = [
+ {
+ "snippet": r["body"],
+ "title": r["title"],
+ "link": r["url"],
+ "date": r["date"],
+ "source": r["source"],
+ }
+ for r in self._ddgs_news(query, max_results=max_results)
+ ]
+ else:
+ results = []
+
+ if results is None:
+ results = [{"Result": "No good DuckDuckGo Search Result was found"}]
+
+ return results
diff --git a/libs/community/langchain_community/utilities/github.py b/libs/community/langchain_community/utilities/github.py
new file mode 100644
index 00000000000..d6e1a0833e7
--- /dev/null
+++ b/libs/community/langchain_community/utilities/github.py
@@ -0,0 +1,839 @@
+"""Util that calls GitHub."""
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+if TYPE_CHECKING:
+ from github.Issue import Issue
+ from github.PullRequest import PullRequest
+
+
+def _import_tiktoken() -> Any:
+ """Import tiktoken."""
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "tiktoken is not installed. "
+ "Please install it with `pip install tiktoken`"
+ )
+ return tiktoken
+
+
+class GitHubAPIWrapper(BaseModel):
+ """Wrapper for GitHub API."""
+
+ github: Any #: :meta private:
+ github_repo_instance: Any #: :meta private:
+ github_repository: Optional[str] = None
+ github_app_id: Optional[str] = None
+ github_app_private_key: Optional[str] = None
+ active_branch: Optional[str] = None
+ github_base_branch: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ github_repository = get_from_dict_or_env(
+ values, "github_repository", "GITHUB_REPOSITORY"
+ )
+
+ github_app_id = get_from_dict_or_env(values, "github_app_id", "GITHUB_APP_ID")
+
+ github_app_private_key = get_from_dict_or_env(
+ values, "github_app_private_key", "GITHUB_APP_PRIVATE_KEY"
+ )
+
+ try:
+ from github import Auth, GithubIntegration
+
+ except ImportError:
+ raise ImportError(
+ "PyGithub is not installed. "
+ "Please install it with `pip install PyGithub`"
+ )
+
+ try:
+ # interpret the key as a file path
+ # fallback to interpreting as the key itself
+ with open(github_app_private_key, "r") as f:
+ private_key = f.read()
+ except Exception:
+ private_key = github_app_private_key
+
+ auth = Auth.AppAuth(
+ github_app_id,
+ private_key,
+ )
+ gi = GithubIntegration(auth=auth)
+ installation = gi.get_installations()[0]
+
+ # create a GitHub instance:
+ g = installation.get_github_for_installation()
+ repo = g.get_repo(github_repository)
+
+ github_base_branch = get_from_dict_or_env(
+ values,
+ "github_base_branch",
+ "GITHUB_BASE_BRANCH",
+ default=repo.default_branch,
+ )
+
+ active_branch = get_from_dict_or_env(
+ values,
+ "active_branch",
+ "ACTIVE_BRANCH",
+ default=repo.default_branch,
+ )
+
+ values["github"] = g
+ values["github_repo_instance"] = repo
+ values["github_repository"] = github_repository
+ values["github_app_id"] = github_app_id
+ values["github_app_private_key"] = github_app_private_key
+ values["active_branch"] = active_branch
+ values["github_base_branch"] = github_base_branch
+
+ return values
+
+ def parse_issues(self, issues: List[Issue]) -> List[dict]:
+ """
+ Extracts title and number from each Issue and puts them in a dictionary
+ Parameters:
+ issues(List[Issue]): A list of Github Issue objects
+ Returns:
+ List[dict]: A dictionary of issue titles and numbers
+ """
+ parsed = []
+ for issue in issues:
+ title = issue.title
+ number = issue.number
+ opened_by = issue.user.login if issue.user else None
+ issue_dict = {"title": title, "number": number}
+ if opened_by is not None:
+ issue_dict["opened_by"] = opened_by
+ parsed.append(issue_dict)
+ return parsed
+
+ def parse_pull_requests(self, pull_requests: List[PullRequest]) -> List[dict]:
+ """
+ Extracts title and number from each Issue and puts them in a dictionary
+ Parameters:
+ issues(List[Issue]): A list of Github Issue objects
+ Returns:
+ List[dict]: A dictionary of issue titles and numbers
+ """
+ parsed = []
+ for pr in pull_requests:
+ parsed.append(
+ {
+ "title": pr.title,
+ "number": pr.number,
+ "commits": str(pr.commits),
+ "comments": str(pr.comments),
+ }
+ )
+ return parsed
+
+ def get_issues(self) -> str:
+ """
+ Fetches all open issues from the repo excluding pull requests
+
+ Returns:
+ str: A plaintext report containing the number of issues
+ and each issue's title and number.
+ """
+ issues = self.github_repo_instance.get_issues(state="open")
+ # Filter out pull requests (part of GH issues object)
+ issues = [issue for issue in issues if not issue.pull_request]
+ if issues:
+ parsed_issues = self.parse_issues(issues)
+ parsed_issues_str = (
+ "Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
+ )
+ return parsed_issues_str
+ else:
+ return "No open issues available"
+
+ def list_open_pull_requests(self) -> str:
+ """
+ Fetches all open PRs from the repo
+
+ Returns:
+ str: A plaintext report containing the number of PRs
+ and each PR's title and number.
+ """
+ # issues = self.github_repo_instance.get_issues(state="open")
+ pull_requests = self.github_repo_instance.get_pulls(state="open")
+ if pull_requests.totalCount > 0:
+ parsed_prs = self.parse_pull_requests(pull_requests)
+ parsed_prs_str = (
+ "Found " + str(len(parsed_prs)) + " pull requests:\n" + str(parsed_prs)
+ )
+ return parsed_prs_str
+ else:
+ return "No open pull requests available"
+
+ def list_files_in_main_branch(self) -> str:
+ """
+ Fetches all files in the main branch of the repo.
+
+ Returns:
+ str: A plaintext report containing the paths and names of the files.
+ """
+ files: List[str] = []
+ try:
+ contents = self.github_repo_instance.get_contents(
+ "", ref=self.github_base_branch
+ )
+ for content in contents:
+ if content.type == "dir":
+ files.extend(self.get_files_from_directory(content.path))
+ else:
+ files.append(content.path)
+
+ if files:
+ files_str = "\n".join(files)
+ return f"Found {len(files)} files in the main branch:\n{files_str}"
+ else:
+ return "No files found in the main branch"
+ except Exception as e:
+ return str(e)
+
+ def set_active_branch(self, branch_name: str) -> str:
+ """Equivalent to `git checkout branch_name` for this Agent.
+ Clones formatting from Github.
+
+ Returns an Error (as a string) if branch doesn't exist.
+ """
+ curr_branches = [
+ branch.name for branch in self.github_repo_instance.get_branches()
+ ]
+ if branch_name in curr_branches:
+ self.active_branch = branch_name
+ return f"Switched to branch `{branch_name}`"
+ else:
+ return (
+ f"Error {branch_name} does not exist,"
+ f"in repo with current branches: {str(curr_branches)}"
+ )
+
+ def list_branches_in_repo(self) -> str:
+ """
+ Fetches a list of all branches in the repository.
+
+ Returns:
+ str: A plaintext report containing the names of the branches.
+ """
+ try:
+ branches = [
+ branch.name for branch in self.github_repo_instance.get_branches()
+ ]
+ if branches:
+ branches_str = "\n".join(branches)
+ return (
+ f"Found {len(branches)} branches in the repository:"
+ f"\n{branches_str}"
+ )
+ else:
+ return "No branches found in the repository"
+ except Exception as e:
+ return str(e)
+
+ def create_branch(self, proposed_branch_name: str) -> str:
+ """
+ Create a new branch, and set it as the active bot branch.
+ Equivalent to `git switch -c proposed_branch_name`
+ If the proposed branch already exists, we append _v1 then _v2...
+ until a unique name is found.
+
+ Returns:
+ str: A plaintext success message.
+ """
+ from github import GithubException
+
+ i = 0
+ new_branch_name = proposed_branch_name
+ base_branch = self.github_repo_instance.get_branch(
+ self.github_repo_instance.default_branch
+ )
+ for i in range(1000):
+ try:
+ self.github_repo_instance.create_git_ref(
+ ref=f"refs/heads/{new_branch_name}", sha=base_branch.commit.sha
+ )
+ self.active_branch = new_branch_name
+ return (
+ f"Branch '{new_branch_name}' "
+ "created successfully, and set as current active branch."
+ )
+ except GithubException as e:
+ if e.status == 422 and "Reference already exists" in e.data["message"]:
+ i += 1
+ new_branch_name = f"{proposed_branch_name}_v{i}"
+ else:
+ # Handle any other exceptions
+ print(f"Failed to create branch. Error: {e}")
+ raise Exception(
+ "Unable to create branch name from proposed_branch_name: "
+ f"{proposed_branch_name}"
+ )
+ return (
+ "Unable to create branch. "
+ "At least 1000 branches exist with named derived from "
+ f"proposed_branch_name: `{proposed_branch_name}`"
+ )
+
+ def list_files_in_bot_branch(self) -> str:
+ """
+ Fetches all files in the active branch of the repo,
+ the branch the bot uses to make changes.
+
+ Returns:
+ str: A plaintext list containing the the filepaths in the branch.
+ """
+ files: List[str] = []
+ try:
+ contents = self.github_repo_instance.get_contents(
+ "", ref=self.active_branch
+ )
+ for content in contents:
+ if content.type == "dir":
+ files.extend(self.get_files_from_directory(content.path))
+ else:
+ files.append(content.path)
+
+ if files:
+ files_str = "\n".join(files)
+ return (
+ f"Found {len(files)} files in branch `{self.active_branch}`:\n"
+ f"{files_str}"
+ )
+ else:
+ return f"No files found in branch: `{self.active_branch}`"
+ except Exception as e:
+ return f"Error: {e}"
+
+ def get_files_from_directory(self, directory_path: str) -> str:
+ """
+ Recursively fetches files from a directory in the repo.
+
+ Parameters:
+ directory_path (str): Path to the directory
+
+ Returns:
+ str: List of file paths, or an error message.
+ """
+ from github import GithubException
+
+ files: List[str] = []
+ try:
+ contents = self.github_repo_instance.get_contents(
+ directory_path, ref=self.active_branch
+ )
+ except GithubException as e:
+ return f"Error: status code {e.status}, {e.message}"
+
+ for content in contents:
+ if content.type == "dir":
+ files.extend(self.get_files_from_directory(content.path))
+ else:
+ files.append(content.path)
+ return str(files)
+
+ def get_issue(self, issue_number: int) -> Dict[str, Any]:
+ """
+ Fetches a specific issue and its first 10 comments
+ Parameters:
+ issue_number(int): The number for the github issue
+ Returns:
+ dict: A dictionary containing the issue's title,
+ body, comments as a string, and the username of the user
+ who opened the issue
+ """
+ issue = self.github_repo_instance.get_issue(number=issue_number)
+ page = 0
+ comments: List[dict] = []
+ while len(comments) <= 10:
+ comments_page = issue.get_comments().get_page(page)
+ if len(comments_page) == 0:
+ break
+ for comment in comments_page:
+ comments.append({"body": comment.body, "user": comment.user.login})
+ page += 1
+
+ opened_by = None
+ if issue.user and issue.user.login:
+ opened_by = issue.user.login
+
+ return {
+ "number": issue_number,
+ "title": issue.title,
+ "body": issue.body,
+ "comments": str(comments),
+ "opened_by": str(opened_by),
+ }
+
+ def list_pull_request_files(self, pr_number: int) -> List[Dict[str, Any]]:
+ """Fetches the full text of all files in a PR. Truncates after first 3k tokens.
+ # TODO: Enhancement to summarize files with ctags if they're getting long.
+
+ Args:
+ pr_number(int): The number of the pull request on Github
+
+ Returns:
+ dict: A dictionary containing the issue's title,
+ body, and comments as a string
+ """
+ tiktoken = _import_tiktoken()
+ MAX_TOKENS_FOR_FILES = 3_000
+ pr_files = []
+ pr = self.github_repo_instance.get_pull(number=int(pr_number))
+ total_tokens = 0
+ page = 0
+ while True: # or while (total_tokens + tiktoken()) < MAX_TOKENS_FOR_FILES:
+ files_page = pr.get_files().get_page(page)
+ if len(files_page) == 0:
+ break
+ for file in files_page:
+ try:
+ file_metadata_response = requests.get(file.contents_url)
+ if file_metadata_response.status_code == 200:
+ download_url = json.loads(file_metadata_response.text)[
+ "download_url"
+ ]
+ else:
+ print(f"Failed to download file: {file.contents_url}, skipping")
+ continue
+
+ file_content_response = requests.get(download_url)
+ if file_content_response.status_code == 200:
+ # Save the content as a UTF-8 string
+ file_content = file_content_response.text
+ else:
+ print(
+ "Failed downloading file content "
+ f"(Error {file_content_response.status_code}). Skipping"
+ )
+ continue
+
+ file_tokens = len(
+ tiktoken.get_encoding("cl100k_base").encode(
+ file_content + file.filename + "file_name file_contents"
+ )
+ )
+ if (total_tokens + file_tokens) < MAX_TOKENS_FOR_FILES:
+ pr_files.append(
+ {
+ "filename": file.filename,
+ "contents": file_content,
+ "additions": file.additions,
+ "deletions": file.deletions,
+ }
+ )
+ total_tokens += file_tokens
+ except Exception as e:
+ print(f"Error when reading files from a PR on github. {e}")
+ page += 1
+ return pr_files
+
+ def get_pull_request(self, pr_number: int) -> Dict[str, Any]:
+ """
+ Fetches a specific pull request and its first 10 comments,
+ limited by max_tokens.
+
+ Parameters:
+ pr_number(int): The number for the Github pull
+ max_tokens(int): The maximum number of tokens in the response
+ Returns:
+ dict: A dictionary containing the pull's title, body,
+ and comments as a string
+ """
+ max_tokens = 2_000
+ pull = self.github_repo_instance.get_pull(number=pr_number)
+ total_tokens = 0
+
+ def get_tokens(text: str) -> int:
+ tiktoken = _import_tiktoken()
+ return len(tiktoken.get_encoding("cl100k_base").encode(text))
+
+ def add_to_dict(data_dict: Dict[str, Any], key: str, value: str) -> None:
+ nonlocal total_tokens # Declare total_tokens as nonlocal
+ tokens = get_tokens(value)
+ if total_tokens + tokens <= max_tokens:
+ data_dict[key] = value
+ total_tokens += tokens # Now this will modify the outer variable
+
+ response_dict: Dict[str, str] = {}
+ add_to_dict(response_dict, "title", pull.title)
+ add_to_dict(response_dict, "number", str(pr_number))
+ add_to_dict(response_dict, "body", pull.body)
+
+ comments: List[str] = []
+ page = 0
+ while len(comments) <= 10:
+ comments_page = pull.get_issue_comments().get_page(page)
+ if len(comments_page) == 0:
+ break
+ for comment in comments_page:
+ comment_str = str({"body": comment.body, "user": comment.user.login})
+ if total_tokens + get_tokens(comment_str) > max_tokens:
+ break
+ comments.append(comment_str)
+ total_tokens += get_tokens(comment_str)
+ page += 1
+ add_to_dict(response_dict, "comments", str(comments))
+
+ commits: List[str] = []
+ page = 0
+ while len(commits) <= 10:
+ commits_page = pull.get_commits().get_page(page)
+ if len(commits_page) == 0:
+ break
+ for commit in commits_page:
+ commit_str = str({"message": commit.commit.message})
+ if total_tokens + get_tokens(commit_str) > max_tokens:
+ break
+ commits.append(commit_str)
+ total_tokens += get_tokens(commit_str)
+ page += 1
+ add_to_dict(response_dict, "commits", str(commits))
+ return response_dict
+
+ def create_pull_request(self, pr_query: str) -> str:
+ """
+ Makes a pull request from the bot's branch to the base branch
+ Parameters:
+ pr_query(str): a string which contains the PR title
+ and the PR body. The title is the first line
+ in the string, and the body are the rest of the string.
+ For example, "Updated README\nmade changes to add info"
+ Returns:
+ str: A success or failure message
+ """
+ if self.github_base_branch == self.active_branch:
+ return """Cannot make a pull request because
+ commits are already in the main or master branch."""
+ else:
+ try:
+ title = pr_query.split("\n")[0]
+ body = pr_query[len(title) + 2 :]
+ pr = self.github_repo_instance.create_pull(
+ title=title,
+ body=body,
+ head=self.active_branch,
+ base=self.github_base_branch,
+ )
+ return f"Successfully created PR number {pr.number}"
+ except Exception as e:
+ return "Unable to make pull request due to error:\n" + str(e)
+
+ def comment_on_issue(self, comment_query: str) -> str:
+ """
+ Adds a comment to a github issue
+ Parameters:
+ comment_query(str): a string which contains the issue number,
+ two newlines, and the comment.
+ for example: "1\n\nWorking on it now"
+ adds the comment "working on it now" to issue 1
+ Returns:
+ str: A success or failure message
+ """
+ issue_number = int(comment_query.split("\n\n")[0])
+ comment = comment_query[len(str(issue_number)) + 2 :]
+ try:
+ issue = self.github_repo_instance.get_issue(number=issue_number)
+ issue.create_comment(comment)
+ return "Commented on issue " + str(issue_number)
+ except Exception as e:
+ return "Unable to make comment due to error:\n" + str(e)
+
+ def create_file(self, file_query: str) -> str:
+ """
+ Creates a new file on the Github repo
+ Parameters:
+ file_query(str): a string which contains the file path
+ and the file contents. The file path is the first line
+ in the string, and the contents are the rest of the string.
+ For example, "hello_world.md\n# Hello World!"
+ Returns:
+ str: A success or failure message
+ """
+ if self.active_branch == self.github_base_branch:
+ return (
+ "You're attempting to commit to the directly to the"
+ f"{self.github_base_branch} branch, which is protected. "
+ "Please create a new branch and try again."
+ )
+
+ file_path = file_query.split("\n")[0]
+ file_contents = file_query[len(file_path) + 2 :]
+
+ try:
+ try:
+ file = self.github_repo_instance.get_contents(
+ file_path, ref=self.active_branch
+ )
+ if file:
+ return (
+ f"File already exists at `{file_path}` "
+ f"on branch `{self.active_branch}`. You must use "
+ "`update_file` to modify it."
+ )
+ except Exception:
+ # expected behavior, file shouldn't exist yet
+ pass
+
+ self.github_repo_instance.create_file(
+ path=file_path,
+ message="Create " + file_path,
+ content=file_contents,
+ branch=self.active_branch,
+ )
+ return "Created file " + file_path
+ except Exception as e:
+ return "Unable to make file due to error:\n" + str(e)
+
+ def read_file(self, file_path: str) -> str:
+ """
+ Read a file from this agent's branch, defined by self.active_branch,
+ which supports PR branches.
+ Parameters:
+ file_path(str): the file path
+ Returns:
+ str: The file decoded as a string, or an error message if not found
+ """
+ try:
+ file = self.github_repo_instance.get_contents(
+ file_path, ref=self.active_branch
+ )
+ return file.decoded_content.decode("utf-8")
+ except Exception as e:
+ return (
+ f"File not found `{file_path}` on branch"
+ f"`{self.active_branch}`. Error: {str(e)}"
+ )
+
+ def update_file(self, file_query: str) -> str:
+ """
+ Updates a file with new content.
+ Parameters:
+ file_query(str): Contains the file path and the file contents.
+ The old file contents is wrapped in OLD <<<< and >>>> OLD
+ The new file contents is wrapped in NEW <<<< and >>>> NEW
+ For example:
+ /test/hello.txt
+ OLD <<<<
+ Hello Earth!
+ >>>> OLD
+ NEW <<<<
+ Hello Mars!
+ >>>> NEW
+ Returns:
+ A success or failure message
+ """
+ if self.active_branch == self.github_base_branch:
+ return (
+ "You're attempting to commit to the directly"
+ f"to the {self.github_base_branch} branch, which is protected. "
+ "Please create a new branch and try again."
+ )
+ try:
+ file_path: str = file_query.split("\n")[0]
+ old_file_contents = (
+ file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip()
+ )
+ new_file_contents = (
+ file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip()
+ )
+
+ file_content = self.read_file(file_path)
+ updated_file_content = file_content.replace(
+ old_file_contents, new_file_contents
+ )
+
+ if file_content == updated_file_content:
+ return (
+ "File content was not updated because old content was not found."
+ "It may be helpful to use the read_file action to get "
+ "the current file contents."
+ )
+
+ self.github_repo_instance.update_file(
+ path=file_path,
+ message="Update " + str(file_path),
+ content=updated_file_content,
+ branch=self.active_branch,
+ sha=self.github_repo_instance.get_contents(
+ file_path, ref=self.active_branch
+ ).sha,
+ )
+ return "Updated file " + str(file_path)
+ except Exception as e:
+ return "Unable to update file due to error:\n" + str(e)
+
+ def delete_file(self, file_path: str) -> str:
+ """
+ Deletes a file from the repo
+ Parameters:
+ file_path(str): Where the file is
+ Returns:
+ str: Success or failure message
+ """
+ if self.active_branch == self.github_base_branch:
+ return (
+ "You're attempting to commit to the directly"
+ f"to the {self.github_base_branch} branch, which is protected. "
+ "Please create a new branch and try again."
+ )
+ try:
+ self.github_repo_instance.delete_file(
+ path=file_path,
+ message="Delete " + file_path,
+ branch=self.active_branch,
+ sha=self.github_repo_instance.get_contents(
+ file_path, ref=self.active_branch
+ ).sha,
+ )
+ return "Deleted file " + file_path
+ except Exception as e:
+ return "Unable to delete file due to error:\n" + str(e)
+
+ def search_issues_and_prs(self, query: str) -> str:
+ """
+ Searches issues and pull requests in the repository.
+
+ Parameters:
+ query(str): The search query
+
+ Returns:
+ str: A string containing the first 5 issues and pull requests
+ """
+ search_result = self.github.search_issues(query, repo=self.github_repository)
+ max_items = min(5, len(search_result))
+ results = [f"Top {max_items} results:"]
+ for issue in search_result[:max_items]:
+ results.append(
+ f"Title: {issue.title}, Number: {issue.number}, State: {issue.state}"
+ )
+ return "\n".join(results)
+
+ def search_code(self, query: str) -> str:
+ """
+ Searches code in the repository.
+ # Todo: limit total tokens returned...
+
+ Parameters:
+ query(str): The search query
+
+ Returns:
+ str: A string containing, at most, the top 5 search results
+ """
+ search_result = self.github.search_code(
+ query=query, repo=self.github_repository
+ )
+ if search_result.totalCount == 0:
+ return "0 results found."
+ max_results = min(5, search_result.totalCount)
+ results = [f"Showing top {max_results} of {search_result.totalCount} results:"]
+ count = 0
+ for code in search_result:
+ if count >= max_results:
+ break
+ # Get the file content using the PyGithub get_contents method
+ file_content = self.github_repo_instance.get_contents(
+ code.path, ref=self.active_branch
+ ).decoded_content.decode()
+ results.append(
+ f"Filepath: `{code.path}`\nFile contents: "
+ f"{file_content}\n"
+ )
+ count += 1
+ return "\n".join(results)
+
+ def create_review_request(self, reviewer_username: str) -> str:
+ """
+ Creates a review request on *THE* open pull request
+ that matches the current active_branch.
+
+ Parameters:
+ reviewer_username(str): The username of the person who is being requested
+
+ Returns:
+ str: A message confirming the creation of the review request
+ """
+ pull_requests = self.github_repo_instance.get_pulls(
+ state="open", sort="created"
+ )
+ # find PR against active_branch
+ pr = next(
+ (pr for pr in pull_requests if pr.head.ref == self.active_branch), None
+ )
+ if pr is None:
+ return (
+ "No open pull request found for the "
+ f"current branch `{self.active_branch}`"
+ )
+
+ try:
+ pr.create_review_request(reviewers=[reviewer_username])
+ return (
+ f"Review request created for user {reviewer_username} "
+ f"on PR #{pr.number}"
+ )
+ except Exception as e:
+ return f"Failed to create a review request with error {e}"
+
+ def run(self, mode: str, query: str) -> str:
+ if mode == "get_issue":
+ return json.dumps(self.get_issue(int(query)))
+ elif mode == "get_pull_request":
+ return json.dumps(self.get_pull_request(int(query)))
+ elif mode == "list_pull_request_files":
+ return json.dumps(self.list_pull_request_files(int(query)))
+ elif mode == "get_issues":
+ return self.get_issues()
+ elif mode == "comment_on_issue":
+ return self.comment_on_issue(query)
+ elif mode == "create_file":
+ return self.create_file(query)
+ elif mode == "create_pull_request":
+ return self.create_pull_request(query)
+ elif mode == "read_file":
+ return self.read_file(query)
+ elif mode == "update_file":
+ return self.update_file(query)
+ elif mode == "delete_file":
+ return self.delete_file(query)
+ elif mode == "list_open_pull_requests":
+ return self.list_open_pull_requests()
+ elif mode == "list_files_in_main_branch":
+ return self.list_files_in_main_branch()
+ elif mode == "list_files_in_bot_branch":
+ return self.list_files_in_bot_branch()
+ elif mode == "list_branches_in_repo":
+ return self.list_branches_in_repo()
+ elif mode == "set_active_branch":
+ return self.set_active_branch(query)
+ elif mode == "create_branch":
+ return self.create_branch(query)
+ elif mode == "get_files_from_directory":
+ return self.get_files_from_directory(query)
+ elif mode == "search_issues_and_prs":
+ return self.search_issues_and_prs(query)
+ elif mode == "search_code":
+ return self.search_code(query)
+ elif mode == "create_review_request":
+ return self.create_review_request(query)
+ else:
+ raise ValueError("Invalid mode" + mode)
diff --git a/libs/community/langchain_community/utilities/gitlab.py b/libs/community/langchain_community/utilities/gitlab.py
new file mode 100644
index 00000000000..0eb1b40a7dd
--- /dev/null
+++ b/libs/community/langchain_community/utilities/gitlab.py
@@ -0,0 +1,327 @@
+"""Util that calls gitlab."""
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+if TYPE_CHECKING:
+ from gitlab.v4.objects import Issue
+
+
+class GitLabAPIWrapper(BaseModel):
+ """Wrapper for GitLab API."""
+
+ gitlab: Any #: :meta private:
+ gitlab_repo_instance: Any #: :meta private:
+ gitlab_repository: Optional[str] = None
+ """The name of the GitLab repository, in the form {username}/{repo-name}."""
+ gitlab_personal_access_token: Optional[str] = None
+ """Personal access token for the GitLab service, used for authentication."""
+ gitlab_branch: Optional[str] = None
+ """The specific branch in the GitLab repository where the bot will make
+ its commits. Defaults to 'main'.
+ """
+ gitlab_base_branch: Optional[str] = None
+ """The base branch in the GitLab repository, used for comparisons.
+ Usually 'main' or 'master'. Defaults to 'main'.
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+
+ gitlab_url = get_from_dict_or_env(
+ values, "gitlab_url", "GITLAB_URL", default=None
+ )
+ gitlab_repository = get_from_dict_or_env(
+ values, "gitlab_repository", "GITLAB_REPOSITORY"
+ )
+
+ gitlab_personal_access_token = get_from_dict_or_env(
+ values, "gitlab_personal_access_token", "GITLAB_PERSONAL_ACCESS_TOKEN"
+ )
+
+ gitlab_branch = get_from_dict_or_env(
+ values, "gitlab_branch", "GITLAB_BRANCH", default="main"
+ )
+ gitlab_base_branch = get_from_dict_or_env(
+ values, "gitlab_base_branch", "GITLAB_BASE_BRANCH", default="main"
+ )
+
+ try:
+ import gitlab
+
+ except ImportError:
+ raise ImportError(
+ "python-gitlab is not installed. "
+ "Please install it with `pip install python-gitlab`"
+ )
+
+ g = gitlab.Gitlab(
+ url=gitlab_url,
+ private_token=gitlab_personal_access_token,
+ keep_base_url=True,
+ )
+
+ g.auth()
+
+ values["gitlab"] = g
+ values["gitlab_repo_instance"] = g.projects.get(gitlab_repository)
+ values["gitlab_repository"] = gitlab_repository
+ values["gitlab_personal_access_token"] = gitlab_personal_access_token
+ values["gitlab_branch"] = gitlab_branch
+ values["gitlab_base_branch"] = gitlab_base_branch
+
+ return values
+
+ def parse_issues(self, issues: List[Issue]) -> List[dict]:
+ """
+ Extracts title and number from each Issue and puts them in a dictionary
+ Parameters:
+ issues(List[Issue]): A list of gitlab Issue objects
+ Returns:
+ List[dict]: A dictionary of issue titles and numbers
+ """
+ parsed = []
+ for issue in issues:
+ title = issue.title
+ number = issue.iid
+ parsed.append({"title": title, "number": number})
+ return parsed
+
+ def get_issues(self) -> str:
+ """
+ Fetches all open issues from the repo
+
+ Returns:
+ str: A plaintext report containing the number of issues
+ and each issue's title and number.
+ """
+ issues = self.gitlab_repo_instance.issues.list(state="opened")
+ if len(issues) > 0:
+ parsed_issues = self.parse_issues(issues)
+ parsed_issues_str = (
+ "Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
+ )
+ return parsed_issues_str
+ else:
+ return "No open issues available"
+
+ def get_issue(self, issue_number: int) -> Dict[str, Any]:
+ """
+ Fetches a specific issue and its first 10 comments
+ Parameters:
+ issue_number(int): The number for the gitlab issue
+ Returns:
+ dict: A dictionary containing the issue's title,
+ body, and comments as a string
+ """
+ issue = self.gitlab_repo_instance.issues.get(issue_number)
+ page = 0
+ comments: List[dict] = []
+ while len(comments) <= 10:
+ comments_page = issue.notes.list(page=page)
+ if len(comments_page) == 0:
+ break
+ for comment in comments_page:
+ comment = issue.notes.get(comment.id)
+ comments.append(
+ {"body": comment.body, "user": comment.author["username"]}
+ )
+ page += 1
+
+ return {
+ "title": issue.title,
+ "body": issue.description,
+ "comments": str(comments),
+ }
+
+ def create_pull_request(self, pr_query: str) -> str:
+ """
+ Makes a pull request from the bot's branch to the base branch
+ Parameters:
+ pr_query(str): a string which contains the PR title
+ and the PR body. The title is the first line
+ in the string, and the body are the rest of the string.
+ For example, "Updated README\nmade changes to add info"
+ Returns:
+ str: A success or failure message
+ """
+ if self.gitlab_base_branch == self.gitlab_branch:
+ return """Cannot make a pull request because
+ commits are already in the master branch"""
+ else:
+ try:
+ title = pr_query.split("\n")[0]
+ body = pr_query[len(title) + 2 :]
+ pr = self.gitlab_repo_instance.mergerequests.create(
+ {
+ "source_branch": self.gitlab_branch,
+ "target_branch": self.gitlab_base_branch,
+ "title": title,
+ "description": body,
+ "labels": ["created-by-agent"],
+ }
+ )
+ return f"Successfully created PR number {pr.iid}"
+ except Exception as e:
+ return "Unable to make pull request due to error:\n" + str(e)
+
+ def comment_on_issue(self, comment_query: str) -> str:
+ """
+ Adds a comment to a gitlab issue
+ Parameters:
+ comment_query(str): a string which contains the issue number,
+ two newlines, and the comment.
+ for example: "1\n\nWorking on it now"
+ adds the comment "working on it now" to issue 1
+ Returns:
+ str: A success or failure message
+ """
+ issue_number = int(comment_query.split("\n\n")[0])
+ comment = comment_query[len(str(issue_number)) + 2 :]
+ try:
+ issue = self.gitlab_repo_instance.issues.get(issue_number)
+ issue.notes.create({"body": comment})
+ return "Commented on issue " + str(issue_number)
+ except Exception as e:
+ return "Unable to make comment due to error:\n" + str(e)
+
+ def create_file(self, file_query: str) -> str:
+ """
+ Creates a new file on the gitlab repo
+ Parameters:
+ file_query(str): a string which contains the file path
+ and the file contents. The file path is the first line
+ in the string, and the contents are the rest of the string.
+ For example, "hello_world.md\n# Hello World!"
+ Returns:
+ str: A success or failure message
+ """
+ file_path = file_query.split("\n")[0]
+ file_contents = file_query[len(file_path) + 2 :]
+ try:
+ self.gitlab_repo_instance.files.get(file_path, self.gitlab_branch)
+ return f"File already exists at {file_path}. Use update_file instead"
+ except Exception:
+ data = {
+ "branch": self.gitlab_branch,
+ "commit_message": "Create " + file_path,
+ "file_path": file_path,
+ "content": file_contents,
+ }
+
+ self.gitlab_repo_instance.files.create(data)
+
+ return "Created file " + file_path
+
+ def read_file(self, file_path: str) -> str:
+ """
+ Reads a file from the gitlab repo
+ Parameters:
+ file_path(str): the file path
+ Returns:
+ str: The file decoded as a string
+ """
+ file = self.gitlab_repo_instance.files.get(file_path, self.gitlab_branch)
+ return file.decode().decode("utf-8")
+
+ def update_file(self, file_query: str) -> str:
+ """
+ Updates a file with new content.
+ Parameters:
+ file_query(str): Contains the file path and the file contents.
+ The old file contents is wrapped in OLD <<<< and >>>> OLD
+ The new file contents is wrapped in NEW <<<< and >>>> NEW
+ For example:
+ test/hello.txt
+ OLD <<<<
+ Hello Earth!
+ >>>> OLD
+ NEW <<<<
+ Hello Mars!
+ >>>> NEW
+ Returns:
+ A success or failure message
+ """
+ try:
+ file_path = file_query.split("\n")[0]
+ old_file_contents = (
+ file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip()
+ )
+ new_file_contents = (
+ file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip()
+ )
+
+ file_content = self.read_file(file_path)
+ updated_file_content = file_content.replace(
+ old_file_contents, new_file_contents
+ )
+
+ if file_content == updated_file_content:
+ return (
+ "File content was not updated because old content was not found."
+ "It may be helpful to use the read_file action to get "
+ "the current file contents."
+ )
+
+ commit = {
+ "branch": self.gitlab_branch,
+ "commit_message": "Create " + file_path,
+ "actions": [
+ {
+ "action": "update",
+ "file_path": file_path,
+ "content": updated_file_content,
+ }
+ ],
+ }
+
+ self.gitlab_repo_instance.commits.create(commit)
+ return "Updated file " + file_path
+ except Exception as e:
+ return "Unable to update file due to error:\n" + str(e)
+
+ def delete_file(self, file_path: str) -> str:
+ """
+ Deletes a file from the repo
+ Parameters:
+ file_path(str): Where the file is
+ Returns:
+ str: Success or failure message
+ """
+ try:
+ self.gitlab_repo_instance.files.delete(
+ file_path, self.gitlab_branch, "Delete " + file_path
+ )
+ return "Deleted file " + file_path
+ except Exception as e:
+ return "Unable to delete file due to error:\n" + str(e)
+
+ def run(self, mode: str, query: str) -> str:
+ if mode == "get_issues":
+ return self.get_issues()
+ elif mode == "get_issue":
+ return json.dumps(self.get_issue(int(query)))
+ elif mode == "comment_on_issue":
+ return self.comment_on_issue(query)
+ elif mode == "create_file":
+ return self.create_file(query)
+ elif mode == "create_pull_request":
+ return self.create_pull_request(query)
+ elif mode == "read_file":
+ return self.read_file(query)
+ elif mode == "update_file":
+ return self.update_file(query)
+ elif mode == "delete_file":
+ return self.delete_file(query)
+ else:
+ raise ValueError("Invalid mode" + mode)
diff --git a/libs/community/langchain_community/utilities/golden_query.py b/libs/community/langchain_community/utilities/golden_query.py
new file mode 100644
index 00000000000..baeb3ddf96a
--- /dev/null
+++ b/libs/community/langchain_community/utilities/golden_query.py
@@ -0,0 +1,66 @@
+"""Util that calls Golden."""
+import json
+from typing import Dict, Optional
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+GOLDEN_BASE_URL = "https://golden.com"
+GOLDEN_TIMEOUT = 5000
+
+
+class GoldenQueryAPIWrapper(BaseModel):
+ """Wrapper for Golden.
+
+ Docs for using:
+
+ 1. Go to https://golden.com and sign up for an account
+ 2. Get your API Key from https://golden.com/settings/api
+ 3. Save your API Key into GOLDEN_API_KEY env variable
+
+ """
+
+ golden_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ golden_api_key = get_from_dict_or_env(
+ values, "golden_api_key", "GOLDEN_API_KEY"
+ )
+ values["golden_api_key"] = golden_api_key
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through Golden Query API and return the JSON raw result."""
+
+ headers = {"apikey": self.golden_api_key or ""}
+
+ response = requests.post(
+ f"{GOLDEN_BASE_URL}/api/v2/public/queries/",
+ json={"prompt": query},
+ headers=headers,
+ timeout=GOLDEN_TIMEOUT,
+ )
+ if response.status_code != 201:
+ return response.text
+
+ content = json.loads(response.content)
+ query_id = content["id"]
+
+ response = requests.get(
+ (
+ f"{GOLDEN_BASE_URL}/api/v2/public/queries/{query_id}/results/"
+ "?pageSize=10"
+ ),
+ headers=headers,
+ timeout=GOLDEN_TIMEOUT,
+ )
+ return response.text
diff --git a/libs/community/langchain_community/utilities/google_finance.py b/libs/community/langchain_community/utilities/google_finance.py
new file mode 100644
index 00000000000..7b6e3c58edf
--- /dev/null
+++ b/libs/community/langchain_community/utilities/google_finance.py
@@ -0,0 +1,97 @@
+"""Util that calls Google Finance Search."""
+from typing import Any, Dict, Optional, cast
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+
+class GoogleFinanceAPIWrapper(BaseModel):
+ """Wrapper for SerpApi's Google Finance API
+ You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up.
+ The wrapper uses the SerpApi.com python package:
+ https://serpapi.com/integrations/python
+ To use, you should have the environment variable ``SERPAPI_API_KEY``
+ set with your API key, or pass `serp_api_key` as a named parameter
+ to the constructor.
+ Example:
+ .. code-block:: python
+ from langchain_community.utilities import GoogleFinanceAPIWrapper
+ google_Finance = GoogleFinanceAPIWrapper()
+ google_Finance.run('langchain')
+ """
+
+ serp_search_engine: Any
+ serp_api_key: Optional[SecretStr] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["serp_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "serp_api_key", "SERPAPI_API_KEY")
+ )
+
+ try:
+ from serpapi import SerpApiClient
+
+ except ImportError:
+ raise ImportError(
+ "google-search-results is not installed. "
+ "Please install it with `pip install google-search-results"
+ ">=2.4.2`"
+ )
+ serp_search_engine = SerpApiClient
+ values["serp_search_engine"] = serp_search_engine
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through Google Finance with Serpapi"""
+ serpapi_api_key = cast(SecretStr, self.serp_api_key)
+ params = {
+ "engine": "google_finance",
+ "api_key": serpapi_api_key.get_secret_value(),
+ "q": query,
+ }
+
+ total_results = {}
+ client = self.serp_search_engine(params)
+ total_results = client.get_dict()
+
+ if not total_results:
+ return "Nothing was found from the query: " + query
+
+ markets = total_results.get("markets", {})
+ res = "\nQuery: " + query + "\n"
+
+ if "futures_chain" in total_results:
+ futures_chain = total_results.get("futures_chain", [])[0]
+ stock = futures_chain["stock"]
+ price = futures_chain["price"]
+ temp = futures_chain["price_movement"]
+ percentage = temp["percentage"]
+ movement = temp["movement"]
+ res += (
+ f"stock: {stock}\n"
+ + f"price: {price}\n"
+ + f"percentage: {percentage}\n"
+ + f"movement: {movement}\n"
+ )
+
+ else:
+ res += "No summary information\n"
+
+ for key in markets:
+ if (key == "us") or (key == "asia") or (key == "europe"):
+ res += key
+ res += ": price = "
+ res += str(markets[key][0]["price"])
+ res += ", movement = "
+ res += markets[key][0]["price_movement"]["movement"]
+ res += "\n"
+
+ return res
diff --git a/libs/community/langchain_community/utilities/google_jobs.py b/libs/community/langchain_community/utilities/google_jobs.py
new file mode 100644
index 00000000000..2de550186da
--- /dev/null
+++ b/libs/community/langchain_community/utilities/google_jobs.py
@@ -0,0 +1,80 @@
+"""Util that calls Google Scholar Search."""
+from typing import Any, Dict, Optional, cast
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+
+class GoogleJobsAPIWrapper(BaseModel):
+ """Wrapper for SerpApi's Google Scholar API
+ You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up.
+ The wrapper uses the SerpApi.com python package:
+ https://serpapi.com/integrations/python
+ To use, you should have the environment variable ``SERPAPI_API_KEY``
+ set with your API key, or pass `serp_api_key` as a named parameter
+ to the constructor.
+ Example:
+ .. code-block:: python
+ from langchain_community.utilities import GoogleJobsAPIWrapper
+ google_Jobs = GoogleJobsAPIWrapper()
+ google_Jobs.run('langchain')
+ """
+
+ serp_search_engine: Any
+ serp_api_key: Optional[SecretStr] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["serp_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "serp_api_key", "SERPAPI_API_KEY")
+ )
+
+ try:
+ from serpapi import SerpApiClient
+
+ except ImportError:
+ raise ImportError(
+ "google-search-results is not installed. "
+ "Please install it with `pip install google-search-results"
+ ">=2.4.2`"
+ )
+ serp_search_engine = SerpApiClient
+ values["serp_search_engine"] = serp_search_engine
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through Google Trends with Serpapi"""
+
+ # set up query
+ serpapi_api_key = cast(SecretStr, self.serp_api_key)
+ params = {
+ "engine": "google_jobs",
+ "api_key": serpapi_api_key.get_secret_value(),
+ "q": query,
+ }
+
+ total_results = []
+ client = self.serp_search_engine(params)
+ total_results = client.get_dict()["jobs_results"]
+
+ # extract 1 job info:
+ res_str = ""
+ for i in range(1):
+ job = total_results[i]
+ res_str += (
+ "\n_______________________________________________"
+ + f"\nJob Title: {job['title']}\n"
+ + f"Company Name: {job['company_name']}\n"
+ + f"Location: {job['location']}\n"
+ + f"Description: {job['description']}"
+ + "\n_______________________________________________\n"
+ )
+
+ return res_str + "\n"
diff --git a/libs/community/langchain_community/utilities/google_lens.py b/libs/community/langchain_community/utilities/google_lens.py
new file mode 100644
index 00000000000..f8419bdb758
--- /dev/null
+++ b/libs/community/langchain_community/utilities/google_lens.py
@@ -0,0 +1,84 @@
+"""Util that calls Google Lens Search."""
+from typing import Any, Dict, Optional, cast
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+
+class GoogleLensAPIWrapper(BaseModel):
+ """Wrapper for SerpApi's Google Lens API
+
+ You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up.
+
+ The wrapper uses the SerpApi.com python package:
+ https://serpapi.com/integrations/python
+
+ To use, you should have the environment variable ``SERPAPI_API_KEY``
+ set with your API key, or pass `serp_api_key` as a named parameter
+ to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import GoogleLensAPIWrapper
+ google_lens = GoogleLensAPIWrapper()
+ google_lens.run('langchain')
+ """
+
+ serp_search_engine: Any
+ serp_api_key: Optional[SecretStr] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["serp_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "serp_api_key", "SERPAPI_API_KEY")
+ )
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through Google Trends with Serpapi"""
+ serpapi_api_key = cast(SecretStr, self.serp_api_key)
+
+ params = {
+ "engine": "google_lens",
+ "api_key": serpapi_api_key.get_secret_value(),
+ "url": query,
+ }
+ queryURL = f"https://serpapi.com/search?engine={params['engine']}&api_key={params['api_key']}&url={params['url']}"
+ response = requests.get(queryURL)
+
+ if response.status_code != 200:
+ return "Google Lens search failed"
+
+ responseValue = response.json()
+
+ if responseValue["search_metadata"]["status"] != "Success":
+ return "Google Lens search failed"
+
+ xs = ""
+ if len(responseValue["knowledge_graph"]) > 0:
+ subject = responseValue["knowledge_graph"][0]
+ xs += f"Subject:{subject['title']}({subject['subtitle']})\n"
+ xs += f"Link to subject:{subject['link']}\n\n"
+ xs += "Related Images:\n\n"
+ for image in responseValue["visual_matches"]:
+ xs += f"Title: {image['title']}\n"
+ xs += f"Source({image['source']}): {image['link']}\n"
+ xs += f"Image: {image['thumbnail']}\n\n"
+ xs += (
+ "Reverse Image Search"
+ + f"Link: {responseValue['reverse_image_search']['link']}\n"
+ )
+ print(xs)
+
+ docs = [xs]
+
+ return "\n\n".join(docs)
diff --git a/libs/community/langchain_community/utilities/google_places_api.py b/libs/community/langchain_community/utilities/google_places_api.py
new file mode 100644
index 00000000000..eb7ff148b6a
--- /dev/null
+++ b/libs/community/langchain_community/utilities/google_places_api.py
@@ -0,0 +1,114 @@
+"""Chain that calls Google Places API.
+"""
+
+import logging
+from typing import Any, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class GooglePlacesAPIWrapper(BaseModel):
+ """Wrapper around Google Places API.
+
+ To use, you should have the ``googlemaps`` python package installed,
+ **an API key for the google maps platform**,
+ and the environment variable ''GPLACES_API_KEY''
+ set with your API key , or pass 'gplaces_api_key'
+ as a named parameter to the constructor.
+
+ By default, this will return the all the results on the input query.
+ You can use the top_k_results argument to limit the number of results.
+
+ Example:
+ .. code-block:: python
+
+
+ from langchain_community.utilities import GooglePlacesAPIWrapper
+ gplaceapi = GooglePlacesAPIWrapper()
+ """
+
+ gplaces_api_key: Optional[str] = None
+ google_map_client: Any #: :meta private:
+ top_k_results: Optional[int] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key is in your environment variable."""
+ gplaces_api_key = get_from_dict_or_env(
+ values, "gplaces_api_key", "GPLACES_API_KEY"
+ )
+ values["gplaces_api_key"] = gplaces_api_key
+ try:
+ import googlemaps
+
+ values["google_map_client"] = googlemaps.Client(gplaces_api_key)
+ except ImportError:
+ raise ImportError(
+ "Could not import googlemaps python package. "
+ "Please install it with `pip install googlemaps`."
+ )
+ return values
+
+ def run(self, query: str) -> str:
+ """Run Places search and get k number of places that exists that match."""
+ search_results = self.google_map_client.places(query)["results"]
+ num_to_return = len(search_results)
+
+ places = []
+
+ if num_to_return == 0:
+ return "Google Places did not find any places that match the description"
+
+ num_to_return = (
+ num_to_return
+ if self.top_k_results is None
+ else min(num_to_return, self.top_k_results)
+ )
+
+ for i in range(num_to_return):
+ result = search_results[i]
+ details = self.fetch_place_details(result["place_id"])
+
+ if details is not None:
+ places.append(details)
+
+ return "\n".join([f"{i+1}. {item}" for i, item in enumerate(places)])
+
+ def fetch_place_details(self, place_id: str) -> Optional[str]:
+ try:
+ place_details = self.google_map_client.place(place_id)
+ place_details["place_id"] = place_id
+ formatted_details = self.format_place_details(place_details)
+ return formatted_details
+ except Exception as e:
+ logging.error(f"An Error occurred while fetching place details: {e}")
+ return None
+
+ def format_place_details(self, place_details: Dict[str, Any]) -> Optional[str]:
+ try:
+ name = place_details.get("result", {}).get("name", "Unknown")
+ address = place_details.get("result", {}).get(
+ "formatted_address", "Unknown"
+ )
+ phone_number = place_details.get("result", {}).get(
+ "formatted_phone_number", "Unknown"
+ )
+ website = place_details.get("result", {}).get("website", "Unknown")
+ place_id = place_details.get("result", {}).get("place_id", "Unknown")
+
+ formatted_details = (
+ f"{name}\nAddress: {address}\n"
+ f"Google place ID: {place_id}\n"
+ f"Phone: {phone_number}\nWebsite: {website}\n\n"
+ )
+ return formatted_details
+ except Exception as e:
+ logging.error(f"An error occurred while formatting place details: {e}")
+ return None
diff --git a/libs/community/langchain_community/utilities/google_scholar.py b/libs/community/langchain_community/utilities/google_scholar.py
new file mode 100644
index 00000000000..49b2677d7f7
--- /dev/null
+++ b/libs/community/langchain_community/utilities/google_scholar.py
@@ -0,0 +1,129 @@
+"""Util that calls Google Scholar Search."""
+from typing import Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class GoogleScholarAPIWrapper(BaseModel):
+ """Wrapper for Google Scholar API
+
+ You can create serpapi key by signing up at: https://serpapi.com/users/sign_up.
+
+ The wrapper uses the serpapi python package:
+ https://serpapi.com/integrations/python#search-google-scholar
+
+ To use, you should have the environment variable ``SERP_API_KEY``
+ set with your API key, or pass `serp_api_key` as a named parameter
+ to the constructor.
+
+ Attributes:
+ top_k_results: number of results to return from google-scholar query search.
+ By default it returns top 10 results.
+ hl: attribute defines the language to use for the Google Scholar search.
+ It's a two-letter language code.
+ (e.g., en for English, es for Spanish, or fr for French). Head to the
+ Google languages page for a full list of supported Google languages:
+ https://serpapi.com/google-languages
+
+ lr: attribute defines one or multiple languages to limit the search to.
+ It uses lang_{two-letter language code} to specify languages
+ and | as a delimiter. (e.g., lang_fr|lang_de will only search French
+ and German pages). Head to the Google lr languages for a full
+ list of supported languages: https://serpapi.com/google-lr-languages
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import GoogleScholarAPIWrapper
+ google_scholar = GoogleScholarAPIWrapper()
+ google_scholar.run('langchain')
+ """
+
+ top_k_results: int = 10
+ hl: str = "en"
+ lr: str = "lang_en"
+ serp_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ serp_api_key = get_from_dict_or_env(values, "serp_api_key", "SERP_API_KEY")
+ values["SERP_API_KEY"] = serp_api_key
+
+ try:
+ from serpapi import GoogleScholarSearch
+
+ except ImportError:
+ raise ImportError(
+ "google-search-results is not installed. "
+ "Please install it with `pip install google-search-results"
+ ">=2.4.2`"
+ )
+ GoogleScholarSearch.SERP_API_KEY = serp_api_key
+ values["google_scholar_engine"] = GoogleScholarSearch
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through GoogleSearchScholar and parse result"""
+ total_results = []
+ page = 0
+ while page < max((self.top_k_results - 20), 1):
+ # We are getting 20 results from every page
+ # which is the max in order to reduce the number of API CALLS.
+ # 0 is the first page of results, 20 is the 2nd page of results,
+ # 40 is the 3rd page of results, etc.
+ results = (
+ self.google_scholar_engine( # type: ignore
+ {
+ "q": query,
+ "start": page,
+ "hl": self.hl,
+ "num": min(
+ self.top_k_results, 20
+ ), # if top_k_result is less than 20.
+ "lr": self.lr,
+ }
+ )
+ .get_dict()
+ .get("organic_results", [])
+ )
+ total_results.extend(results)
+ if not results: # No need to search for more pages if current page
+ # has returned no results
+ break
+ page += 20
+ if (
+ self.top_k_results % 20 != 0 and page > 20 and total_results
+ ): # From the last page we would only need top_k_results%20 results
+ # if k is not divisible by 20.
+ results = (
+ self.google_scholar_engine( # type: ignore
+ {
+ "q": query,
+ "start": page,
+ "num": self.top_k_results % 20,
+ "hl": self.hl,
+ "lr": self.lr,
+ }
+ )
+ .get_dict()
+ .get("organic_results", [])
+ )
+ total_results.extend(results)
+ if not total_results:
+ return "No good Google Scholar Result was found"
+ docs = [
+ f"Title: {result.get('title','')}\n"
+ f"Authors: {','.join([author.get('name') for author in result.get('publication_info',{}).get('authors',[])])}\n" # noqa: E501
+ f"Summary: {result.get('publication_info',{}).get('summary','')}\n"
+ f"Total-Citations: {result.get('inline_links',{}).get('cited_by',{}).get('total','')}" # noqa: E501
+ for result in total_results
+ ]
+ return "\n\n".join(docs)
diff --git a/libs/community/langchain_community/utilities/google_search.py b/libs/community/langchain_community/utilities/google_search.py
new file mode 100644
index 00000000000..5229f59cb3a
--- /dev/null
+++ b/libs/community/langchain_community/utilities/google_search.py
@@ -0,0 +1,137 @@
+"""Util that calls Google Search."""
+from typing import Any, Dict, List, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class GoogleSearchAPIWrapper(BaseModel):
+ """Wrapper for Google Search API.
+
+ Adapted from: Instructions adapted from https://stackoverflow.com/questions/
+ 37083058/
+ programmatically-searching-google-in-python-using-custom-search
+
+ TODO: DOCS for using it
+ 1. Install google-api-python-client
+ - If you don't already have a Google account, sign up.
+ - If you have never created a Google APIs Console project,
+ read the Managing Projects page and create a project in the Google API Console.
+ - Install the library using pip install google-api-python-client
+
+ 2. Enable the Custom Search API
+ - Navigate to the APIs & ServicesβDashboard panel in Cloud Console.
+ - Click Enable APIs and Services.
+ - Search for Custom Search API and click on it.
+ - Click Enable.
+ URL for it: https://console.cloud.google.com/apis/library/customsearch.googleapis
+ .com
+
+ 3. To create an API key:
+ - Navigate to the APIs & Services β Credentials panel in Cloud Console.
+ - Select Create credentials, then select API key from the drop-down menu.
+ - The API key created dialog box displays your newly created key.
+ - You now have an API_KEY
+
+ Alternatively, you can just generate an API key here:
+ https://developers.google.com/custom-search/docs/paid_element#api_key
+
+ 4. Setup Custom Search Engine so you can search the entire web
+ - Create a custom search engine here: https://programmablesearchengine.google.com/.
+ - In `What to search` to search, pick the `Search the entire Web` option.
+ After search engine is created, you can click on it and find `Search engine ID`
+ on the Overview page.
+
+ """
+
+ search_engine: Any #: :meta private:
+ google_api_key: Optional[str] = None
+ google_cse_id: Optional[str] = None
+ k: int = 10
+ siterestrict: bool = False
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _google_search_results(self, search_term: str, **kwargs: Any) -> List[dict]:
+ cse = self.search_engine.cse()
+ if self.siterestrict:
+ cse = cse.siterestrict()
+ res = cse.list(q=search_term, cx=self.google_cse_id, **kwargs).execute()
+ return res.get("items", [])
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ google_api_key = get_from_dict_or_env(
+ values, "google_api_key", "GOOGLE_API_KEY"
+ )
+ values["google_api_key"] = google_api_key
+
+ google_cse_id = get_from_dict_or_env(values, "google_cse_id", "GOOGLE_CSE_ID")
+ values["google_cse_id"] = google_cse_id
+
+ try:
+ from googleapiclient.discovery import build
+
+ except ImportError:
+ raise ImportError(
+ "google-api-python-client is not installed. "
+ "Please install it with `pip install google-api-python-client"
+ ">=2.100.0`"
+ )
+
+ service = build("customsearch", "v1", developerKey=google_api_key)
+ values["search_engine"] = service
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through GoogleSearch and parse result."""
+ snippets = []
+ results = self._google_search_results(query, num=self.k)
+ if len(results) == 0:
+ return "No good Google Search Result was found"
+ for result in results:
+ if "snippet" in result:
+ snippets.append(result["snippet"])
+
+ return " ".join(snippets)
+
+ def results(
+ self,
+ query: str,
+ num_results: int,
+ search_params: Optional[Dict[str, str]] = None,
+ ) -> List[Dict]:
+ """Run query through GoogleSearch and return metadata.
+
+ Args:
+ query: The query to search for.
+ num_results: The number of results to return.
+ search_params: Parameters to be passed on search
+
+ Returns:
+ A list of dictionaries with the following keys:
+ snippet - The description of the result.
+ title - The title of the result.
+ link - The link to the result.
+ """
+ metadata_results = []
+ results = self._google_search_results(
+ query, num=num_results, **(search_params or {})
+ )
+ if len(results) == 0:
+ return [{"Result": "No good Google Search Result was found"}]
+ for result in results:
+ metadata_result = {
+ "title": result["title"],
+ "link": result["link"],
+ }
+ if "snippet" in result:
+ metadata_result["snippet"] = result["snippet"]
+ metadata_results.append(metadata_result)
+
+ return metadata_results
diff --git a/libs/community/langchain_community/utilities/google_serper.py b/libs/community/langchain_community/utilities/google_serper.py
new file mode 100644
index 00000000000..6701b56e584
--- /dev/null
+++ b/libs/community/langchain_community/utilities/google_serper.py
@@ -0,0 +1,192 @@
+"""Util that calls Google Search using the Serper.dev API."""
+from typing import Any, Dict, List, Optional
+
+import aiohttp
+import requests
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from typing_extensions import Literal
+
+
+class GoogleSerperAPIWrapper(BaseModel):
+ """Wrapper around the Serper.dev Google Search API.
+
+ You can create a free API key at https://serper.dev.
+
+ To use, you should have the environment variable ``SERPER_API_KEY``
+ set with your API key, or pass `serper_api_key` as a named parameter
+ to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import GoogleSerperAPIWrapper
+ google_serper = GoogleSerperAPIWrapper()
+ """
+
+ k: int = 10
+ gl: str = "us"
+ hl: str = "en"
+ # "places" and "images" is available from Serper but not implemented in the
+ # parser of run(). They can be used in results()
+ type: Literal["news", "search", "places", "images"] = "search"
+ result_key_for_type = {
+ "news": "news",
+ "places": "places",
+ "images": "images",
+ "search": "organic",
+ }
+
+ tbs: Optional[str] = None
+ serper_api_key: Optional[str] = None
+ aiosession: Optional[aiohttp.ClientSession] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ serper_api_key = get_from_dict_or_env(
+ values, "serper_api_key", "SERPER_API_KEY"
+ )
+ values["serper_api_key"] = serper_api_key
+
+ return values
+
+ def results(self, query: str, **kwargs: Any) -> Dict:
+ """Run query through GoogleSearch."""
+ return self._google_serper_api_results(
+ query,
+ gl=self.gl,
+ hl=self.hl,
+ num=self.k,
+ tbs=self.tbs,
+ search_type=self.type,
+ **kwargs,
+ )
+
+ def run(self, query: str, **kwargs: Any) -> str:
+ """Run query through GoogleSearch and parse result."""
+ results = self._google_serper_api_results(
+ query,
+ gl=self.gl,
+ hl=self.hl,
+ num=self.k,
+ tbs=self.tbs,
+ search_type=self.type,
+ **kwargs,
+ )
+
+ return self._parse_results(results)
+
+ async def aresults(self, query: str, **kwargs: Any) -> Dict:
+ """Run query through GoogleSearch."""
+ results = await self._async_google_serper_search_results(
+ query,
+ gl=self.gl,
+ hl=self.hl,
+ num=self.k,
+ search_type=self.type,
+ tbs=self.tbs,
+ **kwargs,
+ )
+ return results
+
+ async def arun(self, query: str, **kwargs: Any) -> str:
+ """Run query through GoogleSearch and parse result async."""
+ results = await self._async_google_serper_search_results(
+ query,
+ gl=self.gl,
+ hl=self.hl,
+ num=self.k,
+ search_type=self.type,
+ tbs=self.tbs,
+ **kwargs,
+ )
+
+ return self._parse_results(results)
+
+ def _parse_snippets(self, results: dict) -> List[str]:
+ snippets = []
+
+ if results.get("answerBox"):
+ answer_box = results.get("answerBox", {})
+ if answer_box.get("answer"):
+ return [answer_box.get("answer")]
+ elif answer_box.get("snippet"):
+ return [answer_box.get("snippet").replace("\n", " ")]
+ elif answer_box.get("snippetHighlighted"):
+ return answer_box.get("snippetHighlighted")
+
+ if results.get("knowledgeGraph"):
+ kg = results.get("knowledgeGraph", {})
+ title = kg.get("title")
+ entity_type = kg.get("type")
+ if entity_type:
+ snippets.append(f"{title}: {entity_type}.")
+ description = kg.get("description")
+ if description:
+ snippets.append(description)
+ for attribute, value in kg.get("attributes", {}).items():
+ snippets.append(f"{title} {attribute}: {value}.")
+
+ for result in results[self.result_key_for_type[self.type]][: self.k]:
+ if "snippet" in result:
+ snippets.append(result["snippet"])
+ for attribute, value in result.get("attributes", {}).items():
+ snippets.append(f"{attribute}: {value}.")
+
+ if len(snippets) == 0:
+ return ["No good Google Search Result was found"]
+ return snippets
+
+ def _parse_results(self, results: dict) -> str:
+ return " ".join(self._parse_snippets(results))
+
+ def _google_serper_api_results(
+ self, search_term: str, search_type: str = "search", **kwargs: Any
+ ) -> dict:
+ headers = {
+ "X-API-KEY": self.serper_api_key or "",
+ "Content-Type": "application/json",
+ }
+ params = {
+ "q": search_term,
+ **{key: value for key, value in kwargs.items() if value is not None},
+ }
+ response = requests.post(
+ f"https://google.serper.dev/{search_type}", headers=headers, params=params
+ )
+ response.raise_for_status()
+ search_results = response.json()
+ return search_results
+
+ async def _async_google_serper_search_results(
+ self, search_term: str, search_type: str = "search", **kwargs: Any
+ ) -> dict:
+ headers = {
+ "X-API-KEY": self.serper_api_key or "",
+ "Content-Type": "application/json",
+ }
+ url = f"https://google.serper.dev/{search_type}"
+ params = {
+ "q": search_term,
+ **{key: value for key, value in kwargs.items() if value is not None},
+ }
+
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ url, params=params, headers=headers, raise_for_status=False
+ ) as response:
+ search_results = await response.json()
+ else:
+ async with self.aiosession.post(
+ url, params=params, headers=headers, raise_for_status=True
+ ) as response:
+ search_results = await response.json()
+
+ return search_results
diff --git a/libs/community/langchain_community/utilities/google_trends.py b/libs/community/langchain_community/utilities/google_trends.py
new file mode 100644
index 00000000000..f0f15000c8a
--- /dev/null
+++ b/libs/community/langchain_community/utilities/google_trends.py
@@ -0,0 +1,116 @@
+"""Util that calls Google Scholar Search."""
+from typing import Any, Dict, Optional, cast
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
+
+
+class GoogleTrendsAPIWrapper(BaseModel):
+ """Wrapper for SerpApi's Google Scholar API
+
+ You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up.
+
+ The wrapper uses the SerpApi.com python package:
+ https://serpapi.com/integrations/python
+
+ To use, you should have the environment variable ``SERPAPI_API_KEY``
+ set with your API key, or pass `serp_api_key` as a named parameter
+ to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import GoogleTrendsAPIWrapper
+ google_trends = GoogleTrendsAPIWrapper()
+ google_trends.run('langchain')
+ """
+
+ serp_search_engine: Any
+ serp_api_key: Optional[SecretStr] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ values["serp_api_key"] = convert_to_secret_str(
+ get_from_dict_or_env(values, "serp_api_key", "SERPAPI_API_KEY")
+ )
+
+ try:
+ from serpapi import SerpApiClient
+
+ except ImportError:
+ raise ImportError(
+ "google-search-results is not installed. "
+ "Please install it with `pip install google-search-results"
+ ">=2.4.2`"
+ )
+ serp_search_engine = SerpApiClient
+ values["serp_search_engine"] = serp_search_engine
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through Google Trends with Serpapi"""
+ serpapi_api_key = cast(SecretStr, self.serp_api_key)
+ params = {
+ "engine": "google_trends",
+ "api_key": serpapi_api_key.get_secret_value(),
+ "q": query,
+ }
+
+ total_results = []
+ client = self.serp_search_engine(params)
+ total_results = client.get_dict()["interest_over_time"]["timeline_data"]
+
+ if not total_results:
+ return "No good Trend Result was found"
+
+ start_date = total_results[0]["date"].split()
+ end_date = total_results[-1]["date"].split()
+ values = [
+ results.get("values")[0].get("extracted_value") for results in total_results
+ ]
+ min_value = min(values)
+ max_value = max(values)
+ avg_value = sum(values) / len(values)
+ percentage_change = (
+ (values[-1] - values[0])
+ / (values[0] if values[0] != 0 else 1)
+ * (100 if values[0] != 0 else 1)
+ )
+
+ params = {
+ "engine": "google_trends",
+ "api_key": serpapi_api_key.get_secret_value(),
+ "data_type": "RELATED_QUERIES",
+ "q": query,
+ }
+
+ total_results2 = {}
+ client = self.serp_search_engine(params)
+ total_results2 = client.get_dict().get("related_queries", {})
+ rising = []
+ top = []
+
+ rising = [results.get("query") for results in total_results2.get("rising", [])]
+ top = [results.get("query") for results in total_results2.get("top", [])]
+
+ doc = [
+ f"Query: {query}\n"
+ f"Date From: {start_date[0]} {start_date[1]}, {start_date[-1]}\n"
+ f"Date To: {end_date[0]} {end_date[3]} {end_date[-1]}\n"
+ f"Min Value: {min_value}\n"
+ f"Max Value: {max_value}\n"
+ f"Average Value: {avg_value}\n"
+ f"Percent Change: {str(percentage_change) + '%'}\n"
+ f"Trend values: {', '.join([str(x) for x in values])}\n"
+ f"Rising Related Queries: {', '.join(rising)}\n"
+ f"Top Related Queries: {', '.join(top)}"
+ ]
+
+ return "\n\n".join(doc)
diff --git a/libs/community/langchain_community/utilities/graphql.py b/libs/community/langchain_community/utilities/graphql.py
new file mode 100644
index 00000000000..87be94d09c3
--- /dev/null
+++ b/libs/community/langchain_community/utilities/graphql.py
@@ -0,0 +1,54 @@
+import json
+from typing import Any, Callable, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+
+class GraphQLAPIWrapper(BaseModel):
+ """Wrapper around GraphQL API.
+
+ To use, you should have the ``gql`` python package installed.
+ This wrapper will use the GraphQL API to conduct queries.
+ """
+
+ custom_headers: Optional[Dict[str, str]] = None
+ graphql_endpoint: str
+ gql_client: Any #: :meta private:
+ gql_function: Callable[[str], Any] #: :meta private:
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in the environment."""
+ try:
+ from gql import Client, gql
+ from gql.transport.requests import RequestsHTTPTransport
+ except ImportError as e:
+ raise ImportError(
+ "Could not import gql python package. "
+ f"Try installing it with `pip install gql`. Received error: {e}"
+ )
+ headers = values.get("custom_headers")
+ transport = RequestsHTTPTransport(
+ url=values["graphql_endpoint"],
+ headers=headers,
+ )
+ client = Client(transport=transport, fetch_schema_from_transport=True)
+ values["gql_client"] = client
+ values["gql_function"] = gql
+ return values
+
+ def run(self, query: str) -> str:
+ """Run a GraphQL query and get the results."""
+ result = self._execute_query(query)
+ return json.dumps(result, indent=2)
+
+ def _execute_query(self, query: str) -> Dict[str, Any]:
+ """Execute a GraphQL query and return the results."""
+ document_node = self.gql_function(query)
+ result = self.gql_client.execute(document_node)
+ return result
diff --git a/libs/community/langchain_community/utilities/jira.py b/libs/community/langchain_community/utilities/jira.py
new file mode 100644
index 00000000000..4d6522d7fd9
--- /dev/null
+++ b/libs/community/langchain_community/utilities/jira.py
@@ -0,0 +1,174 @@
+"""Util that calls Jira."""
+from typing import Any, Dict, List, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+# TODO: think about error handling, more specific api specs, and jql/project limits
+class JiraAPIWrapper(BaseModel):
+ """Wrapper for Jira API."""
+
+ jira: Any #: :meta private:
+ confluence: Any
+ jira_username: Optional[str] = None
+ jira_api_token: Optional[str] = None
+ jira_instance_url: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ jira_username = get_from_dict_or_env(values, "jira_username", "JIRA_USERNAME")
+ values["jira_username"] = jira_username
+
+ jira_api_token = get_from_dict_or_env(
+ values, "jira_api_token", "JIRA_API_TOKEN"
+ )
+ values["jira_api_token"] = jira_api_token
+
+ jira_instance_url = get_from_dict_or_env(
+ values, "jira_instance_url", "JIRA_INSTANCE_URL"
+ )
+ values["jira_instance_url"] = jira_instance_url
+
+ try:
+ from atlassian import Confluence, Jira
+ except ImportError:
+ raise ImportError(
+ "atlassian-python-api is not installed. "
+ "Please install it with `pip install atlassian-python-api`"
+ )
+
+ jira = Jira(
+ url=jira_instance_url,
+ username=jira_username,
+ password=jira_api_token,
+ cloud=True,
+ )
+
+ confluence = Confluence(
+ url=jira_instance_url,
+ username=jira_username,
+ password=jira_api_token,
+ cloud=True,
+ )
+
+ values["jira"] = jira
+ values["confluence"] = confluence
+
+ return values
+
+ def parse_issues(self, issues: Dict) -> List[dict]:
+ parsed = []
+ for issue in issues["issues"]:
+ key = issue["key"]
+ summary = issue["fields"]["summary"]
+ created = issue["fields"]["created"][0:10]
+ priority = issue["fields"]["priority"]["name"]
+ status = issue["fields"]["status"]["name"]
+ try:
+ assignee = issue["fields"]["assignee"]["displayName"]
+ except Exception:
+ assignee = "None"
+ rel_issues = {}
+ for related_issue in issue["fields"]["issuelinks"]:
+ if "inwardIssue" in related_issue.keys():
+ rel_type = related_issue["type"]["inward"]
+ rel_key = related_issue["inwardIssue"]["key"]
+ rel_summary = related_issue["inwardIssue"]["fields"]["summary"]
+ if "outwardIssue" in related_issue.keys():
+ rel_type = related_issue["type"]["outward"]
+ rel_key = related_issue["outwardIssue"]["key"]
+ rel_summary = related_issue["outwardIssue"]["fields"]["summary"]
+ rel_issues = {"type": rel_type, "key": rel_key, "summary": rel_summary}
+ parsed.append(
+ {
+ "key": key,
+ "summary": summary,
+ "created": created,
+ "assignee": assignee,
+ "priority": priority,
+ "status": status,
+ "related_issues": rel_issues,
+ }
+ )
+ return parsed
+
+ def parse_projects(self, projects: List[dict]) -> List[dict]:
+ parsed = []
+ for project in projects:
+ id = project["id"]
+ key = project["key"]
+ name = project["name"]
+ type = project["projectTypeKey"]
+ style = project["style"]
+ parsed.append(
+ {"id": id, "key": key, "name": name, "type": type, "style": style}
+ )
+ return parsed
+
+ def search(self, query: str) -> str:
+ issues = self.jira.jql(query)
+ parsed_issues = self.parse_issues(issues)
+ parsed_issues_str = (
+ "Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
+ )
+ return parsed_issues_str
+
+ def project(self) -> str:
+ projects = self.jira.projects()
+ parsed_projects = self.parse_projects(projects)
+ parsed_projects_str = (
+ "Found " + str(len(parsed_projects)) + " projects:\n" + str(parsed_projects)
+ )
+ return parsed_projects_str
+
+ def issue_create(self, query: str) -> str:
+ try:
+ import json
+ except ImportError:
+ raise ImportError(
+ "json is not installed. Please install it with `pip install json`"
+ )
+ params = json.loads(query)
+ return self.jira.issue_create(fields=dict(params))
+
+ def page_create(self, query: str) -> str:
+ try:
+ import json
+ except ImportError:
+ raise ImportError(
+ "json is not installed. Please install it with `pip install json`"
+ )
+ params = json.loads(query)
+ return self.confluence.create_page(**dict(params))
+
+ def other(self, query: str) -> str:
+ try:
+ import json
+ except ImportError:
+ raise ImportError(
+ "json is not installed. Please install it with `pip install json`"
+ )
+ params = json.loads(query)
+ jira_function = getattr(self.jira, params["function"])
+ return jira_function(*params.get("args", []), **params.get("kwargs", {}))
+
+ def run(self, mode: str, query: str) -> str:
+ if mode == "jql":
+ return self.search(query)
+ elif mode == "get_projects":
+ return self.project()
+ elif mode == "create_issue":
+ return self.issue_create(query)
+ elif mode == "other":
+ return self.other(query)
+ elif mode == "create_page":
+ return self.page_create(query)
+ else:
+ raise ValueError(f"Got unexpected mode {mode}")
diff --git a/libs/community/langchain_community/utilities/max_compute.py b/libs/community/langchain_community/utilities/max_compute.py
new file mode 100644
index 00000000000..3f6441803e0
--- /dev/null
+++ b/libs/community/langchain_community/utilities/max_compute.py
@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Iterator, List, Optional
+
+from langchain_core.utils import get_from_env
+
+if TYPE_CHECKING:
+ from odps import ODPS
+
+
+class MaxComputeAPIWrapper:
+ """Interface for querying Alibaba Cloud MaxCompute tables."""
+
+ def __init__(self, client: ODPS):
+ """Initialize MaxCompute document loader.
+
+ Args:
+ client: odps.ODPS MaxCompute client object.
+ """
+ self.client = client
+
+ @classmethod
+ def from_params(
+ cls,
+ endpoint: str,
+ project: str,
+ *,
+ access_id: Optional[str] = None,
+ secret_access_key: Optional[str] = None,
+ ) -> MaxComputeAPIWrapper:
+ """Convenience constructor that builds the odsp.ODPS MaxCompute client from
+ given parameters.
+
+ Args:
+ endpoint: MaxCompute endpoint.
+ project: A project is a basic organizational unit of MaxCompute, which is
+ similar to a database.
+ access_id: MaxCompute access ID. Should be passed in directly or set as the
+ environment variable `MAX_COMPUTE_ACCESS_ID`.
+ secret_access_key: MaxCompute secret access key. Should be passed in
+ directly or set as the environment variable
+ `MAX_COMPUTE_SECRET_ACCESS_KEY`.
+ """
+ try:
+ from odps import ODPS
+ except ImportError as ex:
+ raise ImportError(
+ "Could not import pyodps python package. "
+ "Please install it with `pip install pyodps` or refer to "
+ "https://pyodps.readthedocs.io/."
+ ) from ex
+ access_id = access_id or get_from_env("access_id", "MAX_COMPUTE_ACCESS_ID")
+ secret_access_key = secret_access_key or get_from_env(
+ "secret_access_key", "MAX_COMPUTE_SECRET_ACCESS_KEY"
+ )
+ client = ODPS(
+ access_id=access_id,
+ secret_access_key=secret_access_key,
+ project=project,
+ endpoint=endpoint,
+ )
+ if not client.exist_project(project):
+ raise ValueError(f'The project "{project}" does not exist.')
+
+ return cls(client)
+
+ def lazy_query(self, query: str) -> Iterator[dict]:
+ # Execute SQL query.
+ with self.client.execute_sql(query).open_reader() as reader:
+ if reader.count == 0:
+ raise ValueError("Table contains no data.")
+ for record in reader:
+ yield {k: v for k, v in record}
+
+ def query(self, query: str) -> List[dict]:
+ return list(self.lazy_query(query))
diff --git a/libs/community/langchain_community/utilities/merriam_webster.py b/libs/community/langchain_community/utilities/merriam_webster.py
new file mode 100644
index 00000000000..4904d01faa2
--- /dev/null
+++ b/libs/community/langchain_community/utilities/merriam_webster.py
@@ -0,0 +1,107 @@
+"""Util that calls Merriam-Webster."""
+import json
+from typing import Dict, Iterator, List, Optional
+from urllib.parse import quote
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+MERRIAM_WEBSTER_API_URL = (
+ "https://www.dictionaryapi.com/api/v3/references/collegiate/json"
+)
+MERRIAM_WEBSTER_TIMEOUT = 5000
+
+
+class MerriamWebsterAPIWrapper(BaseModel):
+ """Wrapper for Merriam-Webster.
+
+ Docs for using:
+
+ 1. Go to https://www.dictionaryapi.com/register/index and register an
+ developer account with a key for the Collegiate Dictionary
+ 2. Get your API Key from https://www.dictionaryapi.com/account/my-keys
+ 3. Save your API Key into MERRIAM_WEBSTER_API_KEY env variable
+
+ """
+
+ merriam_webster_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ merriam_webster_api_key = get_from_dict_or_env(
+ values, "merriam_webster_api_key", "MERRIAM_WEBSTER_API_KEY"
+ )
+ values["merriam_webster_api_key"] = merriam_webster_api_key
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through Merriam-Webster API and return a formatted result."""
+ quoted_query = quote(query)
+
+ request_url = (
+ f"{MERRIAM_WEBSTER_API_URL}/{quoted_query}"
+ f"?key={self.merriam_webster_api_key}"
+ )
+
+ response = requests.get(request_url, timeout=MERRIAM_WEBSTER_TIMEOUT)
+
+ if response.status_code != 200:
+ return response.text
+
+ return self._format_response(query, response)
+
+ def _format_response(self, query: str, response: requests.Response) -> str:
+ content = json.loads(response.content)
+
+ if not content:
+ return f"No Merriam-Webster definition was found for query '{query}'."
+
+ if isinstance(content[0], str):
+ result = f"No Merriam-Webster definition was found for query '{query}'.\n"
+ if len(content) > 1:
+ alternatives = [f"{i + 1}. {content[i]}" for i in range(len(content))]
+ result += "You can try one of the following alternative queries:\n\n"
+ result += "\n".join(alternatives)
+ else:
+ result += f"Did you mean '{content[0]}'?"
+ else:
+ result = self._format_definitions(query, content)
+
+ return result
+
+ def _format_definitions(self, query: str, definitions: List[Dict]) -> str:
+ formatted_definitions: List[str] = []
+ for definition in definitions:
+ formatted_definitions.extend(self._format_definition(definition))
+
+ if len(formatted_definitions) == 1:
+ return f"Definition of '{query}':\n" f"{formatted_definitions[0]}"
+
+ result = f"Definitions of '{query}':\n\n"
+ for i, formatted_definition in enumerate(formatted_definitions, 1):
+ result += f"{i}. {formatted_definition}\n"
+
+ return result
+
+ def _format_definition(self, definition: Dict) -> Iterator[str]:
+ if "hwi" in definition:
+ headword = definition["hwi"]["hw"].replace("*", "-")
+ else:
+ headword = definition["meta"]["id"].split(":")[0]
+
+ if "fl" in definition:
+ functional_label = definition["fl"]
+
+ if "shortdef" in definition:
+ for short_def in definition["shortdef"]:
+ yield f"{headword}, {functional_label}: {short_def}"
+ else:
+ yield f"{headword}, {functional_label}"
diff --git a/libs/community/langchain_community/utilities/metaphor_search.py b/libs/community/langchain_community/utilities/metaphor_search.py
new file mode 100644
index 00000000000..270ce8a5e43
--- /dev/null
+++ b/libs/community/langchain_community/utilities/metaphor_search.py
@@ -0,0 +1,171 @@
+"""Util that calls Metaphor Search API.
+
+In order to set this up, follow instructions at:
+"""
+import json
+from typing import Dict, List, Optional
+
+import aiohttp
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+METAPHOR_API_URL = "https://api.metaphor.systems"
+
+
+class MetaphorSearchAPIWrapper(BaseModel):
+ """Wrapper for Metaphor Search API."""
+
+ metaphor_api_key: str
+ k: int = 10
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _metaphor_search_results(
+ self,
+ query: str,
+ num_results: int,
+ include_domains: Optional[List[str]] = None,
+ exclude_domains: Optional[List[str]] = None,
+ start_crawl_date: Optional[str] = None,
+ end_crawl_date: Optional[str] = None,
+ start_published_date: Optional[str] = None,
+ end_published_date: Optional[str] = None,
+ use_autoprompt: Optional[bool] = None,
+ ) -> List[dict]:
+ headers = {"X-Api-Key": self.metaphor_api_key}
+ params = {
+ "numResults": num_results,
+ "query": query,
+ "includeDomains": include_domains,
+ "excludeDomains": exclude_domains,
+ "startCrawlDate": start_crawl_date,
+ "endCrawlDate": end_crawl_date,
+ "startPublishedDate": start_published_date,
+ "endPublishedDate": end_published_date,
+ "useAutoprompt": use_autoprompt,
+ }
+ response = requests.post(
+ # type: ignore
+ f"{METAPHOR_API_URL}/search",
+ headers=headers,
+ json=params,
+ )
+
+ response.raise_for_status()
+ search_results = response.json()
+ return search_results["results"]
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and endpoint exists in environment."""
+ metaphor_api_key = get_from_dict_or_env(
+ values, "metaphor_api_key", "METAPHOR_API_KEY"
+ )
+ values["metaphor_api_key"] = metaphor_api_key
+
+ return values
+
+ def results(
+ self,
+ query: str,
+ num_results: int,
+ include_domains: Optional[List[str]] = None,
+ exclude_domains: Optional[List[str]] = None,
+ start_crawl_date: Optional[str] = None,
+ end_crawl_date: Optional[str] = None,
+ start_published_date: Optional[str] = None,
+ end_published_date: Optional[str] = None,
+ use_autoprompt: Optional[bool] = None,
+ ) -> List[Dict]:
+ """Run query through Metaphor Search and return metadata.
+
+ Args:
+ query: The query to search for.
+ num_results: The number of results to return.
+ include_domains: A list of domains to include in the search. Only one of include_domains and exclude_domains should be defined.
+ exclude_domains: A list of domains to exclude from the search. Only one of include_domains and exclude_domains should be defined.
+ start_crawl_date: If specified, only pages we crawled after start_crawl_date will be returned.
+ end_crawl_date: If specified, only pages we crawled before end_crawl_date will be returned.
+ start_published_date: If specified, only pages published after start_published_date will be returned.
+ end_published_date: If specified, only pages published before end_published_date will be returned.
+ use_autoprompt: If true, we turn your query into a more Metaphor-friendly query. Adds latency.
+
+ Returns:
+ A list of dictionaries with the following keys:
+ title - The title of the page
+ url - The url
+ author - Author of the content, if applicable. Otherwise, None.
+ published_date - Estimated date published
+ in YYYY-MM-DD format. Otherwise, None.
+ """ # noqa: E501
+ raw_search_results = self._metaphor_search_results(
+ query,
+ num_results=num_results,
+ include_domains=include_domains,
+ exclude_domains=exclude_domains,
+ start_crawl_date=start_crawl_date,
+ end_crawl_date=end_crawl_date,
+ start_published_date=start_published_date,
+ end_published_date=end_published_date,
+ use_autoprompt=use_autoprompt,
+ )
+ return self._clean_results(raw_search_results)
+
+ async def results_async(
+ self,
+ query: str,
+ num_results: int,
+ include_domains: Optional[List[str]] = None,
+ exclude_domains: Optional[List[str]] = None,
+ start_crawl_date: Optional[str] = None,
+ end_crawl_date: Optional[str] = None,
+ start_published_date: Optional[str] = None,
+ end_published_date: Optional[str] = None,
+ use_autoprompt: Optional[bool] = None,
+ ) -> List[Dict]:
+ """Get results from the Metaphor Search API asynchronously."""
+
+ # Function to perform the API call
+ async def fetch() -> str:
+ headers = {"X-Api-Key": self.metaphor_api_key}
+ params = {
+ "numResults": num_results,
+ "query": query,
+ "includeDomains": include_domains,
+ "excludeDomains": exclude_domains,
+ "startCrawlDate": start_crawl_date,
+ "endCrawlDate": end_crawl_date,
+ "startPublishedDate": start_published_date,
+ "endPublishedDate": end_published_date,
+ "useAutoprompt": use_autoprompt,
+ }
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ f"{METAPHOR_API_URL}/search", json=params, headers=headers
+ ) as res:
+ if res.status == 200:
+ data = await res.text()
+ return data
+ else:
+ raise Exception(f"Error {res.status}: {res.reason}")
+
+ results_json_str = await fetch()
+ results_json = json.loads(results_json_str)
+ return self._clean_results(results_json["results"])
+
+ def _clean_results(self, raw_search_results: List[Dict]) -> List[Dict]:
+ cleaned_results = []
+ for result in raw_search_results:
+ cleaned_results.append(
+ {
+ "title": result.get("title", "Unknown Title"),
+ "url": result.get("url", "Unknown URL"),
+ "author": result.get("author", "Unknown Author"),
+ "published_date": result.get("publishedDate", "Unknown Date"),
+ }
+ )
+ return cleaned_results
diff --git a/libs/community/langchain_community/utilities/nasa.py b/libs/community/langchain_community/utilities/nasa.py
new file mode 100644
index 00000000000..b58889ca0de
--- /dev/null
+++ b/libs/community/langchain_community/utilities/nasa.py
@@ -0,0 +1,51 @@
+"""Util that calls several NASA APIs."""
+import json
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel
+
+IMAGE_AND_VIDEO_LIBRARY_URL = "https://images-api.nasa.gov"
+
+
+class NasaAPIWrapper(BaseModel):
+ def get_media(self, query: str) -> str:
+ params = json.loads(query)
+ if params.get("q"):
+ queryText = params["q"]
+ params.pop("q")
+ else:
+ queryText = ""
+ response = requests.get(
+ IMAGE_AND_VIDEO_LIBRARY_URL + "/search?q=" + queryText, params=params
+ )
+ data = response.json()
+ return data
+
+ def get_media_metadata_manifest(self, query: str) -> str:
+ response = requests.get(IMAGE_AND_VIDEO_LIBRARY_URL + "/asset/" + query)
+ return response.json()
+
+ def get_media_metadata_location(self, query: str) -> str:
+ response = requests.get(IMAGE_AND_VIDEO_LIBRARY_URL + "/metadata/" + query)
+ return response.json()
+
+ def get_video_captions_location(self, query: str) -> str:
+ response = requests.get(IMAGE_AND_VIDEO_LIBRARY_URL + "/captions/" + query)
+ return response.json()
+
+ def run(self, mode: str, query: str) -> str:
+ if mode == "search_media":
+ output = self.get_media(query)
+ elif mode == "get_media_metadata_manifest":
+ output = self.get_media_metadata_manifest(query)
+ elif mode == "get_media_metadata_location":
+ output = self.get_media_metadata_location(query)
+ elif mode == "get_video_captions_location":
+ output = self.get_video_captions_location(query)
+ else:
+ output = f"ModeError: Got unexpected mode {mode}."
+
+ try:
+ return json.dumps(output)
+ except Exception:
+ return str(output)
diff --git a/libs/community/langchain_community/utilities/opaqueprompts.py b/libs/community/langchain_community/utilities/opaqueprompts.py
new file mode 100644
index 00000000000..9473fd9e102
--- /dev/null
+++ b/libs/community/langchain_community/utilities/opaqueprompts.py
@@ -0,0 +1,102 @@
+from typing import Dict, Union
+
+
+def sanitize(
+ input: Union[str, Dict[str, str]],
+) -> Dict[str, Union[str, Dict[str, str]]]:
+ """
+ Sanitize input string or dict of strings by replacing sensitive data with
+ placeholders.
+
+ It returns the sanitized input string or dict of strings and the secure
+ context as a dict following the format:
+ {
+ "sanitized_input": ,
+ "secure_context":
+ }
+
+ The secure context is a bytes object that is needed to de-sanitize the response
+ from the LLM.
+
+ Args:
+ input: Input string or dict of strings.
+
+ Returns:
+ Sanitized input string or dict of strings and the secure context
+ as a dict following the format:
+ {
+ "sanitized_input": ,
+ "secure_context":
+ }
+
+ The `secure_context` needs to be passed to the `desanitize` function.
+
+ Raises:
+ ValueError: If the input is not a string or dict of strings.
+ ImportError: If the `opaqueprompts` Python package is not installed.
+ """
+ try:
+ import opaqueprompts as op
+ except ImportError:
+ raise ImportError(
+ "Could not import the `opaqueprompts` Python package, "
+ "please install it with `pip install opaqueprompts`."
+ )
+
+ if isinstance(input, str):
+ # the input could be a string, so we sanitize the string
+ sanitize_response: op.SanitizeResponse = op.sanitize([input])
+ return {
+ "sanitized_input": sanitize_response.sanitized_texts[0],
+ "secure_context": sanitize_response.secure_context,
+ }
+
+ if isinstance(input, dict):
+ # the input could be a dict[string, string], so we sanitize the values
+ values = list()
+
+ # get the values from the dict
+ for key in input:
+ values.append(input[key])
+
+ # sanitize the values
+ sanitize_values_response: op.SanitizeResponse = op.sanitize(values)
+
+ # reconstruct the dict with the sanitized values
+ sanitized_input_values = sanitize_values_response.sanitized_texts
+ idx = 0
+ sanitized_input = dict()
+ for key in input:
+ sanitized_input[key] = sanitized_input_values[idx]
+ idx += 1
+
+ return {
+ "sanitized_input": sanitized_input,
+ "secure_context": sanitize_values_response.secure_context,
+ }
+
+ raise ValueError(f"Unexpected input type {type(input)}")
+
+
+def desanitize(sanitized_text: str, secure_context: bytes) -> str:
+ """
+ Restore the original sensitive data from the sanitized text.
+
+ Args:
+ sanitized_text: Sanitized text.
+ secure_context: Secure context returned by the `sanitize` function.
+
+ Returns:
+ De-sanitized text.
+ """
+ try:
+ import opaqueprompts as op
+ except ImportError:
+ raise ImportError(
+ "Could not import the `opaqueprompts` Python package, "
+ "please install it with `pip install opaqueprompts`."
+ )
+ desanitize_response: op.DesanitizeResponse = op.desanitize(
+ sanitized_text, secure_context
+ )
+ return desanitize_response.desanitized_text
diff --git a/libs/community/langchain_community/utilities/openapi.py b/libs/community/langchain_community/utilities/openapi.py
new file mode 100644
index 00000000000..71263c369d5
--- /dev/null
+++ b/libs/community/langchain_community/utilities/openapi.py
@@ -0,0 +1,314 @@
+"""Utility functions for parsing an OpenAPI spec."""
+from __future__ import annotations
+
+import copy
+import json
+import logging
+import re
+from enum import Enum
+from pathlib import Path
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
+
+import requests
+import yaml
+from langchain_core.pydantic_v1 import ValidationError
+
+logger = logging.getLogger(__name__)
+
+
+class HTTPVerb(str, Enum):
+ """Enumerator of the HTTP verbs."""
+
+ GET = "get"
+ PUT = "put"
+ POST = "post"
+ DELETE = "delete"
+ OPTIONS = "options"
+ HEAD = "head"
+ PATCH = "patch"
+ TRACE = "trace"
+
+ @classmethod
+ def from_str(cls, verb: str) -> HTTPVerb:
+ """Parse an HTTP verb."""
+ try:
+ return cls(verb)
+ except ValueError:
+ raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
+
+
+if TYPE_CHECKING:
+ from openapi_pydantic import (
+ Components,
+ Operation,
+ Parameter,
+ PathItem,
+ Paths,
+ Reference,
+ RequestBody,
+ Schema,
+ )
+
+try:
+ from openapi_pydantic import OpenAPI
+except ImportError:
+ OpenAPI = object # type: ignore
+
+
+class OpenAPISpec(OpenAPI):
+ """OpenAPI Model that removes mis-formatted parts of the spec."""
+
+ openapi: str = "3.1.0" # overriding overly restrictive type from parent class
+
+ @property
+ def _paths_strict(self) -> Paths:
+ if not self.paths:
+ raise ValueError("No paths found in spec")
+ return self.paths
+
+ def _get_path_strict(self, path: str) -> PathItem:
+ path_item = self._paths_strict.get(path)
+ if not path_item:
+ raise ValueError(f"No path found for {path}")
+ return path_item
+
+ @property
+ def _components_strict(self) -> Components:
+ """Get components or err."""
+ if self.components is None:
+ raise ValueError("No components found in spec. ")
+ return self.components
+
+ @property
+ def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
+ """Get parameters or err."""
+ parameters = self._components_strict.parameters
+ if parameters is None:
+ raise ValueError("No parameters found in spec. ")
+ return parameters
+
+ @property
+ def _schemas_strict(self) -> Dict[str, Schema]:
+ """Get the dictionary of schemas or err."""
+ schemas = self._components_strict.schemas
+ if schemas is None:
+ raise ValueError("No schemas found in spec. ")
+ return schemas
+
+ @property
+ def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
+ """Get the request body or err."""
+ request_bodies = self._components_strict.requestBodies
+ if request_bodies is None:
+ raise ValueError("No request body found in spec. ")
+ return request_bodies
+
+ def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
+ """Get a parameter (or nested reference) or err."""
+ ref_name = ref.ref.split("/")[-1]
+ parameters = self._parameters_strict
+ if ref_name not in parameters:
+ raise ValueError(f"No parameter found for {ref_name}")
+ return parameters[ref_name]
+
+ def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
+ """Get the root reference or err."""
+ from openapi_pydantic import Reference
+
+ parameter = self._get_referenced_parameter(ref)
+ while isinstance(parameter, Reference):
+ parameter = self._get_referenced_parameter(parameter)
+ return parameter
+
+ def get_referenced_schema(self, ref: Reference) -> Schema:
+ """Get a schema (or nested reference) or err."""
+ ref_name = ref.ref.split("/")[-1]
+ schemas = self._schemas_strict
+ if ref_name not in schemas:
+ raise ValueError(f"No schema found for {ref_name}")
+ return schemas[ref_name]
+
+ def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
+ from openapi_pydantic import Reference
+
+ if isinstance(schema, Reference):
+ return self.get_referenced_schema(schema)
+ return schema
+
+ def _get_root_referenced_schema(self, ref: Reference) -> Schema:
+ """Get the root reference or err."""
+ from openapi_pydantic import Reference
+
+ schema = self.get_referenced_schema(ref)
+ while isinstance(schema, Reference):
+ schema = self.get_referenced_schema(schema)
+ return schema
+
+ def _get_referenced_request_body(
+ self, ref: Reference
+ ) -> Optional[Union[Reference, RequestBody]]:
+ """Get a request body (or nested reference) or err."""
+ ref_name = ref.ref.split("/")[-1]
+ request_bodies = self._request_bodies_strict
+ if ref_name not in request_bodies:
+ raise ValueError(f"No request body found for {ref_name}")
+ return request_bodies[ref_name]
+
+ def _get_root_referenced_request_body(
+ self, ref: Reference
+ ) -> Optional[RequestBody]:
+ """Get the root request Body or err."""
+ from openapi_pydantic import Reference
+
+ request_body = self._get_referenced_request_body(ref)
+ while isinstance(request_body, Reference):
+ request_body = self._get_referenced_request_body(request_body)
+ return request_body
+
+ @staticmethod
+ def _alert_unsupported_spec(obj: dict) -> None:
+ """Alert if the spec is not supported."""
+ warning_message = (
+ " This may result in degraded performance."
+ + " Convert your OpenAPI spec to 3.1.* spec"
+ + " for better support."
+ )
+ swagger_version = obj.get("swagger")
+ openapi_version = obj.get("openapi")
+ if isinstance(openapi_version, str):
+ if openapi_version != "3.1.0":
+ logger.warning(
+ f"Attempting to load an OpenAPI {openapi_version}"
+ f" spec. {warning_message}"
+ )
+ else:
+ pass
+ elif isinstance(swagger_version, str):
+ logger.warning(
+ f"Attempting to load a Swagger {swagger_version}"
+ f" spec. {warning_message}"
+ )
+ else:
+ raise ValueError(
+ "Attempting to load an unsupported spec:"
+ f"\n\n{obj}\n{warning_message}"
+ )
+
+ @classmethod
+ def parse_obj(cls, obj: dict) -> OpenAPISpec:
+ try:
+ cls._alert_unsupported_spec(obj)
+ return super().parse_obj(obj)
+ except ValidationError as e:
+ # We are handling possibly misconfigured specs and
+ # want to do a best-effort job to get a reasonable interface out of it.
+ new_obj = copy.deepcopy(obj)
+ for error in e.errors():
+ keys = error["loc"]
+ item = new_obj
+ for key in keys[:-1]:
+ item = item[key]
+ item.pop(keys[-1], None)
+ return cls.parse_obj(new_obj)
+
+ @classmethod
+ def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
+ """Get an OpenAPI spec from a dict."""
+ return cls.parse_obj(spec_dict)
+
+ @classmethod
+ def from_text(cls, text: str) -> OpenAPISpec:
+ """Get an OpenAPI spec from a text."""
+ try:
+ spec_dict = json.loads(text)
+ except json.JSONDecodeError:
+ spec_dict = yaml.safe_load(text)
+ return cls.from_spec_dict(spec_dict)
+
+ @classmethod
+ def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
+ """Get an OpenAPI spec from a file path."""
+ path_ = path if isinstance(path, Path) else Path(path)
+ if not path_.exists():
+ raise FileNotFoundError(f"{path} does not exist")
+ with path_.open("r") as f:
+ return cls.from_text(f.read())
+
+ @classmethod
+ def from_url(cls, url: str) -> OpenAPISpec:
+ """Get an OpenAPI spec from a URL."""
+ response = requests.get(url)
+ return cls.from_text(response.text)
+
+ @property
+ def base_url(self) -> str:
+ """Get the base url."""
+ return self.servers[0].url
+
+ def get_methods_for_path(self, path: str) -> List[str]:
+ """Return a list of valid methods for the specified path."""
+ from openapi_pydantic import Operation
+
+ path_item = self._get_path_strict(path)
+ results = []
+ for method in HTTPVerb:
+ operation = getattr(path_item, method.value, None)
+ if isinstance(operation, Operation):
+ results.append(method.value)
+ return results
+
+ def get_parameters_for_path(self, path: str) -> List[Parameter]:
+ from openapi_pydantic import Reference
+
+ path_item = self._get_path_strict(path)
+ parameters = []
+ if not path_item.parameters:
+ return []
+ for parameter in path_item.parameters:
+ if isinstance(parameter, Reference):
+ parameter = self._get_root_referenced_parameter(parameter)
+ parameters.append(parameter)
+ return parameters
+
+ def get_operation(self, path: str, method: str) -> Operation:
+ """Get the operation object for a given path and HTTP method."""
+ from openapi_pydantic import Operation
+
+ path_item = self._get_path_strict(path)
+ operation_obj = getattr(path_item, method, None)
+ if not isinstance(operation_obj, Operation):
+ raise ValueError(f"No {method} method found for {path}")
+ return operation_obj
+
+ def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
+ """Get the components for a given operation."""
+ from openapi_pydantic import Reference
+
+ parameters = []
+ if operation.parameters:
+ for parameter in operation.parameters:
+ if isinstance(parameter, Reference):
+ parameter = self._get_root_referenced_parameter(parameter)
+ parameters.append(parameter)
+ return parameters
+
+ def get_request_body_for_operation(
+ self, operation: Operation
+ ) -> Optional[RequestBody]:
+ """Get the request body for a given operation."""
+ from openapi_pydantic import Reference
+
+ request_body = operation.requestBody
+ if isinstance(request_body, Reference):
+ request_body = self._get_root_referenced_request_body(request_body)
+ return request_body
+
+ @staticmethod
+ def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
+ """Get a cleaned operation id from an operation id."""
+ operation_id = operation.operationId
+ if operation_id is None:
+ # Replace all punctuation of any kind with underscore
+ path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
+ operation_id = f"{path}_{method}"
+ return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
diff --git a/libs/community/langchain_community/utilities/openweathermap.py b/libs/community/langchain_community/utilities/openweathermap.py
new file mode 100644
index 00000000000..f79149a1be4
--- /dev/null
+++ b/libs/community/langchain_community/utilities/openweathermap.py
@@ -0,0 +1,76 @@
+"""Util that calls OpenWeatherMap using PyOWM."""
+from typing import Any, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class OpenWeatherMapAPIWrapper(BaseModel):
+ """Wrapper for OpenWeatherMap API using PyOWM.
+
+ Docs for using:
+
+ 1. Go to OpenWeatherMap and sign up for an API key
+ 2. Save your API KEY into OPENWEATHERMAP_API_KEY env variable
+ 3. pip install pyowm
+ """
+
+ owm: Any
+ openweathermap_api_key: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ openweathermap_api_key = get_from_dict_or_env(
+ values, "openweathermap_api_key", "OPENWEATHERMAP_API_KEY"
+ )
+
+ try:
+ import pyowm
+
+ except ImportError:
+ raise ImportError(
+ "pyowm is not installed. Please install it with `pip install pyowm`"
+ )
+
+ owm = pyowm.OWM(openweathermap_api_key)
+ values["owm"] = owm
+
+ return values
+
+ def _format_weather_info(self, location: str, w: Any) -> str:
+ detailed_status = w.detailed_status
+ wind = w.wind()
+ humidity = w.humidity
+ temperature = w.temperature("celsius")
+ rain = w.rain
+ heat_index = w.heat_index
+ clouds = w.clouds
+
+ return (
+ f"In {location}, the current weather is as follows:\n"
+ f"Detailed status: {detailed_status}\n"
+ f"Wind speed: {wind['speed']} m/s, direction: {wind['deg']}Β°\n"
+ f"Humidity: {humidity}%\n"
+ f"Temperature: \n"
+ f" - Current: {temperature['temp']}Β°C\n"
+ f" - High: {temperature['temp_max']}Β°C\n"
+ f" - Low: {temperature['temp_min']}Β°C\n"
+ f" - Feels like: {temperature['feels_like']}Β°C\n"
+ f"Rain: {rain}\n"
+ f"Heat index: {heat_index}\n"
+ f"Cloud cover: {clouds}%"
+ )
+
+ def run(self, location: str) -> str:
+ """Get the current weather information for a specified location."""
+ mgr = self.owm.weather_manager()
+ observation = mgr.weather_at_place(location)
+ w = observation.weather
+
+ return self._format_weather_info(location, w)
diff --git a/libs/community/langchain_community/utilities/outline.py b/libs/community/langchain_community/utilities/outline.py
new file mode 100644
index 00000000000..b9a4e3cad4d
--- /dev/null
+++ b/libs/community/langchain_community/utilities/outline.py
@@ -0,0 +1,95 @@
+"""Util that calls Outline."""
+import logging
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+OUTLINE_MAX_QUERY_LENGTH = 300
+
+
+class OutlineAPIWrapper(BaseModel):
+ """Wrapper around OutlineAPI.
+
+ This wrapper will use the Outline API to query the documents of your instance.
+ By default it will return the document content of the top-k results.
+ It limits the document content by doc_content_chars_max.
+ """
+
+ top_k_results: int = 3
+ load_all_available_meta: bool = False
+ doc_content_chars_max: int = 4000
+ outline_instance_url: Optional[str] = None
+ outline_api_key: Optional[str] = None
+ outline_search_endpoint: str = "/api/documents.search"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that instance url and api key exists in environment."""
+ outline_instance_url = get_from_dict_or_env(
+ values, "outline_instance_url", "OUTLINE_INSTANCE_URL"
+ )
+ values["outline_instance_url"] = outline_instance_url
+
+ outline_api_key = get_from_dict_or_env(
+ values, "outline_api_key", "OUTLINE_API_KEY"
+ )
+ values["outline_api_key"] = outline_api_key
+
+ return values
+
+ def _result_to_document(self, outline_res: Any) -> Document:
+ main_meta = {
+ "title": outline_res["document"]["title"],
+ "source": self.outline_instance_url + outline_res["document"]["url"],
+ }
+ add_meta = (
+ {
+ "id": outline_res["document"]["id"],
+ "ranking": outline_res["ranking"],
+ "collection_id": outline_res["document"]["collectionId"],
+ "parent_document_id": outline_res["document"]["parentDocumentId"],
+ "revision": outline_res["document"]["revision"],
+ "created_by": outline_res["document"]["createdBy"]["name"],
+ }
+ if self.load_all_available_meta
+ else {}
+ )
+ doc = Document(
+ page_content=outline_res["document"]["text"][: self.doc_content_chars_max],
+ metadata={
+ **main_meta,
+ **add_meta,
+ },
+ )
+ return doc
+
+ def _outline_api_query(self, query: str) -> List:
+ raw_result = requests.post(
+ f"{self.outline_instance_url}{self.outline_search_endpoint}",
+ data={"query": query, "limit": self.top_k_results},
+ headers={"Authorization": f"Bearer {self.outline_api_key}"},
+ )
+
+ if not raw_result.ok:
+ raise ValueError("Outline API returned an error: ", raw_result.text)
+
+ return raw_result.json()["data"]
+
+ def run(self, query: str) -> List[Document]:
+ """
+ Run Outline search and get the document content plus the meta information.
+
+ Returns: a list of documents.
+
+ """
+ results = self._outline_api_query(query[:OUTLINE_MAX_QUERY_LENGTH])
+ docs = []
+ for result in results[: self.top_k_results]:
+ if doc := self._result_to_document(result):
+ docs.append(doc)
+ return docs
diff --git a/libs/community/langchain_community/utilities/portkey.py b/libs/community/langchain_community/utilities/portkey.py
new file mode 100644
index 00000000000..bf9044c4f15
--- /dev/null
+++ b/libs/community/langchain_community/utilities/portkey.py
@@ -0,0 +1,75 @@
+import json
+import os
+from typing import Dict, Optional
+
+
+class Portkey:
+ """Portkey configuration.
+
+ Attributes:
+ base: The base URL for the Portkey API.
+ Default: "https://api.portkey.ai/v1/proxy"
+ """
+
+ base = "https://api.portkey.ai/v1/proxy"
+
+ @staticmethod
+ def Config(
+ api_key: str,
+ trace_id: Optional[str] = None,
+ environment: Optional[str] = None,
+ user: Optional[str] = None,
+ organisation: Optional[str] = None,
+ prompt: Optional[str] = None,
+ retry_count: Optional[int] = None,
+ cache: Optional[str] = None,
+ cache_force_refresh: Optional[str] = None,
+ cache_age: Optional[int] = None,
+ ) -> Dict[str, str]:
+ assert retry_count is None or retry_count in range(
+ 1, 6
+ ), "retry_count must be an integer and in range [1, 2, 3, 4, 5]"
+ assert cache is None or cache in [
+ "simple",
+ "semantic",
+ ], "cache must be 'simple' or 'semantic'"
+ assert cache_force_refresh is None or (
+ isinstance(cache_force_refresh, str)
+ and cache_force_refresh in ["True", "False"]
+ ), "cache_force_refresh must be 'True' or 'False'"
+ assert cache_age is None or isinstance(
+ cache_age, int
+ ), "cache_age must be an integer"
+
+ os.environ["OPENAI_API_BASE"] = Portkey.base
+
+ headers = {
+ "x-portkey-api-key": api_key,
+ "x-portkey-mode": "proxy openai",
+ }
+
+ if trace_id:
+ headers["x-portkey-trace-id"] = trace_id
+ if retry_count:
+ headers["x-portkey-retry-count"] = str(retry_count)
+ if cache:
+ headers["x-portkey-cache"] = cache
+ if cache_force_refresh:
+ headers["x-portkey-cache-force-refresh"] = cache_force_refresh
+ if cache_age:
+ headers["Cache-Control"] = f"max-age:{str(cache_age)}"
+
+ metadata = {}
+ if environment:
+ metadata["_environment"] = environment
+ if user:
+ metadata["_user"] = user
+ if organisation:
+ metadata["_organisation"] = organisation
+ if prompt:
+ metadata["_prompt"] = prompt
+
+ if metadata:
+ headers.update({"x-portkey-metadata": json.dumps(metadata)})
+
+ return headers
diff --git a/libs/community/langchain_community/utilities/powerbi.py b/libs/community/langchain_community/utilities/powerbi.py
new file mode 100644
index 00000000000..88219936260
--- /dev/null
+++ b/libs/community/langchain_community/utilities/powerbi.py
@@ -0,0 +1,276 @@
+"""Wrapper around a Power BI endpoint."""
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
+
+import aiohttp
+import requests
+from aiohttp import ServerTimeoutError
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
+from requests.exceptions import Timeout
+
+logger = logging.getLogger(__name__)
+
+BASE_URL = os.getenv("POWERBI_BASE_URL", "https://api.powerbi.com/v1.0/myorg")
+
+if TYPE_CHECKING:
+ from azure.core.credentials import TokenCredential
+
+
+class PowerBIDataset(BaseModel):
+ """Create PowerBI engine from dataset ID and credential or token.
+
+ Use either the credential or a supplied token to authenticate.
+ If both are supplied the credential is used to generate a token.
+ The impersonated_user_name is the UPN of a user to be impersonated.
+ If the model is not RLS enabled, this will be ignored.
+ """
+
+ dataset_id: str
+ table_names: List[str]
+ group_id: Optional[str] = None
+ credential: Optional[TokenCredential] = None
+ token: Optional[str] = None
+ impersonated_user_name: Optional[str] = None
+ sample_rows_in_table_info: int = Field(default=1, gt=0, le=10)
+ schemas: Dict[str, str] = Field(default_factory=dict)
+ aiosession: Optional[aiohttp.ClientSession] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @validator("table_names", allow_reuse=True)
+ def fix_table_names(cls, table_names: List[str]) -> List[str]:
+ """Fix the table names."""
+ return [fix_table_name(table) for table in table_names]
+
+ @root_validator(pre=True, allow_reuse=True)
+ def token_or_credential_present(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Validate that at least one of token and credentials is present."""
+ if "token" in values or "credential" in values:
+ return values
+ raise ValueError("Please provide either a credential or a token.")
+
+ @property
+ def request_url(self) -> str:
+ """Get the request url."""
+ if self.group_id:
+ return f"{BASE_URL}/groups/{self.group_id}/datasets/{self.dataset_id}/executeQueries" # noqa: E501 # pylint: disable=C0301
+ return f"{BASE_URL}/datasets/{self.dataset_id}/executeQueries" # noqa: E501 # pylint: disable=C0301
+
+ @property
+ def headers(self) -> Dict[str, str]:
+ """Get the token."""
+ if self.token:
+ return {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer " + self.token,
+ }
+ from azure.core.exceptions import (
+ ClientAuthenticationError, # pylint: disable=import-outside-toplevel
+ )
+
+ if self.credential:
+ try:
+ token = self.credential.get_token(
+ "https://analysis.windows.net/powerbi/api/.default"
+ ).token
+ return {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer " + token,
+ }
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ raise ClientAuthenticationError(
+ "Could not get a token from the supplied credentials."
+ ) from exc
+ raise ClientAuthenticationError("No credential or token supplied.")
+
+ def get_table_names(self) -> Iterable[str]:
+ """Get names of tables available."""
+ return self.table_names
+
+ def get_schemas(self) -> str:
+ """Get the available schema's."""
+ if self.schemas:
+ return ", ".join([f"{key}: {value}" for key, value in self.schemas.items()])
+ return "No known schema's yet. Use the schema_powerbi tool first."
+
+ @property
+ def table_info(self) -> str:
+ """Information about all tables in the database."""
+ return self.get_table_info()
+
+ def _get_tables_to_query(
+ self, table_names: Optional[Union[List[str], str]] = None
+ ) -> Optional[List[str]]:
+ """Get the tables names that need to be queried, after checking they exist."""
+ if table_names is not None:
+ if (
+ isinstance(table_names, list)
+ and len(table_names) > 0
+ and table_names[0] != ""
+ ):
+ fixed_tables = [fix_table_name(table) for table in table_names]
+ non_existing_tables = [
+ table for table in fixed_tables if table not in self.table_names
+ ]
+ if non_existing_tables:
+ logger.warning(
+ "Table(s) %s not found in dataset.",
+ ", ".join(non_existing_tables),
+ )
+ tables = [
+ table for table in fixed_tables if table not in non_existing_tables
+ ]
+ return tables if tables else None
+ if isinstance(table_names, str) and table_names != "":
+ if table_names not in self.table_names:
+ logger.warning("Table %s not found in dataset.", table_names)
+ return None
+ return [fix_table_name(table_names)]
+ return self.table_names
+
+ def _get_tables_todo(self, tables_todo: List[str]) -> List[str]:
+ """Get the tables that still need to be queried."""
+ return [table for table in tables_todo if table not in self.schemas]
+
+ def _get_schema_for_tables(self, table_names: List[str]) -> str:
+ """Create a string of the table schemas for the supplied tables."""
+ schemas = [
+ schema for table, schema in self.schemas.items() if table in table_names
+ ]
+ return ", ".join(schemas)
+
+ def get_table_info(
+ self, table_names: Optional[Union[List[str], str]] = None
+ ) -> str:
+ """Get information about specified tables."""
+ tables_requested = self._get_tables_to_query(table_names)
+ if tables_requested is None:
+ return "No (valid) tables requested."
+ tables_todo = self._get_tables_todo(tables_requested)
+ for table in tables_todo:
+ self._get_schema(table)
+ return self._get_schema_for_tables(tables_requested)
+
+ async def aget_table_info(
+ self, table_names: Optional[Union[List[str], str]] = None
+ ) -> str:
+ """Get information about specified tables."""
+ tables_requested = self._get_tables_to_query(table_names)
+ if tables_requested is None:
+ return "No (valid) tables requested."
+ tables_todo = self._get_tables_todo(tables_requested)
+ await asyncio.gather(*[self._aget_schema(table) for table in tables_todo])
+ return self._get_schema_for_tables(tables_requested)
+
+ def _get_schema(self, table: str) -> None:
+ """Get the schema for a table."""
+ try:
+ result = self.run(
+ f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
+ )
+ self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
+ except Timeout:
+ logger.warning("Timeout while getting table info for %s", table)
+ self.schemas[table] = "unknown"
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ logger.warning("Error while getting table info for %s: %s", table, exc)
+ self.schemas[table] = "unknown"
+
+ async def _aget_schema(self, table: str) -> None:
+ """Get the schema for a table."""
+ try:
+ result = await self.arun(
+ f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
+ )
+ self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
+ except ServerTimeoutError:
+ logger.warning("Timeout while getting table info for %s", table)
+ self.schemas[table] = "unknown"
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ logger.warning("Error while getting table info for %s: %s", table, exc)
+ self.schemas[table] = "unknown"
+
+ def _create_json_content(self, command: str) -> dict[str, Any]:
+ """Create the json content for the request."""
+ return {
+ "queries": [{"query": rf"{command}"}],
+ "impersonatedUserName": self.impersonated_user_name,
+ "serializerSettings": {"includeNulls": True},
+ }
+
+ def run(self, command: str) -> Any:
+ """Execute a DAX command and return a json representing the results."""
+ logger.debug("Running command: %s", command)
+ response = requests.post(
+ self.request_url,
+ json=self._create_json_content(command),
+ headers=self.headers,
+ timeout=10,
+ )
+ if response.status_code == 403:
+ return (
+ "TokenError: Could not login to PowerBI, please check your credentials."
+ )
+ return response.json()
+
+ async def arun(self, command: str) -> Any:
+ """Execute a DAX command and return the result asynchronously."""
+ logger.debug("Running command: %s", command)
+ if self.aiosession:
+ async with self.aiosession.post(
+ self.request_url,
+ headers=self.headers,
+ json=self._create_json_content(command),
+ timeout=10,
+ ) as response:
+ if response.status == 403:
+ return "TokenError: Could not login to PowerBI, please check your credentials." # noqa: E501
+ response_json = await response.json(content_type=response.content_type)
+ return response_json
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ self.request_url,
+ headers=self.headers,
+ json=self._create_json_content(command),
+ timeout=10,
+ ) as response:
+ if response.status == 403:
+ return "TokenError: Could not login to PowerBI, please check your credentials." # noqa: E501
+ response_json = await response.json(content_type=response.content_type)
+ return response_json
+
+
+def json_to_md(
+ json_contents: List[Dict[str, Union[str, int, float]]],
+ table_name: Optional[str] = None,
+) -> str:
+ """Converts a JSON object to a markdown table."""
+ if len(json_contents) == 0:
+ return ""
+ output_md = ""
+ headers = json_contents[0].keys()
+ for header in headers:
+ header.replace("[", ".").replace("]", "")
+ if table_name:
+ header.replace(f"{table_name}.", "")
+ output_md += f"| {header} "
+ output_md += "|\n"
+ for row in json_contents:
+ for value in row.values():
+ output_md += f"| {value} "
+ output_md += "|\n"
+ return output_md
+
+
+def fix_table_name(table: str) -> str:
+ """Add single quotes around table names that contain spaces."""
+ if " " in table and not table.startswith("'") and not table.endswith("'"):
+ return f"'{table}'"
+ return table
diff --git a/libs/community/langchain_community/utilities/pubmed.py b/libs/community/langchain_community/utilities/pubmed.py
new file mode 100644
index 00000000000..981799d9de3
--- /dev/null
+++ b/libs/community/langchain_community/utilities/pubmed.py
@@ -0,0 +1,200 @@
+import json
+import logging
+import time
+import urllib.error
+import urllib.parse
+import urllib.request
+from typing import Any, Dict, Iterator, List
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+logger = logging.getLogger(__name__)
+
+
+class PubMedAPIWrapper(BaseModel):
+ """
+ Wrapper around PubMed API.
+
+ This wrapper will use the PubMed API to conduct searches and fetch
+ document summaries. By default, it will return the document summaries
+ of the top-k results of an input search.
+
+ Parameters:
+ top_k_results: number of the top-scored document used for the PubMed tool
+ MAX_QUERY_LENGTH: maximum length of the query.
+ Default is 300 characters.
+ doc_content_chars_max: maximum length of the document content.
+ Content will be truncated if it exceeds this length.
+ Default is 2000 characters.
+ max_retry: maximum number of retries for a request. Default is 5.
+ sleep_time: time to wait between retries.
+ Default is 0.2 seconds.
+ email: email address to be used for the PubMed API.
+ """
+
+ parse: Any #: :meta private:
+
+ base_url_esearch: str = (
+ "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
+ )
+ base_url_efetch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
+ max_retry: int = 5
+ sleep_time: float = 0.2
+
+ # Default values for the parameters
+ top_k_results: int = 3
+ MAX_QUERY_LENGTH: int = 300
+ doc_content_chars_max: int = 2000
+ email: str = "your_email@example.com"
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in environment."""
+ try:
+ import xmltodict
+
+ values["parse"] = xmltodict.parse
+ except ImportError:
+ raise ImportError(
+ "Could not import xmltodict python package. "
+ "Please install it with `pip install xmltodict`."
+ )
+ return values
+
+ def run(self, query: str) -> str:
+ """
+ Run PubMed search and get the article meta information.
+ See https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch
+ It uses only the most informative fields of article meta information.
+ """
+
+ try:
+ # Retrieve the top-k results for the query
+ docs = [
+ f"Published: {result['Published']}\n"
+ f"Title: {result['Title']}\n"
+ f"Copyright Information: {result['Copyright Information']}\n"
+ f"Summary::\n{result['Summary']}"
+ for result in self.load(query[: self.MAX_QUERY_LENGTH])
+ ]
+
+ # Join the results and limit the character count
+ return (
+ "\n\n".join(docs)[: self.doc_content_chars_max]
+ if docs
+ else "No good PubMed Result was found"
+ )
+ except Exception as ex:
+ return f"PubMed exception: {ex}"
+
+ def lazy_load(self, query: str) -> Iterator[dict]:
+ """
+ Search PubMed for documents matching the query.
+ Return an iterator of dictionaries containing the document metadata.
+ """
+
+ url = (
+ self.base_url_esearch
+ + "db=pubmed&term="
+ + str({urllib.parse.quote(query)})
+ + f"&retmode=json&retmax={self.top_k_results}&usehistory=y"
+ )
+ result = urllib.request.urlopen(url)
+ text = result.read().decode("utf-8")
+ json_text = json.loads(text)
+
+ webenv = json_text["esearchresult"]["webenv"]
+ for uid in json_text["esearchresult"]["idlist"]:
+ yield self.retrieve_article(uid, webenv)
+
+ def load(self, query: str) -> List[dict]:
+ """
+ Search PubMed for documents matching the query.
+ Return a list of dictionaries containing the document metadata.
+ """
+ return list(self.lazy_load(query))
+
+ def _dict2document(self, doc: dict) -> Document:
+ summary = doc.pop("Summary")
+ return Document(page_content=summary, metadata=doc)
+
+ def lazy_load_docs(self, query: str) -> Iterator[Document]:
+ for d in self.lazy_load(query=query):
+ yield self._dict2document(d)
+
+ def load_docs(self, query: str) -> List[Document]:
+ return list(self.lazy_load_docs(query=query))
+
+ def retrieve_article(self, uid: str, webenv: str) -> dict:
+ url = (
+ self.base_url_efetch
+ + "db=pubmed&retmode=xml&id="
+ + uid
+ + "&webenv="
+ + webenv
+ )
+
+ retry = 0
+ while True:
+ try:
+ result = urllib.request.urlopen(url)
+ break
+ except urllib.error.HTTPError as e:
+ if e.code == 429 and retry < self.max_retry:
+ # Too Many Requests errors
+ # wait for an exponentially increasing amount of time
+ print(
+ f"Too Many Requests, "
+ f"waiting for {self.sleep_time:.2f} seconds..."
+ )
+ time.sleep(self.sleep_time)
+ self.sleep_time *= 2
+ retry += 1
+ else:
+ raise e
+
+ xml_text = result.read().decode("utf-8")
+ text_dict = self.parse(xml_text)
+ return self._parse_article(uid, text_dict)
+
+ def _parse_article(self, uid: str, text_dict: dict) -> dict:
+ try:
+ ar = text_dict["PubmedArticleSet"]["PubmedArticle"]["MedlineCitation"][
+ "Article"
+ ]
+ except KeyError:
+ ar = text_dict["PubmedArticleSet"]["PubmedBookArticle"]["BookDocument"]
+ abstract_text = ar.get("Abstract", {}).get("AbstractText", [])
+ summaries = [
+ f"{txt['@Label']}: {txt['#text']}"
+ for txt in abstract_text
+ if "#text" in txt and "@Label" in txt
+ ]
+ summary = (
+ "\n".join(summaries)
+ if summaries
+ else (
+ abstract_text
+ if isinstance(abstract_text, str)
+ else (
+ "\n".join(str(value) for value in abstract_text.values())
+ if isinstance(abstract_text, dict)
+ else "No abstract available"
+ )
+ )
+ )
+ a_d = ar.get("ArticleDate", {})
+ pub_date = "-".join(
+ [a_d.get("Year", ""), a_d.get("Month", ""), a_d.get("Day", "")]
+ )
+
+ return {
+ "uid": uid,
+ "Title": ar.get("ArticleTitle", ""),
+ "Published": pub_date,
+ "Copyright Information": ar.get("Abstract", {}).get(
+ "CopyrightInformation", ""
+ ),
+ "Summary": summary,
+ }
diff --git a/libs/community/langchain_community/utilities/python.py b/libs/community/langchain_community/utilities/python.py
new file mode 100644
index 00000000000..70c3119e5f6
--- /dev/null
+++ b/libs/community/langchain_community/utilities/python.py
@@ -0,0 +1,71 @@
+import functools
+import logging
+import multiprocessing
+import sys
+from io import StringIO
+from typing import Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+logger = logging.getLogger(__name__)
+
+
+@functools.lru_cache(maxsize=None)
+def warn_once() -> None:
+ """Warn once about the dangers of PythonREPL."""
+ logger.warning("Python REPL can execute arbitrary code. Use with caution.")
+
+
+class PythonREPL(BaseModel):
+ """Simulates a standalone Python REPL."""
+
+ globals: Optional[Dict] = Field(default_factory=dict, alias="_globals")
+ locals: Optional[Dict] = Field(default_factory=dict, alias="_locals")
+
+ @classmethod
+ def worker(
+ cls,
+ command: str,
+ globals: Optional[Dict],
+ locals: Optional[Dict],
+ queue: multiprocessing.Queue,
+ ) -> None:
+ old_stdout = sys.stdout
+ sys.stdout = mystdout = StringIO()
+ try:
+ exec(command, globals, locals)
+ sys.stdout = old_stdout
+ queue.put(mystdout.getvalue())
+ except Exception as e:
+ sys.stdout = old_stdout
+ queue.put(repr(e))
+
+ def run(self, command: str, timeout: Optional[int] = None) -> str:
+ """Run command with own globals/locals and returns anything printed.
+ Timeout after the specified number of seconds."""
+
+ # Warn against dangers of PythonREPL
+ warn_once()
+
+ queue: multiprocessing.Queue = multiprocessing.Queue()
+
+ # Only use multiprocessing if we are enforcing a timeout
+ if timeout is not None:
+ # create a Process
+ p = multiprocessing.Process(
+ target=self.worker, args=(command, self.globals, self.locals, queue)
+ )
+
+ # start it
+ p.start()
+
+ # wait for the process to finish or kill it after timeout seconds
+ p.join(timeout)
+
+ if p.is_alive():
+ p.terminate()
+ return "Execution timed out"
+ else:
+ self.worker(command, self.globals, self.locals, queue)
+ # get the result from the worker function
+ return queue.get()
diff --git a/libs/community/langchain_community/utilities/reddit_search.py b/libs/community/langchain_community/utilities/reddit_search.py
new file mode 100644
index 00000000000..e17971b1b2c
--- /dev/null
+++ b/libs/community/langchain_community/utilities/reddit_search.py
@@ -0,0 +1,121 @@
+"""Wrapper for the Reddit API"""
+
+from typing import Any, Dict, List, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class RedditSearchAPIWrapper(BaseModel):
+ """Wrapper for Reddit API
+
+ To use, set the environment variables ``REDDIT_CLIENT_ID``,
+ ``REDDIT_CLIENT_SECRET``, ``REDDIT_USER_AGENT`` to set the client ID,
+ client secret, and user agent, respectively, as given by Reddit's API.
+ Alternatively, all three can be supplied as named parameters in the
+ constructor: ``reddit_client_id``, ``reddit_client_secret``, and
+ ``reddit_user_agent``, respectively.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import RedditSearchAPIWrapper
+ reddit_search = RedditSearchAPIWrapper()
+ """
+
+ reddit_client: Any
+
+ # Values required to access Reddit API via praw
+ reddit_client_id: Optional[str]
+ reddit_client_secret: Optional[str]
+ reddit_user_agent: Optional[str]
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the API ID, secret and user agent exists in environment
+ and check that praw module is present.
+ """
+ reddit_client_id = get_from_dict_or_env(
+ values, "reddit_client_id", "REDDIT_CLIENT_ID"
+ )
+ values["reddit_client_id"] = reddit_client_id
+
+ reddit_client_secret = get_from_dict_or_env(
+ values, "reddit_client_secret", "REDDIT_CLIENT_SECRET"
+ )
+ values["reddit_client_secret"] = reddit_client_secret
+
+ reddit_user_agent = get_from_dict_or_env(
+ values, "reddit_user_agent", "REDDIT_USER_AGENT"
+ )
+ values["reddit_user_agent"] = reddit_user_agent
+
+ try:
+ import praw
+ except ImportError:
+ raise ImportError(
+ "praw package not found, please install it with pip install praw"
+ )
+
+ reddit_client = praw.Reddit(
+ client_id=reddit_client_id,
+ client_secret=reddit_client_secret,
+ user_agent=reddit_user_agent,
+ )
+ values["reddit_client"] = reddit_client
+
+ return values
+
+ def run(
+ self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
+ ) -> str:
+ """Search Reddit and return posts as a single string."""
+ results: List[Dict] = self.results(
+ query=query,
+ sort=sort,
+ time_filter=time_filter,
+ subreddit=subreddit,
+ limit=limit,
+ )
+ if len(results) > 0:
+ output: List[str] = [f"Searching r/{subreddit} found {len(results)} posts:"]
+ for r in results:
+ category = "N/A" if r["post_category"] is None else r["post_category"]
+ p = f"Post Title: '{r['post_title']}'\n\
+ User: {r['post_author']}\n\
+ Subreddit: {r['post_subreddit']}:\n\
+ Text body: {r['post_text']}\n\
+ Post URL: {r['post_url']}\n\
+ Post Category: {category}.\n\
+ Score: {r['post_score']}\n"
+ output.append(p)
+ return "\n".join(output)
+ else:
+ return f"Searching r/{subreddit} did not find any posts:"
+
+ def results(
+ self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
+ ) -> List[Dict]:
+ """Use praw to search Reddit and return a list of dictionaries,
+ one for each post.
+ """
+ subredditObject = self.reddit_client.subreddit(subreddit)
+ search_results = subredditObject.search(
+ query=query, sort=sort, time_filter=time_filter, limit=limit
+ )
+ search_results = [r for r in search_results]
+ results_object = []
+ for submission in search_results:
+ results_object.append(
+ {
+ "post_subreddit": submission.subreddit_name_prefixed,
+ "post_category": submission.category,
+ "post_title": submission.title,
+ "post_text": submission.selftext,
+ "post_score": submission.score,
+ "post_id": submission.id,
+ "post_url": submission.url,
+ "post_author": submission.author,
+ }
+ )
+ return results_object
diff --git a/libs/community/langchain_community/utilities/redis.py b/libs/community/langchain_community/utilities/redis.py
new file mode 100644
index 00000000000..5b613667b64
--- /dev/null
+++ b/libs/community/langchain_community/utilities/redis.py
@@ -0,0 +1,217 @@
+from __future__ import annotations
+
+import logging
+import re
+from typing import TYPE_CHECKING, Any, List, Optional, Pattern
+from urllib.parse import urlparse
+
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from redis.client import Redis as RedisType
+
+
+def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
+ return np.array(array).astype(dtype).tobytes()
+
+
+def _buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]:
+ return np.frombuffer(buffer, dtype=dtype).tolist()
+
+
+class TokenEscaper:
+ """
+ Escape punctuation within an input string.
+ """
+
+ # Characters that RediSearch requires us to escape during queries.
+ # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
+ DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
+
+ def __init__(self, escape_chars_re: Optional[Pattern] = None):
+ if escape_chars_re:
+ self.escaped_chars_re = escape_chars_re
+ else:
+ self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
+
+ def escape(self, value: str) -> str:
+ if not isinstance(value, str):
+ raise TypeError(
+ "Value must be a string object for token escaping."
+ f"Got type {type(value)}"
+ )
+
+ def escape_symbol(match: re.Match) -> str:
+ value = match.group(0)
+ return f"\\{value}"
+
+ return self.escaped_chars_re.sub(escape_symbol, value)
+
+
+def check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
+ """Check if the correct Redis modules are installed."""
+ installed_modules = client.module_list()
+ installed_modules = {
+ module[b"name"].decode("utf-8"): module for module in installed_modules
+ }
+ for module in required_modules:
+ if module["name"] in installed_modules and int(
+ installed_modules[module["name"]][b"ver"]
+ ) >= int(module["ver"]):
+ return
+ # otherwise raise error
+ error_message = (
+ "Redis cannot be used as a vector database without RediSearch >=2.4"
+ "Please head to https://redis.io/docs/stack/search/quick_start/"
+ "to know more about installing the RediSearch module within Redis Stack."
+ )
+ logger.error(error_message)
+ raise ValueError(error_message)
+
+
+def get_client(redis_url: str, **kwargs: Any) -> RedisType:
+ """Get a redis client from the connection url given. This helper accepts
+ urls for Redis server (TCP with/without TLS or UnixSocket) as well as
+ Redis Sentinel connections.
+
+ Redis Cluster is not supported.
+
+ Before creating a connection the existence of the database driver is checked
+ an and ValueError raised otherwise
+
+ To use, you should have the ``redis`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities.redis import get_client
+ redis_client = get_client(
+ redis_url="redis://username:password@localhost:6379"
+ index_name="my-index",
+ embedding_function=embeddings.embed_query,
+ )
+
+ To use a redis replication setup with multiple redis server and redis sentinels
+ set "redis_url" to "redis+sentinel://" scheme. With this url format a path is
+ needed holding the name of the redis service within the sentinels to get the
+ correct redis server connection. The default service name is "mymaster". The
+ optional second part of the path is the redis db number to connect to.
+
+ An optional username or password is used for booth connections to the rediserver
+ and the sentinel, different passwords for server and sentinel are not supported.
+ And as another constraint only one sentinel instance can be given:
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities.redis import get_client
+ redis_client = get_client(
+ redis_url="redis+sentinel://username:password@sentinelhost:26379/mymaster/0"
+ index_name="my-index",
+ embedding_function=embeddings.embed_query,
+ )
+ """
+
+ # Initialize with necessary components.
+ try:
+ import redis
+ except ImportError:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis>=4.1.0`."
+ )
+
+ # check if normal redis:// or redis+sentinel:// url
+ if redis_url.startswith("redis+sentinel"):
+ redis_client = _redis_sentinel_client(redis_url, **kwargs)
+ elif redis_url.startswith("rediss+sentinel"): # sentinel with TLS support enables
+ kwargs["ssl"] = True
+ if "ssl_cert_reqs" not in kwargs:
+ kwargs["ssl_cert_reqs"] = "none"
+ redis_client = _redis_sentinel_client(redis_url, **kwargs)
+ else:
+ # connect to redis server from url, reconnect with cluster client if needed
+ redis_client = redis.from_url(redis_url, **kwargs)
+ if _check_for_cluster(redis_client):
+ redis_client.close()
+ redis_client = _redis_cluster_client(redis_url, **kwargs)
+ return redis_client
+
+
+def _redis_sentinel_client(redis_url: str, **kwargs: Any) -> RedisType:
+ """helper method to parse an (un-official) redis+sentinel url
+ and create a Sentinel connection to fetch the final redis client
+ connection to a replica-master for read-write operations.
+
+ If username and/or password for authentication is given the
+ same credentials are used for the Redis Sentinel as well as Redis Server.
+ With this implementation using a redis url only it is not possible
+ to use different data for authentication on booth systems.
+ """
+ import redis
+
+ parsed_url = urlparse(redis_url)
+ # sentinel needs list with (host, port) tuple, use default port if none available
+ sentinel_list = [(parsed_url.hostname or "localhost", parsed_url.port or 26379)]
+ if parsed_url.path:
+ # "/mymaster/0" first part is service name, optional second part is db number
+ path_parts = parsed_url.path.split("/")
+ service_name = path_parts[1] or "mymaster"
+ if len(path_parts) > 2:
+ kwargs["db"] = path_parts[2]
+ else:
+ service_name = "mymaster"
+
+ sentinel_args = {}
+ if parsed_url.password:
+ sentinel_args["password"] = parsed_url.password
+ kwargs["password"] = parsed_url.password
+ if parsed_url.username:
+ sentinel_args["username"] = parsed_url.username
+ kwargs["username"] = parsed_url.username
+
+ # check for all SSL related properties and copy them into sentinel_kwargs too,
+ # add client_name also
+ for arg in kwargs:
+ if arg.startswith("ssl") or arg == "client_name":
+ sentinel_args[arg] = kwargs[arg]
+
+ # sentinel user/pass is part of sentinel_kwargs, user/pass for redis server
+ # connection as direct parameter in kwargs
+ sentinel_client = redis.sentinel.Sentinel(
+ sentinel_list, sentinel_kwargs=sentinel_args, **kwargs
+ )
+
+ # redis server might have password but not sentinel - fetch this error and try
+ # again without pass, everything else cannot be handled here -> user needed
+ try:
+ sentinel_client.execute_command("ping")
+ except redis.exceptions.AuthenticationError as ae:
+ if "no password is set" in ae.args[0]:
+ logger.warning(
+ "Redis sentinel connection configured with password but Sentinel \
+answered NO PASSWORD NEEDED - Please check Sentinel configuration"
+ )
+ sentinel_client = redis.sentinel.Sentinel(sentinel_list, **kwargs)
+ else:
+ raise ae
+
+ return sentinel_client.master_for(service_name)
+
+
+def _check_for_cluster(redis_client: RedisType) -> bool:
+ import redis
+
+ try:
+ cluster_info = redis_client.info("cluster")
+ return cluster_info["cluster_enabled"] == 1
+ except redis.exceptions.RedisError:
+ return False
+
+
+def _redis_cluster_client(redis_url: str, **kwargs: Any) -> RedisType:
+ from redis.cluster import RedisCluster
+
+ return RedisCluster.from_url(redis_url, **kwargs)
diff --git a/libs/community/langchain_community/utilities/requests.py b/libs/community/langchain_community/utilities/requests.py
new file mode 100644
index 00000000000..651616ff531
--- /dev/null
+++ b/libs/community/langchain_community/utilities/requests.py
@@ -0,0 +1,180 @@
+"""Lightweight wrapper around requests library, with async support."""
+from contextlib import asynccontextmanager
+from typing import Any, AsyncGenerator, Dict, Optional
+
+import aiohttp
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra
+
+
+class Requests(BaseModel):
+ """Wrapper around requests to handle auth and async.
+
+ The main purpose of this wrapper is to handle authentication (by saving
+ headers) and enable easy async methods on the same base object.
+ """
+
+ headers: Optional[Dict[str, str]] = None
+ aiosession: Optional[aiohttp.ClientSession] = None
+ auth: Optional[Any] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ def get(self, url: str, **kwargs: Any) -> requests.Response:
+ """GET the URL and return the text."""
+ return requests.get(url, headers=self.headers, auth=self.auth, **kwargs)
+
+ def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
+ """POST to the URL and return the text."""
+ return requests.post(
+ url, json=data, headers=self.headers, auth=self.auth, **kwargs
+ )
+
+ def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
+ """PATCH the URL and return the text."""
+ return requests.patch(
+ url, json=data, headers=self.headers, auth=self.auth, **kwargs
+ )
+
+ def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
+ """PUT the URL and return the text."""
+ return requests.put(
+ url, json=data, headers=self.headers, auth=self.auth, **kwargs
+ )
+
+ def delete(self, url: str, **kwargs: Any) -> requests.Response:
+ """DELETE the URL and return the text."""
+ return requests.delete(url, headers=self.headers, auth=self.auth, **kwargs)
+
+ @asynccontextmanager
+ async def _arequest(
+ self, method: str, url: str, **kwargs: Any
+ ) -> AsyncGenerator[aiohttp.ClientResponse, None]:
+ """Make an async request."""
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.request(
+ method, url, headers=self.headers, auth=self.auth, **kwargs
+ ) as response:
+ yield response
+ else:
+ async with self.aiosession.request(
+ method, url, headers=self.headers, auth=self.auth, **kwargs
+ ) as response:
+ yield response
+
+ @asynccontextmanager
+ async def aget(
+ self, url: str, **kwargs: Any
+ ) -> AsyncGenerator[aiohttp.ClientResponse, None]:
+ """GET the URL and return the text asynchronously."""
+ async with self._arequest("GET", url, **kwargs) as response:
+ yield response
+
+ @asynccontextmanager
+ async def apost(
+ self, url: str, data: Dict[str, Any], **kwargs: Any
+ ) -> AsyncGenerator[aiohttp.ClientResponse, None]:
+ """POST to the URL and return the text asynchronously."""
+ async with self._arequest("POST", url, json=data, **kwargs) as response:
+ yield response
+
+ @asynccontextmanager
+ async def apatch(
+ self, url: str, data: Dict[str, Any], **kwargs: Any
+ ) -> AsyncGenerator[aiohttp.ClientResponse, None]:
+ """PATCH the URL and return the text asynchronously."""
+ async with self._arequest("PATCH", url, json=data, **kwargs) as response:
+ yield response
+
+ @asynccontextmanager
+ async def aput(
+ self, url: str, data: Dict[str, Any], **kwargs: Any
+ ) -> AsyncGenerator[aiohttp.ClientResponse, None]:
+ """PUT the URL and return the text asynchronously."""
+ async with self._arequest("PUT", url, json=data, **kwargs) as response:
+ yield response
+
+ @asynccontextmanager
+ async def adelete(
+ self, url: str, **kwargs: Any
+ ) -> AsyncGenerator[aiohttp.ClientResponse, None]:
+ """DELETE the URL and return the text asynchronously."""
+ async with self._arequest("DELETE", url, **kwargs) as response:
+ yield response
+
+
+class TextRequestsWrapper(BaseModel):
+ """Lightweight wrapper around requests library.
+
+ The main purpose of this wrapper is to always return a text output.
+ """
+
+ headers: Optional[Dict[str, str]] = None
+ aiosession: Optional[aiohttp.ClientSession] = None
+ auth: Optional[Any] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ @property
+ def requests(self) -> Requests:
+ return Requests(
+ headers=self.headers, aiosession=self.aiosession, auth=self.auth
+ )
+
+ def get(self, url: str, **kwargs: Any) -> str:
+ """GET the URL and return the text."""
+ return self.requests.get(url, **kwargs).text
+
+ def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
+ """POST to the URL and return the text."""
+ return self.requests.post(url, data, **kwargs).text
+
+ def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
+ """PATCH the URL and return the text."""
+ return self.requests.patch(url, data, **kwargs).text
+
+ def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
+ """PUT the URL and return the text."""
+ return self.requests.put(url, data, **kwargs).text
+
+ def delete(self, url: str, **kwargs: Any) -> str:
+ """DELETE the URL and return the text."""
+ return self.requests.delete(url, **kwargs).text
+
+ async def aget(self, url: str, **kwargs: Any) -> str:
+ """GET the URL and return the text asynchronously."""
+ async with self.requests.aget(url, **kwargs) as response:
+ return await response.text()
+
+ async def apost(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
+ """POST to the URL and return the text asynchronously."""
+ async with self.requests.apost(url, data, **kwargs) as response:
+ return await response.text()
+
+ async def apatch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
+ """PATCH the URL and return the text asynchronously."""
+ async with self.requests.apatch(url, data, **kwargs) as response:
+ return await response.text()
+
+ async def aput(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
+ """PUT the URL and return the text asynchronously."""
+ async with self.requests.aput(url, data, **kwargs) as response:
+ return await response.text()
+
+ async def adelete(self, url: str, **kwargs: Any) -> str:
+ """DELETE the URL and return the text asynchronously."""
+ async with self.requests.adelete(url, **kwargs) as response:
+ return await response.text()
+
+
+# For backwards compatibility
+RequestsWrapper = TextRequestsWrapper
diff --git a/libs/community/langchain_community/utilities/scenexplain.py b/libs/community/langchain_community/utilities/scenexplain.py
new file mode 100644
index 00000000000..f8ffcb41e13
--- /dev/null
+++ b/libs/community/langchain_community/utilities/scenexplain.py
@@ -0,0 +1,65 @@
+"""Util that calls SceneXplain.
+
+In order to set this up, you need API key for the SceneXplain API.
+You can obtain a key by following the steps below.
+- Sign up for a free account at https://scenex.jina.ai/.
+- Navigate to the API Access page (https://scenex.jina.ai/api) and create a new API key.
+"""
+from typing import Dict
+
+import requests
+from langchain_core.pydantic_v1 import BaseModel, BaseSettings, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class SceneXplainAPIWrapper(BaseSettings, BaseModel):
+ """Wrapper for SceneXplain API.
+
+ In order to set this up, you need API key for the SceneXplain API.
+ You can obtain a key by following the steps below.
+ - Sign up for a free account at https://scenex.jina.ai/.
+ - Navigate to the API Access page (https://scenex.jina.ai/api)
+ and create a new API key.
+ """
+
+ scenex_api_key: str = Field(..., env="SCENEX_API_KEY")
+ scenex_api_url: str = "https://api.scenex.jina.ai/v1/describe"
+
+ def _describe_image(self, image: str) -> str:
+ headers = {
+ "x-api-key": f"token {self.scenex_api_key}",
+ "content-type": "application/json",
+ }
+ payload = {
+ "data": [
+ {
+ "image": image,
+ "algorithm": "Ember",
+ "languages": ["en"],
+ }
+ ]
+ }
+ response = requests.post(self.scenex_api_url, headers=headers, json=payload)
+ response.raise_for_status()
+ result = response.json().get("result", [])
+ img = result[0] if result else {}
+
+ return img.get("text", "")
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+ scenex_api_key = get_from_dict_or_env(
+ values, "scenex_api_key", "SCENEX_API_KEY"
+ )
+ values["scenex_api_key"] = scenex_api_key
+
+ return values
+
+ def run(self, image: str) -> str:
+ """Run SceneXplain image explainer."""
+ description = self._describe_image(image)
+ if not description:
+ return "No description found."
+
+ return description
diff --git a/libs/community/langchain_community/utilities/searchapi.py b/libs/community/langchain_community/utilities/searchapi.py
new file mode 100644
index 00000000000..3934b59d0ed
--- /dev/null
+++ b/libs/community/langchain_community/utilities/searchapi.py
@@ -0,0 +1,138 @@
+from typing import Any, Dict, Optional
+
+import aiohttp
+import requests
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class SearchApiAPIWrapper(BaseModel):
+ """
+ Wrapper around SearchApi API.
+
+ To use, you should have the environment variable ``SEARCHAPI_API_KEY``
+ set with your API key, or pass `searchapi_api_key`
+ as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import SearchApiAPIWrapper
+ searchapi = SearchApiAPIWrapper()
+ """
+
+ # Use "google" engine by default.
+ # Full list of supported ones can be found in https://www.searchapi.io docs
+ engine: str = "google"
+ searchapi_api_key: Optional[str] = None
+ aiosession: Optional[aiohttp.ClientSession] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that API key exists in environment."""
+ searchapi_api_key = get_from_dict_or_env(
+ values, "searchapi_api_key", "SEARCHAPI_API_KEY"
+ )
+ values["searchapi_api_key"] = searchapi_api_key
+ return values
+
+ def run(self, query: str, **kwargs: Any) -> str:
+ results = self.results(query, **kwargs)
+ return self._result_as_string(results)
+
+ async def arun(self, query: str, **kwargs: Any) -> str:
+ results = await self.aresults(query, **kwargs)
+ return self._result_as_string(results)
+
+ def results(self, query: str, **kwargs: Any) -> dict:
+ results = self._search_api_results(query, **kwargs)
+ return results
+
+ async def aresults(self, query: str, **kwargs: Any) -> dict:
+ results = await self._async_search_api_results(query, **kwargs)
+ return results
+
+ def _prepare_request(self, query: str, **kwargs: Any) -> dict:
+ return {
+ "url": "https://www.searchapi.io/api/v1/search",
+ "headers": {
+ "Authorization": f"Bearer {self.searchapi_api_key}",
+ },
+ "params": {
+ "engine": self.engine,
+ "q": query,
+ **{key: value for key, value in kwargs.items() if value is not None},
+ },
+ }
+
+ def _search_api_results(self, query: str, **kwargs: Any) -> dict:
+ request_details = self._prepare_request(query, **kwargs)
+ response = requests.get(
+ url=request_details["url"],
+ params=request_details["params"],
+ headers=request_details["headers"],
+ )
+ response.raise_for_status()
+ return response.json()
+
+ async def _async_search_api_results(self, query: str, **kwargs: Any) -> dict:
+ """Use aiohttp to send request to SearchApi API and return results async."""
+ request_details = self._prepare_request(query, **kwargs)
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ url=request_details["url"],
+ headers=request_details["headers"],
+ params=request_details["params"],
+ raise_for_status=True,
+ ) as response:
+ results = await response.json()
+ else:
+ async with self.aiosession.get(
+ url=request_details["url"],
+ headers=request_details["headers"],
+ params=request_details["params"],
+ raise_for_status=True,
+ ) as response:
+ results = await response.json()
+ return results
+
+ @staticmethod
+ def _result_as_string(result: dict) -> str:
+ toret = "No good search result found"
+ if "answer_box" in result.keys() and "answer" in result["answer_box"].keys():
+ toret = result["answer_box"]["answer"]
+ elif "answer_box" in result.keys() and "snippet" in result["answer_box"].keys():
+ toret = result["answer_box"]["snippet"]
+ elif "knowledge_graph" in result.keys():
+ toret = result["knowledge_graph"]["description"]
+ elif "organic_results" in result.keys():
+ snippets = [
+ r["snippet"] for r in result["organic_results"] if "snippet" in r.keys()
+ ]
+ toret = "\n".join(snippets)
+ elif "jobs" in result.keys():
+ jobs = [
+ r["description"] for r in result["jobs"] if "description" in r.keys()
+ ]
+ toret = "\n".join(jobs)
+ elif "videos" in result.keys():
+ videos = [
+ f"""Title: "{r["title"]}" Link: {r["link"]}"""
+ for r in result["videos"]
+ if "title" in r.keys()
+ ]
+ toret = "\n".join(videos)
+ elif "images" in result.keys():
+ images = [
+ f"""Title: "{r["title"]}" Link: {r["original"]["link"]}"""
+ for r in result["images"]
+ if "original" in r.keys()
+ ]
+ toret = "\n".join(images)
+ return toret
diff --git a/libs/community/langchain_community/utilities/searx_search.py b/libs/community/langchain_community/utilities/searx_search.py
new file mode 100644
index 00000000000..3ac9f2cf1c2
--- /dev/null
+++ b/libs/community/langchain_community/utilities/searx_search.py
@@ -0,0 +1,505 @@
+"""Utility for using SearxNG meta search API.
+
+SearxNG is a privacy-friendly free metasearch engine that aggregates results from
+`multiple search engines
+`_ and databases and
+supports the `OpenSearch
+`_
+specification.
+
+More details on the installation instructions `here. <../../integrations/searx.html>`_
+
+For the search API refer to https://docs.searxng.org/dev/search_api.html
+
+Quick Start
+-----------
+
+
+In order to use this utility you need to provide the searx host. This can be done
+by passing the named parameter :attr:`searx_host `
+or exporting the environment variable SEARX_HOST.
+Note: this is the only required parameter.
+
+Then create a searx search instance like this:
+
+ .. code-block:: python
+
+ from langchain_community.utilities import SearxSearchWrapper
+
+ # when the host starts with `http` SSL is disabled and the connection
+ # is assumed to be on a private network
+ searx_host='http://self.hosted'
+
+ search = SearxSearchWrapper(searx_host=searx_host)
+
+
+You can now use the ``search`` instance to query the searx API.
+
+Searching
+---------
+
+Use the :meth:`run() ` and
+:meth:`results() ` methods to query the searx API.
+Other methods are available for convenience.
+
+:class:`SearxResults` is a convenience wrapper around the raw json result.
+
+Example usage of the ``run`` method to make a search:
+
+ .. code-block:: python
+
+ s.run(query="what is the best search engine?")
+
+Engine Parameters
+-----------------
+
+You can pass any `accepted searx search API
+`_ parameters to the
+:py:class:`SearxSearchWrapper` instance.
+
+In the following example we are using the
+:attr:`engines ` and the ``language`` parameters:
+
+ .. code-block:: python
+
+ # assuming the searx host is set as above or exported as an env variable
+ s = SearxSearchWrapper(engines=['google', 'bing'],
+ language='es')
+
+Search Tips
+-----------
+
+Searx offers a special
+`search syntax `_
+that can also be used instead of passing engine parameters.
+
+For example the following query:
+
+ .. code-block:: python
+
+ s = SearxSearchWrapper("langchain library", engines=['github'])
+
+ # can also be written as:
+ s = SearxSearchWrapper("langchain library !github")
+ # or even:
+ s = SearxSearchWrapper("langchain library !gh")
+
+
+In some situations you might want to pass an extra string to the search query.
+For example when the `run()` method is called by an agent. The search suffix can
+also be used as a way to pass extra parameters to searx or the underlying search
+engines.
+
+ .. code-block:: python
+
+ # select the github engine and pass the search suffix
+ s = SearchWrapper("langchain library", query_suffix="!gh")
+
+
+ s = SearchWrapper("langchain library")
+ # select github the conventional google search syntax
+ s.run("large language models", query_suffix="site:github.com")
+
+
+*NOTE*: A search suffix can be defined on both the instance and the method level.
+The resulting query will be the concatenation of the two with the former taking
+precedence.
+
+
+See `SearxNG Configured Engines
+`_ and
+`SearxNG Search Syntax `_
+for more details.
+
+Notes
+-----
+This wrapper is based on the SearxNG fork https://github.com/searxng/searxng which is
+better maintained than the original Searx project and offers more features.
+
+Public searxNG instances often use a rate limiter for API usage, so you might want to
+use a self hosted instance and disable the rate limiter.
+
+If you are self-hosting an instance you can customize the rate limiter for your
+own network as described
+`here `_.
+
+
+For a list of public SearxNG instances see https://searx.space/
+"""
+
+import json
+from typing import Any, Dict, List, Optional
+
+import aiohttp
+import requests
+from langchain_core.pydantic_v1 import (
+ BaseModel,
+ Extra,
+ Field,
+ PrivateAttr,
+ root_validator,
+ validator,
+)
+from langchain_core.utils import get_from_dict_or_env
+
+
+def _get_default_params() -> dict:
+ return {"language": "en", "format": "json"}
+
+
+class SearxResults(dict):
+ """Dict like wrapper around search api results."""
+
+ _data: str = ""
+
+ def __init__(self, data: str):
+ """Take a raw result from Searx and make it into a dict like object."""
+ json_data = json.loads(data)
+ super().__init__(json_data)
+ self.__dict__ = self
+
+ def __str__(self) -> str:
+ """Text representation of searx result."""
+ return self._data
+
+ @property
+ def results(self) -> Any:
+ """Silence mypy for accessing this field.
+
+ :meta private:
+ """
+ return self.get("results")
+
+ @property
+ def answers(self) -> Any:
+ """Helper accessor on the json result."""
+ return self.get("answers")
+
+
+class SearxSearchWrapper(BaseModel):
+ """Wrapper for Searx API.
+
+ To use you need to provide the searx host by passing the named parameter
+ ``searx_host`` or exporting the environment variable ``SEARX_HOST``.
+
+ In some situations you might want to disable SSL verification, for example
+ if you are running searx locally. You can do this by passing the named parameter
+ ``unsecure``. You can also pass the host url scheme as ``http`` to disable SSL.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import SearxSearchWrapper
+ searx = SearxSearchWrapper(searx_host="http://localhost:8888")
+
+ Example with SSL disabled:
+ .. code-block:: python
+
+ from langchain_community.utilities import SearxSearchWrapper
+ # note the unsecure parameter is not needed if you pass the url scheme as
+ # http
+ searx = SearxSearchWrapper(searx_host="http://localhost:8888",
+ unsecure=True)
+
+
+ """
+
+ _result: SearxResults = PrivateAttr()
+ searx_host: str = ""
+ unsecure: bool = False
+ params: dict = Field(default_factory=_get_default_params)
+ headers: Optional[dict] = None
+ engines: Optional[List[str]] = []
+ categories: Optional[List[str]] = []
+ query_suffix: Optional[str] = ""
+ k: int = 10
+ aiosession: Optional[Any] = None
+
+ @validator("unsecure")
+ def disable_ssl_warnings(cls, v: bool) -> bool:
+ """Disable SSL warnings."""
+ if v:
+ # requests.urllib3.disable_warnings()
+ try:
+ import urllib3
+
+ urllib3.disable_warnings()
+ except ImportError as e:
+ print(e)
+
+ return v
+
+ @root_validator()
+ def validate_params(cls, values: Dict) -> Dict:
+ """Validate that custom searx params are merged with default ones."""
+ user_params = values["params"]
+ default = _get_default_params()
+ values["params"] = {**default, **user_params}
+
+ engines = values.get("engines")
+ if engines:
+ values["params"]["engines"] = ",".join(engines)
+
+ categories = values.get("categories")
+ if categories:
+ values["params"]["categories"] = ",".join(categories)
+
+ searx_host = get_from_dict_or_env(values, "searx_host", "SEARX_HOST")
+ if not searx_host.startswith("http"):
+ print(
+ f"Warning: missing the url scheme on host \
+ ! assuming secure https://{searx_host} "
+ )
+ searx_host = "https://" + searx_host
+ elif searx_host.startswith("http://"):
+ values["unsecure"] = True
+ cls.disable_ssl_warnings(True)
+ values["searx_host"] = searx_host
+
+ return values
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _searx_api_query(self, params: dict) -> SearxResults:
+ """Actual request to searx API."""
+ raw_result = requests.get(
+ self.searx_host,
+ headers=self.headers,
+ params=params,
+ verify=not self.unsecure,
+ )
+ # test if http result is ok
+ if not raw_result.ok:
+ raise ValueError("Searx API returned an error: ", raw_result.text)
+ res = SearxResults(raw_result.text)
+ self._result = res
+ return res
+
+ async def _asearx_api_query(self, params: dict) -> SearxResults:
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ self.searx_host,
+ headers=self.headers,
+ params=params,
+ ssl=(lambda: False if self.unsecure else None)(),
+ ) as response:
+ if not response.ok:
+ raise ValueError("Searx API returned an error: ", response.text)
+ result = SearxResults(await response.text())
+ self._result = result
+ else:
+ async with self.aiosession.get(
+ self.searx_host,
+ headers=self.headers,
+ params=params,
+ verify=not self.unsecure,
+ ) as response:
+ if not response.ok:
+ raise ValueError("Searx API returned an error: ", response.text)
+ result = SearxResults(await response.text())
+ self._result = result
+
+ return result
+
+ def run(
+ self,
+ query: str,
+ engines: Optional[List[str]] = None,
+ categories: Optional[List[str]] = None,
+ query_suffix: Optional[str] = "",
+ **kwargs: Any,
+ ) -> str:
+ """Run query through Searx API and parse results.
+
+ You can pass any other params to the searx query API.
+
+ Args:
+ query: The query to search for.
+ query_suffix: Extra suffix appended to the query.
+ engines: List of engines to use for the query.
+ categories: List of categories to use for the query.
+ **kwargs: extra parameters to pass to the searx API.
+
+ Returns:
+ str: The result of the query.
+
+ Raises:
+ ValueError: If an error occurred with the query.
+
+
+ Example:
+ This will make a query to the qwant engine:
+
+ .. code-block:: python
+
+ from langchain_community.utilities import SearxSearchWrapper
+ searx = SearxSearchWrapper(searx_host="http://my.searx.host")
+ searx.run("what is the weather in France ?", engine="qwant")
+
+ # the same result can be achieved using the `!` syntax of searx
+ # to select the engine using `query_suffix`
+ searx.run("what is the weather in France ?", query_suffix="!qwant")
+ """
+ _params = {
+ "q": query,
+ }
+ params = {**self.params, **_params, **kwargs}
+
+ if self.query_suffix and len(self.query_suffix) > 0:
+ params["q"] += " " + self.query_suffix
+
+ if isinstance(query_suffix, str) and len(query_suffix) > 0:
+ params["q"] += " " + query_suffix
+
+ if isinstance(engines, list) and len(engines) > 0:
+ params["engines"] = ",".join(engines)
+
+ if isinstance(categories, list) and len(categories) > 0:
+ params["categories"] = ",".join(categories)
+
+ res = self._searx_api_query(params)
+
+ if len(res.answers) > 0:
+ toret = res.answers[0]
+
+ # only return the content of the results list
+ elif len(res.results) > 0:
+ toret = "\n\n".join([r.get("content", "") for r in res.results[: self.k]])
+ else:
+ toret = "No good search result found"
+
+ return toret
+
+ async def arun(
+ self,
+ query: str,
+ engines: Optional[List[str]] = None,
+ query_suffix: Optional[str] = "",
+ **kwargs: Any,
+ ) -> str:
+ """Asynchronously version of `run`."""
+ _params = {
+ "q": query,
+ }
+ params = {**self.params, **_params, **kwargs}
+
+ if self.query_suffix and len(self.query_suffix) > 0:
+ params["q"] += " " + self.query_suffix
+
+ if isinstance(query_suffix, str) and len(query_suffix) > 0:
+ params["q"] += " " + query_suffix
+
+ if isinstance(engines, list) and len(engines) > 0:
+ params["engines"] = ",".join(engines)
+
+ res = await self._asearx_api_query(params)
+
+ if len(res.answers) > 0:
+ toret = res.answers[0]
+
+ # only return the content of the results list
+ elif len(res.results) > 0:
+ toret = "\n\n".join([r.get("content", "") for r in res.results[: self.k]])
+ else:
+ toret = "No good search result found"
+
+ return toret
+
+ def results(
+ self,
+ query: str,
+ num_results: int,
+ engines: Optional[List[str]] = None,
+ categories: Optional[List[str]] = None,
+ query_suffix: Optional[str] = "",
+ **kwargs: Any,
+ ) -> List[Dict]:
+ """Run query through Searx API and returns the results with metadata.
+
+ Args:
+ query: The query to search for.
+ query_suffix: Extra suffix appended to the query.
+ num_results: Limit the number of results to return.
+ engines: List of engines to use for the query.
+ categories: List of categories to use for the query.
+ **kwargs: extra parameters to pass to the searx API.
+
+ Returns:
+ Dict with the following keys:
+ {
+ snippet: The description of the result.
+ title: The title of the result.
+ link: The link to the result.
+ engines: The engines used for the result.
+ category: Searx category of the result.
+ }
+
+ """
+ _params = {
+ "q": query,
+ }
+ params = {**self.params, **_params, **kwargs}
+ if self.query_suffix and len(self.query_suffix) > 0:
+ params["q"] += " " + self.query_suffix
+ if isinstance(query_suffix, str) and len(query_suffix) > 0:
+ params["q"] += " " + query_suffix
+ if isinstance(engines, list) and len(engines) > 0:
+ params["engines"] = ",".join(engines)
+ if isinstance(categories, list) and len(categories) > 0:
+ params["categories"] = ",".join(categories)
+ results = self._searx_api_query(params).results[:num_results]
+ if len(results) == 0:
+ return [{"Result": "No good Search Result was found"}]
+
+ return [
+ {
+ "snippet": result.get("content", ""),
+ "title": result["title"],
+ "link": result["url"],
+ "engines": result["engines"],
+ "category": result["category"],
+ }
+ for result in results
+ ]
+
+ async def aresults(
+ self,
+ query: str,
+ num_results: int,
+ engines: Optional[List[str]] = None,
+ query_suffix: Optional[str] = "",
+ **kwargs: Any,
+ ) -> List[Dict]:
+ """Asynchronously query with json results.
+
+ Uses aiohttp. See `results` for more info.
+ """
+ _params = {
+ "q": query,
+ }
+ params = {**self.params, **_params, **kwargs}
+
+ if self.query_suffix and len(self.query_suffix) > 0:
+ params["q"] += " " + self.query_suffix
+ if isinstance(query_suffix, str) and len(query_suffix) > 0:
+ params["q"] += " " + query_suffix
+ if isinstance(engines, list) and len(engines) > 0:
+ params["engines"] = ",".join(engines)
+ results = (await self._asearx_api_query(params)).results[:num_results]
+ if len(results) == 0:
+ return [{"Result": "No good Search Result was found"}]
+
+ return [
+ {
+ "snippet": result.get("content", ""),
+ "title": result["title"],
+ "link": result["url"],
+ "engines": result["engines"],
+ "category": result["category"],
+ }
+ for result in results
+ ]
diff --git a/libs/community/langchain_community/utilities/serpapi.py b/libs/community/langchain_community/utilities/serpapi.py
new file mode 100644
index 00000000000..f78b4afc90b
--- /dev/null
+++ b/libs/community/langchain_community/utilities/serpapi.py
@@ -0,0 +1,220 @@
+"""Chain that calls SerpAPI.
+
+Heavily borrowed from https://github.com/ofirpress/self-ask
+"""
+import os
+import sys
+from typing import Any, Dict, Optional, Tuple
+
+import aiohttp
+from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class HiddenPrints:
+ """Context manager to hide prints."""
+
+ def __enter__(self) -> None:
+ """Open file to pipe stdout to."""
+ self._original_stdout = sys.stdout
+ sys.stdout = open(os.devnull, "w")
+
+ def __exit__(self, *_: Any) -> None:
+ """Close file that stdout was piped to."""
+ sys.stdout.close()
+ sys.stdout = self._original_stdout
+
+
+class SerpAPIWrapper(BaseModel):
+ """Wrapper around SerpAPI.
+
+ To use, you should have the ``google-search-results`` python package installed,
+ and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
+ `serpapi_api_key` as a named parameter to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import SerpAPIWrapper
+ serpapi = SerpAPIWrapper()
+ """
+
+ search_engine: Any #: :meta private:
+ params: dict = Field(
+ default={
+ "engine": "google",
+ "google_domain": "google.com",
+ "gl": "us",
+ "hl": "en",
+ }
+ )
+ serpapi_api_key: Optional[str] = None
+ aiosession: Optional[aiohttp.ClientSession] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ serpapi_api_key = get_from_dict_or_env(
+ values, "serpapi_api_key", "SERPAPI_API_KEY"
+ )
+ values["serpapi_api_key"] = serpapi_api_key
+ try:
+ from serpapi import GoogleSearch
+
+ values["search_engine"] = GoogleSearch
+ except ImportError:
+ raise ValueError(
+ "Could not import serpapi python package. "
+ "Please install it with `pip install google-search-results`."
+ )
+ return values
+
+ async def arun(self, query: str, **kwargs: Any) -> str:
+ """Run query through SerpAPI and parse result async."""
+ return self._process_response(await self.aresults(query))
+
+ def run(self, query: str, **kwargs: Any) -> str:
+ """Run query through SerpAPI and parse result."""
+ return self._process_response(self.results(query))
+
+ def results(self, query: str) -> dict:
+ """Run query through SerpAPI and return the raw result."""
+ params = self.get_params(query)
+ with HiddenPrints():
+ search = self.search_engine(params)
+ res = search.get_dict()
+ return res
+
+ async def aresults(self, query: str) -> dict:
+ """Use aiohttp to run query through SerpAPI and return the results async."""
+
+ def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
+ params = self.get_params(query)
+ params["source"] = "python"
+ if self.serpapi_api_key:
+ params["serp_api_key"] = self.serpapi_api_key
+ params["output"] = "json"
+ url = "https://serpapi.com/search"
+ return url, params
+
+ url, params = construct_url_and_params()
+ if not self.aiosession:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(url, params=params) as response:
+ res = await response.json()
+ else:
+ async with self.aiosession.get(url, params=params) as response:
+ res = await response.json()
+
+ return res
+
+ def get_params(self, query: str) -> Dict[str, str]:
+ """Get parameters for SerpAPI."""
+ _params = {
+ "api_key": self.serpapi_api_key,
+ "q": query,
+ }
+ params = {**self.params, **_params}
+ return params
+
+ @staticmethod
+ def _process_response(res: dict) -> str:
+ """Process response from SerpAPI."""
+ if "error" in res.keys():
+ raise ValueError(f"Got error from SerpAPI: {res['error']}")
+ if "answer_box_list" in res.keys():
+ res["answer_box"] = res["answer_box_list"]
+ if "answer_box" in res.keys():
+ answer_box = res["answer_box"]
+ if isinstance(answer_box, list):
+ answer_box = answer_box[0]
+ if "result" in answer_box.keys():
+ return answer_box["result"]
+ elif "answer" in answer_box.keys():
+ return answer_box["answer"]
+ elif "snippet" in answer_box.keys():
+ return answer_box["snippet"]
+ elif "snippet_highlighted_words" in answer_box.keys():
+ return answer_box["snippet_highlighted_words"]
+ else:
+ answer = {}
+ for key, value in answer_box.items():
+ if not isinstance(value, (list, dict)) and not (
+ isinstance(value, str) and value.startswith("http")
+ ):
+ answer[key] = value
+ return str(answer)
+ elif "events_results" in res.keys():
+ return res["events_results"][:10]
+ elif "sports_results" in res.keys():
+ return res["sports_results"]
+ elif "top_stories" in res.keys():
+ return res["top_stories"]
+ elif "news_results" in res.keys():
+ return res["news_results"]
+ elif "jobs_results" in res.keys() and "jobs" in res["jobs_results"].keys():
+ return res["jobs_results"]["jobs"]
+ elif (
+ "shopping_results" in res.keys()
+ and "title" in res["shopping_results"][0].keys()
+ ):
+ return res["shopping_results"][:3]
+ elif "questions_and_answers" in res.keys():
+ return res["questions_and_answers"]
+ elif (
+ "popular_destinations" in res.keys()
+ and "destinations" in res["popular_destinations"].keys()
+ ):
+ return res["popular_destinations"]["destinations"]
+ elif "top_sights" in res.keys() and "sights" in res["top_sights"].keys():
+ return res["top_sights"]["sights"]
+ elif (
+ "images_results" in res.keys()
+ and "thumbnail" in res["images_results"][0].keys()
+ ):
+ return str([item["thumbnail"] for item in res["images_results"][:10]])
+
+ snippets = []
+ if "knowledge_graph" in res.keys():
+ knowledge_graph = res["knowledge_graph"]
+ title = knowledge_graph["title"] if "title" in knowledge_graph else ""
+ if "description" in knowledge_graph.keys():
+ snippets.append(knowledge_graph["description"])
+ for key, value in knowledge_graph.items():
+ if (
+ isinstance(key, str)
+ and isinstance(value, str)
+ and key not in ["title", "description"]
+ and not key.endswith("_stick")
+ and not key.endswith("_link")
+ and not value.startswith("http")
+ ):
+ snippets.append(f"{title} {key}: {value}.")
+
+ for organic_result in res.get("organic_results", []):
+ if "snippet" in organic_result.keys():
+ snippets.append(organic_result["snippet"])
+ elif "snippet_highlighted_words" in organic_result.keys():
+ snippets.append(organic_result["snippet_highlighted_words"])
+ elif "rich_snippet" in organic_result.keys():
+ snippets.append(organic_result["rich_snippet"])
+ elif "rich_snippet_table" in organic_result.keys():
+ snippets.append(organic_result["rich_snippet_table"])
+ elif "link" in organic_result.keys():
+ snippets.append(organic_result["link"])
+
+ if "buying_guide" in res.keys():
+ snippets.append(res["buying_guide"])
+ if "local_results" in res.keys() and "places" in res["local_results"].keys():
+ snippets.append(res["local_results"]["places"])
+
+ if len(snippets) > 0:
+ return str(snippets)
+ else:
+ return "No good search result found"
diff --git a/libs/community/langchain_community/utilities/spark_sql.py b/libs/community/langchain_community/utilities/spark_sql.py
new file mode 100644
index 00000000000..20c1e8e5b2f
--- /dev/null
+++ b/libs/community/langchain_community/utilities/spark_sql.py
@@ -0,0 +1,186 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional
+
+if TYPE_CHECKING:
+ from pyspark.sql import DataFrame, Row, SparkSession
+
+
+class SparkSQL:
+ """SparkSQL is a utility class for interacting with Spark SQL."""
+
+ def __init__(
+ self,
+ spark_session: Optional[SparkSession] = None,
+ catalog: Optional[str] = None,
+ schema: Optional[str] = None,
+ ignore_tables: Optional[List[str]] = None,
+ include_tables: Optional[List[str]] = None,
+ sample_rows_in_table_info: int = 3,
+ ):
+ """Initialize a SparkSQL object.
+
+ Args:
+ spark_session: A SparkSession object.
+ If not provided, one will be created.
+ catalog: The catalog to use.
+ If not provided, the default catalog will be used.
+ schema: The schema to use.
+ If not provided, the default schema will be used.
+ ignore_tables: A list of tables to ignore.
+ If not provided, all tables will be used.
+ include_tables: A list of tables to include.
+ If not provided, all tables will be used.
+ sample_rows_in_table_info: The number of rows to include in the table info.
+ Defaults to 3.
+ """
+ try:
+ from pyspark.sql import SparkSession
+ except ImportError:
+ raise ImportError(
+ "pyspark is not installed. Please install it with `pip install pyspark`"
+ )
+
+ self._spark = (
+ spark_session if spark_session else SparkSession.builder.getOrCreate()
+ )
+ if catalog is not None:
+ self._spark.catalog.setCurrentCatalog(catalog)
+ if schema is not None:
+ self._spark.catalog.setCurrentDatabase(schema)
+
+ self._all_tables = set(self._get_all_table_names())
+ self._include_tables = set(include_tables) if include_tables else set()
+ if self._include_tables:
+ missing_tables = self._include_tables - self._all_tables
+ if missing_tables:
+ raise ValueError(
+ f"include_tables {missing_tables} not found in database"
+ )
+ self._ignore_tables = set(ignore_tables) if ignore_tables else set()
+ if self._ignore_tables:
+ missing_tables = self._ignore_tables - self._all_tables
+ if missing_tables:
+ raise ValueError(
+ f"ignore_tables {missing_tables} not found in database"
+ )
+ usable_tables = self.get_usable_table_names()
+ self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
+
+ if not isinstance(sample_rows_in_table_info, int):
+ raise TypeError("sample_rows_in_table_info must be an integer")
+
+ self._sample_rows_in_table_info = sample_rows_in_table_info
+
+ @classmethod
+ def from_uri(
+ cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
+ ) -> SparkSQL:
+ """Creating a remote Spark Session via Spark connect.
+ For example: SparkSQL.from_uri("sc://localhost:15002")
+ """
+ try:
+ from pyspark.sql import SparkSession
+ except ImportError:
+ raise ValueError(
+ "pyspark is not installed. Please install it with `pip install pyspark`"
+ )
+
+ spark = SparkSession.builder.remote(database_uri).getOrCreate()
+ return cls(spark, **kwargs)
+
+ def get_usable_table_names(self) -> Iterable[str]:
+ """Get names of tables available."""
+ if self._include_tables:
+ return self._include_tables
+ # sorting the result can help LLM understanding it.
+ return sorted(self._all_tables - self._ignore_tables)
+
+ def _get_all_table_names(self) -> Iterable[str]:
+ rows = self._spark.sql("SHOW TABLES").select("tableName").collect()
+ return list(map(lambda row: row.tableName, rows))
+
+ def _get_create_table_stmt(self, table: str) -> str:
+ statement = (
+ self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt
+ )
+ # Ignore the data source provider and options to reduce the number of tokens.
+ using_clause_index = statement.find("USING")
+ return statement[:using_clause_index] + ";"
+
+ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
+ all_table_names = self.get_usable_table_names()
+ if table_names is not None:
+ missing_tables = set(table_names).difference(all_table_names)
+ if missing_tables:
+ raise ValueError(f"table_names {missing_tables} not found in database")
+ all_table_names = table_names
+ tables = []
+ for table_name in all_table_names:
+ table_info = self._get_create_table_stmt(table_name)
+ if self._sample_rows_in_table_info:
+ table_info += "\n\n/*"
+ table_info += f"\n{self._get_sample_spark_rows(table_name)}\n"
+ table_info += "*/"
+ tables.append(table_info)
+ final_str = "\n\n".join(tables)
+ return final_str
+
+ def _get_sample_spark_rows(self, table: str) -> str:
+ query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"
+ df = self._spark.sql(query)
+ columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields)))
+ try:
+ sample_rows = self._get_dataframe_results(df)
+ # save the sample rows in string format
+ sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
+ except Exception:
+ sample_rows_str = ""
+
+ return (
+ f"{self._sample_rows_in_table_info} rows from {table} table:\n"
+ f"{columns_str}\n"
+ f"{sample_rows_str}"
+ )
+
+ def _convert_row_as_tuple(self, row: Row) -> tuple:
+ return tuple(map(str, row.asDict().values()))
+
+ def _get_dataframe_results(self, df: DataFrame) -> list:
+ return list(map(self._convert_row_as_tuple, df.collect()))
+
+ def run(self, command: str, fetch: str = "all") -> str:
+ df = self._spark.sql(command)
+ if fetch == "one":
+ df = df.limit(1)
+ return str(self._get_dataframe_results(df))
+
+ def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
+ """Get information about specified tables.
+
+ Follows best practices as specified in: Rajkumar et al, 2022
+ (https://arxiv.org/abs/2204.00498)
+
+ If `sample_rows_in_table_info`, the specified number of sample rows will be
+ appended to each table description. This can increase performance as
+ demonstrated in the paper.
+ """
+ try:
+ return self.get_table_info(table_names)
+ except ValueError as e:
+ """Format the error message"""
+ return f"Error: {e}"
+
+ def run_no_throw(self, command: str, fetch: str = "all") -> str:
+ """Execute a SQL command and return a string representing the results.
+
+ If the statement returns rows, a string of the results is returned.
+ If the statement returns no rows, an empty string is returned.
+
+ If the statement throws an error, the error message is returned.
+ """
+ try:
+ return self.run(command, fetch)
+ except Exception as e:
+ """Format the error message"""
+ return f"Error: {e}"
diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py
new file mode 100644
index 00000000000..969d682a541
--- /dev/null
+++ b/libs/community/langchain_community/utilities/sql_database.py
@@ -0,0 +1,476 @@
+"""SQLAlchemy wrapper around a database."""
+from __future__ import annotations
+
+import warnings
+from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
+
+import sqlalchemy
+from langchain_core.utils import get_from_env
+from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
+from sqlalchemy.engine import Engine
+from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
+from sqlalchemy.schema import CreateTable
+from sqlalchemy.types import NullType
+
+
+def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
+ return (
+ f'Name: {index["name"]}, Unique: {index["unique"]},'
+ f' Columns: {str(index["column_names"])}'
+ )
+
+
+def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
+ """
+ Truncate a string to a certain number of words, based on the max string
+ length.
+ """
+
+ if not isinstance(content, str) or length <= 0:
+ return content
+
+ if len(content) <= length:
+ return content
+
+ return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
+
+
+class SQLDatabase:
+ """SQLAlchemy wrapper around a database."""
+
+ def __init__(
+ self,
+ engine: Engine,
+ schema: Optional[str] = None,
+ metadata: Optional[MetaData] = None,
+ ignore_tables: Optional[List[str]] = None,
+ include_tables: Optional[List[str]] = None,
+ sample_rows_in_table_info: int = 3,
+ indexes_in_table_info: bool = False,
+ custom_table_info: Optional[dict] = None,
+ view_support: bool = False,
+ max_string_length: int = 300,
+ ):
+ """Create engine from database URI."""
+ self._engine = engine
+ self._schema = schema
+ if include_tables and ignore_tables:
+ raise ValueError("Cannot specify both include_tables and ignore_tables")
+
+ self._inspector = inspect(self._engine)
+
+ # including view support by adding the views as well as tables to the all
+ # tables list if view_support is True
+ self._all_tables = set(
+ self._inspector.get_table_names(schema=schema)
+ + (self._inspector.get_view_names(schema=schema) if view_support else [])
+ )
+
+ self._include_tables = set(include_tables) if include_tables else set()
+ if self._include_tables:
+ missing_tables = self._include_tables - self._all_tables
+ if missing_tables:
+ raise ValueError(
+ f"include_tables {missing_tables} not found in database"
+ )
+ self._ignore_tables = set(ignore_tables) if ignore_tables else set()
+ if self._ignore_tables:
+ missing_tables = self._ignore_tables - self._all_tables
+ if missing_tables:
+ raise ValueError(
+ f"ignore_tables {missing_tables} not found in database"
+ )
+ usable_tables = self.get_usable_table_names()
+ self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
+
+ if not isinstance(sample_rows_in_table_info, int):
+ raise TypeError("sample_rows_in_table_info must be an integer")
+
+ self._sample_rows_in_table_info = sample_rows_in_table_info
+ self._indexes_in_table_info = indexes_in_table_info
+
+ self._custom_table_info = custom_table_info
+ if self._custom_table_info:
+ if not isinstance(self._custom_table_info, dict):
+ raise TypeError(
+ "table_info must be a dictionary with table names as keys and the "
+ "desired table info as values"
+ )
+ # only keep the tables that are also present in the database
+ intersection = set(self._custom_table_info).intersection(self._all_tables)
+ self._custom_table_info = dict(
+ (table, self._custom_table_info[table])
+ for table in self._custom_table_info
+ if table in intersection
+ )
+
+ self._max_string_length = max_string_length
+
+ self._metadata = metadata or MetaData()
+ # including view support if view_support = true
+ self._metadata.reflect(
+ views=view_support,
+ bind=self._engine,
+ only=list(self._usable_tables),
+ schema=self._schema,
+ )
+
+ @classmethod
+ def from_uri(
+ cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
+ ) -> SQLDatabase:
+ """Construct a SQLAlchemy engine from URI."""
+ _engine_args = engine_args or {}
+ return cls(create_engine(database_uri, **_engine_args), **kwargs)
+
+ @classmethod
+ def from_databricks(
+ cls,
+ catalog: str,
+ schema: str,
+ host: Optional[str] = None,
+ api_token: Optional[str] = None,
+ warehouse_id: Optional[str] = None,
+ cluster_id: Optional[str] = None,
+ engine_args: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> SQLDatabase:
+ """
+ Class method to create an SQLDatabase instance from a Databricks connection.
+ This method requires the 'databricks-sql-connector' package. If not installed,
+ it can be added using `pip install databricks-sql-connector`.
+
+ Args:
+ catalog (str): The catalog name in the Databricks database.
+ schema (str): The schema name in the catalog.
+ host (Optional[str]): The Databricks workspace hostname, excluding
+ 'https://' part. If not provided, it attempts to fetch from the
+ environment variable 'DATABRICKS_HOST'. If still unavailable and if
+ running in a Databricks notebook, it defaults to the current workspace
+ hostname. Defaults to None.
+ api_token (Optional[str]): The Databricks personal access token for
+ accessing the Databricks SQL warehouse or the cluster. If not provided,
+ it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable
+ and running in a Databricks notebook, a temporary token for the current
+ user is generated. Defaults to None.
+ warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If
+ provided, the method configures the connection to use this warehouse.
+ Cannot be used with 'cluster_id'. Defaults to None.
+ cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If
+ provided, the method configures the connection to use this cluster.
+ Cannot be used with 'warehouse_id'. If running in a Databricks notebook
+ and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the
+ cluster the notebook is attached to. Defaults to None.
+ engine_args (Optional[dict]): The arguments to be used when connecting
+ Databricks. Defaults to None.
+ **kwargs (Any): Additional keyword arguments for the `from_uri` method.
+
+ Returns:
+ SQLDatabase: An instance of SQLDatabase configured with the provided
+ Databricks connection details.
+
+ Raises:
+ ValueError: If 'databricks-sql-connector' is not found, or if both
+ 'warehouse_id' and 'cluster_id' are provided, or if neither
+ 'warehouse_id' nor 'cluster_id' are provided and it's not executing
+ inside a Databricks notebook.
+ """
+ try:
+ from databricks import sql # noqa: F401
+ except ImportError:
+ raise ValueError(
+ "databricks-sql-connector package not found, please install with"
+ " `pip install databricks-sql-connector`"
+ )
+ context = None
+ try:
+ from dbruntime.databricks_repl_context import get_context
+
+ context = get_context()
+ except ImportError:
+ pass
+
+ default_host = context.browserHostName if context else None
+ if host is None:
+ host = get_from_env("host", "DATABRICKS_HOST", default_host)
+
+ default_api_token = context.apiToken if context else None
+ if api_token is None:
+ api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
+
+ if warehouse_id is None and cluster_id is None:
+ if context:
+ cluster_id = context.clusterId
+ else:
+ raise ValueError(
+ "Need to provide either 'warehouse_id' or 'cluster_id'."
+ )
+
+ if warehouse_id and cluster_id:
+ raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
+
+ if warehouse_id:
+ http_path = f"/sql/1.0/warehouses/{warehouse_id}"
+ else:
+ http_path = f"/sql/protocolv1/o/0/{cluster_id}"
+
+ uri = (
+ f"databricks://token:{api_token}@{host}?"
+ f"http_path={http_path}&catalog={catalog}&schema={schema}"
+ )
+ return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
+
+ @classmethod
+ def from_cnosdb(
+ cls,
+ url: str = "127.0.0.1:8902",
+ user: str = "root",
+ password: str = "",
+ tenant: str = "cnosdb",
+ database: str = "public",
+ ) -> SQLDatabase:
+ """
+ Class method to create an SQLDatabase instance from a CnosDB connection.
+ This method requires the 'cnos-connector' package. If not installed, it
+ can be added using `pip install cnos-connector`.
+
+ Args:
+ url (str): The HTTP connection host name and port number of the CnosDB
+ service, excluding "http://" or "https://", with a default value
+ of "127.0.0.1:8902".
+ user (str): The username used to connect to the CnosDB service, with a
+ default value of "root".
+ password (str): The password of the user connecting to the CnosDB service,
+ with a default value of "".
+ tenant (str): The name of the tenant used to connect to the CnosDB service,
+ with a default value of "cnosdb".
+ database (str): The name of the database in the CnosDB tenant.
+
+ Returns:
+ SQLDatabase: An instance of SQLDatabase configured with the provided
+ CnosDB connection details.
+ """
+ try:
+ from cnosdb_connector import make_cnosdb_langchain_uri
+
+ uri = make_cnosdb_langchain_uri(url, user, password, tenant, database)
+ return cls.from_uri(database_uri=uri)
+ except ImportError:
+ raise ValueError(
+ "cnos-connector package not found, please install with"
+ " `pip install cnos-connector`"
+ )
+
+ @property
+ def dialect(self) -> str:
+ """Return string representation of dialect to use."""
+ return self._engine.dialect.name
+
+ def get_usable_table_names(self) -> Iterable[str]:
+ """Get names of tables available."""
+ if self._include_tables:
+ return sorted(self._include_tables)
+ return sorted(self._all_tables - self._ignore_tables)
+
+ def get_table_names(self) -> Iterable[str]:
+ """Get names of tables available."""
+ warnings.warn(
+ "This method is deprecated - please use `get_usable_table_names`."
+ )
+ return self.get_usable_table_names()
+
+ @property
+ def table_info(self) -> str:
+ """Information about all tables in the database."""
+ return self.get_table_info()
+
+ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
+ """Get information about specified tables.
+
+ Follows best practices as specified in: Rajkumar et al, 2022
+ (https://arxiv.org/abs/2204.00498)
+
+ If `sample_rows_in_table_info`, the specified number of sample rows will be
+ appended to each table description. This can increase performance as
+ demonstrated in the paper.
+ """
+ all_table_names = self.get_usable_table_names()
+ if table_names is not None:
+ missing_tables = set(table_names).difference(all_table_names)
+ if missing_tables:
+ raise ValueError(f"table_names {missing_tables} not found in database")
+ all_table_names = table_names
+
+ meta_tables = [
+ tbl
+ for tbl in self._metadata.sorted_tables
+ if tbl.name in set(all_table_names)
+ and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
+ ]
+
+ tables = []
+ for table in meta_tables:
+ if self._custom_table_info and table.name in self._custom_table_info:
+ tables.append(self._custom_table_info[table.name])
+ continue
+
+ # Ignore JSON datatyped columns
+ for k, v in table.columns.items():
+ if type(v.type) is NullType:
+ table._columns.remove(v)
+
+ # add create table command
+ create_table = str(CreateTable(table).compile(self._engine))
+ table_info = f"{create_table.rstrip()}"
+ has_extra_info = (
+ self._indexes_in_table_info or self._sample_rows_in_table_info
+ )
+ if has_extra_info:
+ table_info += "\n\n/*"
+ if self._indexes_in_table_info:
+ table_info += f"\n{self._get_table_indexes(table)}\n"
+ if self._sample_rows_in_table_info:
+ table_info += f"\n{self._get_sample_rows(table)}\n"
+ if has_extra_info:
+ table_info += "*/"
+ tables.append(table_info)
+ tables.sort()
+ final_str = "\n\n".join(tables)
+ return final_str
+
+ def _get_table_indexes(self, table: Table) -> str:
+ indexes = self._inspector.get_indexes(table.name)
+ indexes_formatted = "\n".join(map(_format_index, indexes))
+ return f"Table Indexes:\n{indexes_formatted}"
+
+ def _get_sample_rows(self, table: Table) -> str:
+ # build the select command
+ command = select(table).limit(self._sample_rows_in_table_info)
+
+ # save the columns in string format
+ columns_str = "\t".join([col.name for col in table.columns])
+
+ try:
+ # get the sample rows
+ with self._engine.connect() as connection:
+ sample_rows_result = connection.execute(command) # type: ignore
+ # shorten values in the sample rows
+ sample_rows = list(
+ map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
+ )
+
+ # save the sample rows in string format
+ sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
+
+ # in some dialects when there are no rows in the table a
+ # 'ProgrammingError' is returned
+ except ProgrammingError:
+ sample_rows_str = ""
+
+ return (
+ f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
+ f"{columns_str}\n"
+ f"{sample_rows_str}"
+ )
+
+ def _execute(
+ self,
+ command: str,
+ fetch: Union[Literal["all"], Literal["one"]] = "all",
+ ) -> Sequence[Dict[str, Any]]:
+ """
+ Executes SQL command through underlying engine.
+
+ If the statement returns no rows, an empty list is returned.
+ """
+ with self._engine.begin() as connection:
+ if self._schema is not None:
+ if self.dialect == "snowflake":
+ connection.exec_driver_sql(
+ "ALTER SESSION SET search_path = %s", (self._schema,)
+ )
+ elif self.dialect == "bigquery":
+ connection.exec_driver_sql("SET @@dataset_id=?", (self._schema,))
+ elif self.dialect == "mssql":
+ pass
+ elif self.dialect == "trino":
+ connection.exec_driver_sql("USE ?", (self._schema,))
+ elif self.dialect == "duckdb":
+ # Unclear which parameterized argument syntax duckdb supports.
+ # The docs for the duckdb client say they support multiple,
+ # but `duckdb_engine` seemed to struggle with all of them:
+ # https://github.com/Mause/duckdb_engine/issues/796
+ connection.exec_driver_sql(f"SET search_path TO {self._schema}")
+ elif self.dialect == "oracle":
+ connection.exec_driver_sql(
+ f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}"
+ )
+ else: # postgresql and other compatible dialects
+ connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
+ cursor = connection.execute(text(command))
+ if cursor.returns_rows:
+ if fetch == "all":
+ result = [x._asdict() for x in cursor.fetchall()]
+ elif fetch == "one":
+ first_result = cursor.fetchone()
+ result = [] if first_result is None else [first_result._asdict()]
+ else:
+ raise ValueError("Fetch parameter must be either 'one' or 'all'")
+ return result
+ return []
+
+ def run(
+ self,
+ command: str,
+ fetch: Union[Literal["all"], Literal["one"]] = "all",
+ ) -> str:
+ """Execute a SQL command and return a string representing the results.
+
+ If the statement returns rows, a string of the results is returned.
+ If the statement returns no rows, an empty string is returned.
+ """
+ result = self._execute(command, fetch)
+ # Convert columns values to string to avoid issues with sqlalchemy
+ # truncating text
+ res = [
+ tuple(truncate_word(c, length=self._max_string_length) for c in r.values())
+ for r in result
+ ]
+ if not res:
+ return ""
+ else:
+ return str(res)
+
+ def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
+ """Get information about specified tables.
+
+ Follows best practices as specified in: Rajkumar et al, 2022
+ (https://arxiv.org/abs/2204.00498)
+
+ If `sample_rows_in_table_info`, the specified number of sample rows will be
+ appended to each table description. This can increase performance as
+ demonstrated in the paper.
+ """
+ try:
+ return self.get_table_info(table_names)
+ except ValueError as e:
+ """Format the error message"""
+ return f"Error: {e}"
+
+ def run_no_throw(
+ self,
+ command: str,
+ fetch: Union[Literal["all"], Literal["one"]] = "all",
+ ) -> str:
+ """Execute a SQL command and return a string representing the results.
+
+ If the statement returns rows, a string of the results is returned.
+ If the statement returns no rows, an empty string is returned.
+
+ If the statement throws an error, the error message is returned.
+ """
+ try:
+ return self.run(command, fetch)
+ except SQLAlchemyError as e:
+ """Format the error message"""
+ return f"Error: {e}"
diff --git a/libs/community/langchain_community/utilities/stackexchange.py b/libs/community/langchain_community/utilities/stackexchange.py
new file mode 100644
index 00000000000..6ad72b42541
--- /dev/null
+++ b/libs/community/langchain_community/utilities/stackexchange.py
@@ -0,0 +1,68 @@
+import html
+from typing import Any, Dict, Literal
+
+from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
+
+
+class StackExchangeAPIWrapper(BaseModel):
+ """Wrapper for Stack Exchange API."""
+
+ client: Any #: :meta private:
+ max_results: int = 3
+ """Max number of results to include in output."""
+ query_type: Literal["all", "title", "body"] = "all"
+ """Which part of StackOverflows items to match against. One of 'all', 'title',
+ 'body'. Defaults to 'all'.
+ """
+ fetch_params: Dict[str, Any] = Field(default_factory=dict)
+ """Additional params to pass to StackApi.fetch."""
+ result_separator: str = "\n\n"
+ """Separator between question,answer pairs."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the required Python package exists."""
+ try:
+ from stackapi import StackAPI
+
+ values["client"] = StackAPI("stackoverflow")
+ except ImportError:
+ raise ImportError(
+ "The 'stackapi' Python package is not installed. "
+ "Please install it with `pip install stackapi`."
+ )
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through StackExchange API and parse results."""
+
+ query_key = "q" if self.query_type == "all" else self.query_type
+ output = self.client.fetch(
+ "search/excerpts", **{query_key: query}, **self.fetch_params
+ )
+ if len(output["items"]) < 1:
+ return f"No relevant results found for '{query}' on Stack Overflow."
+ questions = [
+ item for item in output["items"] if item["item_type"] == "question"
+ ][: self.max_results]
+ answers = [item for item in output["items"] if item["item_type"] == "answer"]
+ results = []
+ for question in questions:
+ res_text = f"Question: {question['title']}\n{question['excerpt']}"
+ relevant_answers = [
+ answer
+ for answer in answers
+ if answer["question_id"] == question["question_id"]
+ ]
+ accepted_answers = [
+ answer for answer in relevant_answers if answer["is_accepted"]
+ ]
+ if relevant_answers:
+ top_answer = (
+ accepted_answers[0] if accepted_answers else relevant_answers[0]
+ )
+ excerpt = html.unescape(top_answer["excerpt"])
+ res_text += f"\nAnswer: {excerpt}"
+ results.append(res_text)
+
+ return self.result_separator.join(results)
diff --git a/libs/community/langchain_community/utilities/steam.py b/libs/community/langchain_community/utilities/steam.py
new file mode 100644
index 00000000000..778c3c6870e
--- /dev/null
+++ b/libs/community/langchain_community/utilities/steam.py
@@ -0,0 +1,164 @@
+"""Util that calls Steam-WebAPI."""
+
+from typing import Any, List
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+
+
+class SteamWebAPIWrapper(BaseModel):
+ """Wrapper for Steam API."""
+
+ steam: Any # for python-steam-api
+
+ from langchain_community.tools.steam.prompt import (
+ STEAM_GET_GAMES_DETAILS,
+ STEAM_GET_RECOMMENDED_GAMES,
+ )
+
+ # operations: a list of dictionaries, each representing a specific operation that
+ # can be performed with the API
+ operations: List[dict] = [
+ {
+ "mode": "get_game_details",
+ "name": "Get Game Details",
+ "description": STEAM_GET_GAMES_DETAILS,
+ },
+ {
+ "mode": "get_recommended_games",
+ "name": "Get Recommended Games",
+ "description": STEAM_GET_RECOMMENDED_GAMES,
+ },
+ ]
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def get_operations(self) -> List[dict]:
+ """Return a list of operations."""
+ return self.operations
+
+ @root_validator
+ def validate_environment(cls, values: dict) -> dict:
+ """Validate api key and python package has been configured."""
+
+ # check if the python package is installed
+ try:
+ from steam import Steam
+ except ImportError:
+ raise ImportError("python-steam-api library is not installed. ")
+
+ try:
+ from decouple import config
+ except ImportError:
+ raise ImportError("decouple library is not installed. ")
+
+ # initialize the steam attribute for python-steam-api usage
+ KEY = config("STEAM_KEY")
+ steam = Steam(KEY)
+ values["steam"] = steam
+ return values
+
+ def parse_to_str(self, details: dict) -> str: # For later parsing
+ """Parse the details result."""
+ result = ""
+ for key, value in details.items():
+ result += "The " + str(key) + " is: " + str(value) + "\n"
+ return result
+
+ def get_id_link_price(self, games: dict) -> dict:
+ """The response may contain more than one game, so we need to choose the right
+ one and return the id."""
+
+ game_info = {}
+ for app in games["apps"]:
+ game_info["id"] = app["id"]
+ game_info["link"] = app["link"]
+ game_info["price"] = app["price"]
+ break
+ return game_info
+
+ def remove_html_tags(self, html_string: str) -> str:
+ from bs4 import BeautifulSoup
+
+ soup = BeautifulSoup(html_string, "html.parser")
+ return soup.get_text()
+
+ def details_of_games(self, name: str) -> str:
+ games = self.steam.apps.search_games(name)
+ info_partOne_dict = self.get_id_link_price(games)
+ info_partOne = self.parse_to_str(info_partOne_dict)
+ id = str(info_partOne_dict.get("id"))
+ info_dict = self.steam.apps.get_app_details(id)
+ data = info_dict.get(id).get("data")
+ detailed_description = data.get("detailed_description")
+
+ # detailed_description contains
some other html tags, so we need to
+ # remove them
+ detailed_description = self.remove_html_tags(detailed_description)
+ supported_languages = info_dict.get(id).get("data").get("supported_languages")
+ info_partTwo = (
+ "The summary of the game is: "
+ + detailed_description
+ + "\n"
+ + "The supported languages of the game are: "
+ + supported_languages
+ + "\n"
+ )
+ info = info_partOne + info_partTwo
+ return info
+
+ def get_steam_id(self, name: str) -> str:
+ user = self.steam.users.search_user(name)
+ steam_id = user["player"]["steamid"]
+ return steam_id
+
+ def get_users_games(self, steam_id: str) -> List[str]:
+ return self.steam.users.get_owned_games(steam_id, False, False)
+
+ def recommended_games(self, steam_id: str) -> str:
+ try:
+ import steamspypi
+ except ImportError:
+ raise ImportError("steamspypi library is not installed.")
+ users_games = self.get_users_games(steam_id)
+ result = {} # type: ignore
+ most_popular_genre = ""
+ most_popular_genre_count = 0
+ for game in users_games["games"]: # type: ignore
+ appid = game["appid"]
+ data_request = {"request": "appdetails", "appid": appid}
+ genreStore = steamspypi.download(data_request)
+ genreList = genreStore.get("genre", "").split(", ")
+
+ for genre in genreList:
+ if genre in result:
+ result[genre] += 1
+ else:
+ result[genre] = 1
+ if result[genre] > most_popular_genre_count:
+ most_popular_genre_count = result[genre]
+ most_popular_genre = genre
+
+ data_request = dict()
+ data_request["request"] = "genre"
+ data_request["genre"] = most_popular_genre
+ data = steamspypi.download(data_request)
+ sorted_data = sorted(
+ data.values(), key=lambda x: x.get("average_forever", 0), reverse=True
+ )
+ owned_games = [game["appid"] for game in users_games["games"]] # type: ignore
+ remaining_games = [
+ game for game in sorted_data if game["appid"] not in owned_games
+ ]
+ top_5_popular_not_owned = [game["name"] for game in remaining_games[:5]]
+ return str(top_5_popular_not_owned)
+
+ def run(self, mode: str, game: str) -> str:
+ if mode == "get_games_details":
+ return self.details_of_games(game)
+ elif mode == "get_recommended_games":
+ return self.recommended_games(game)
+ else:
+ raise ValueError(f"Invalid mode {mode} for Steam API.")
diff --git a/libs/community/langchain_community/utilities/tavily_search.py b/libs/community/langchain_community/utilities/tavily_search.py
new file mode 100644
index 00000000000..54cd0810cc2
--- /dev/null
+++ b/libs/community/langchain_community/utilities/tavily_search.py
@@ -0,0 +1,183 @@
+"""Util that calls Tavily Search API.
+
+In order to set this up, follow instructions at:
+"""
+import json
+from typing import Dict, List, Optional
+
+import aiohttp
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+TAVILY_API_URL = "https://api.tavily.com"
+
+
+class TavilySearchAPIWrapper(BaseModel):
+ """Wrapper for Tavily Search API."""
+
+ tavily_api_key: str
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and endpoint exists in environment."""
+ tavily_api_key = get_from_dict_or_env(
+ values, "tavily_api_key", "TAVILY_API_KEY"
+ )
+ values["tavily_api_key"] = tavily_api_key
+
+ return values
+
+ def raw_results(
+ self,
+ query: str,
+ max_results: Optional[int] = 5,
+ search_depth: Optional[str] = "advanced",
+ include_domains: Optional[List[str]] = [],
+ exclude_domains: Optional[List[str]] = [],
+ include_answer: Optional[bool] = False,
+ include_raw_content: Optional[bool] = False,
+ include_images: Optional[bool] = False,
+ ) -> Dict:
+ params = {
+ "api_key": self.tavily_api_key,
+ "query": query,
+ "max_results": max_results,
+ "search_depth": search_depth,
+ "include_domains": include_domains,
+ "exclude_domains": exclude_domains,
+ "include_answer": include_answer,
+ "include_raw_content": include_raw_content,
+ "include_images": include_images,
+ }
+ response = requests.post(
+ # type: ignore
+ f"{TAVILY_API_URL}/search",
+ json=params,
+ )
+ response.raise_for_status()
+ return response.json()
+
+ def results(
+ self,
+ query: str,
+ max_results: Optional[int] = 5,
+ search_depth: Optional[str] = "advanced",
+ include_domains: Optional[List[str]] = [],
+ exclude_domains: Optional[List[str]] = [],
+ include_answer: Optional[bool] = False,
+ include_raw_content: Optional[bool] = False,
+ include_images: Optional[bool] = False,
+ ) -> List[Dict]:
+ """Run query through Tavily Search and return metadata.
+
+ Args:
+ query: The query to search for.
+ max_results: The maximum number of results to return.
+ search_depth: The depth of the search. Can be "basic" or "advanced".
+ include_domains: A list of domains to include in the search.
+ exclude_domains: A list of domains to exclude from the search.
+ include_answer: Whether to include the answer in the results.
+ include_raw_content: Whether to include the raw content in the results.
+ include_images: Whether to include images in the results.
+ Returns:
+ query: The query that was searched for.
+ follow_up_questions: A list of follow up questions.
+ response_time: The response time of the query.
+ answer: The answer to the query.
+ images: A list of images.
+ results: A list of dictionaries containing the results:
+ title: The title of the result.
+ url: The url of the result.
+ content: The content of the result.
+ score: The score of the result.
+ raw_content: The raw content of the result.
+ """ # noqa: E501
+ raw_search_results = self.raw_results(
+ query,
+ max_results=max_results,
+ search_depth=search_depth,
+ include_domains=include_domains,
+ exclude_domains=exclude_domains,
+ include_answer=include_answer,
+ include_raw_content=include_raw_content,
+ include_images=include_images,
+ )
+ return self.clean_results(raw_search_results["results"])
+
+ async def raw_results_async(
+ self,
+ query: str,
+ max_results: Optional[int] = 5,
+ search_depth: Optional[str] = "advanced",
+ include_domains: Optional[List[str]] = [],
+ exclude_domains: Optional[List[str]] = [],
+ include_answer: Optional[bool] = False,
+ include_raw_content: Optional[bool] = False,
+ include_images: Optional[bool] = False,
+ ) -> Dict:
+ """Get results from the Tavily Search API asynchronously."""
+
+ # Function to perform the API call
+ async def fetch() -> str:
+ params = {
+ "api_key": self.tavily_api_key,
+ "query": query,
+ "max_results": max_results,
+ "search_depth": search_depth,
+ "include_domains": include_domains,
+ "exclude_domains": exclude_domains,
+ "include_answer": include_answer,
+ "include_raw_content": include_raw_content,
+ "include_images": include_images,
+ }
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
+ if res.status == 200:
+ data = await res.text()
+ return data
+ else:
+ raise Exception(f"Error {res.status}: {res.reason}")
+
+ results_json_str = await fetch()
+ return json.loads(results_json_str)
+
+ async def results_async(
+ self,
+ query: str,
+ max_results: Optional[int] = 5,
+ search_depth: Optional[str] = "advanced",
+ include_domains: Optional[List[str]] = [],
+ exclude_domains: Optional[List[str]] = [],
+ include_answer: Optional[bool] = False,
+ include_raw_content: Optional[bool] = False,
+ include_images: Optional[bool] = False,
+ ) -> List[Dict]:
+ results_json = await self.raw_results_async(
+ query=query,
+ max_results=max_results,
+ search_depth=search_depth,
+ include_domains=include_domains,
+ exclude_domains=exclude_domains,
+ include_answer=include_answer,
+ include_raw_content=include_raw_content,
+ include_images=include_images,
+ )
+ return self.clean_results(results_json["results"])
+
+ def clean_results(self, results: List[Dict]) -> List[Dict]:
+ """Clean results from Tavily Search API."""
+ clean_results = []
+ for result in results:
+ clean_results.append(
+ {
+ "url": result["url"],
+ "content": result["content"],
+ }
+ )
+ return clean_results
diff --git a/libs/community/langchain_community/utilities/tensorflow_datasets.py b/libs/community/langchain_community/utilities/tensorflow_datasets.py
new file mode 100644
index 00000000000..8f96c9c0c29
--- /dev/null
+++ b/libs/community/langchain_community/utilities/tensorflow_datasets.py
@@ -0,0 +1,110 @@
+import logging
+from typing import Any, Callable, Dict, Iterator, List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+logger = logging.getLogger(__name__)
+
+
+class TensorflowDatasets(BaseModel):
+ """Access to the TensorFlow Datasets.
+
+ The Current implementation can work only with datasets that fit in a memory.
+
+ `TensorFlow Datasets` is a collection of datasets ready to use, with TensorFlow
+ or other Python ML frameworks, such as Jax. All datasets are exposed
+ as `tf.data.Datasets`.
+ To get started see the Guide: https://www.tensorflow.org/datasets/overview and
+ the list of datasets: https://www.tensorflow.org/datasets/catalog/
+ overview#all_datasets
+
+ You have to provide the sample_to_document_function: a function that
+ a sample from the dataset-specific format to the Document.
+
+ Attributes:
+ dataset_name: the name of the dataset to load
+ split_name: the name of the split to load. Defaults to "train".
+ load_max_docs: a limit to the number of loaded documents. Defaults to 100.
+ sample_to_document_function: a function that converts a dataset sample
+ to a Document
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities import TensorflowDatasets
+
+ def mlqaen_example_to_document(example: dict) -> Document:
+ return Document(
+ page_content=decode_to_str(example["context"]),
+ metadata={
+ "id": decode_to_str(example["id"]),
+ "title": decode_to_str(example["title"]),
+ "question": decode_to_str(example["question"]),
+ "answer": decode_to_str(example["answers"]["text"][0]),
+ },
+ )
+
+ tsds_client = TensorflowDatasets(
+ dataset_name="mlqa/en",
+ split_name="train",
+ load_max_docs=MAX_DOCS,
+ sample_to_document_function=mlqaen_example_to_document,
+ )
+
+ """
+
+ dataset_name: str = ""
+ split_name: str = "train"
+ load_max_docs: int = 100
+ sample_to_document_function: Optional[Callable[[Dict], Document]] = None
+ dataset: Any #: :meta private:
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in environment."""
+ try:
+ import tensorflow # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import tensorflow python package. "
+ "Please install it with `pip install tensorflow`."
+ )
+ try:
+ import tensorflow_datasets
+ except ImportError:
+ raise ImportError(
+ "Could not import tensorflow_datasets python package. "
+ "Please install it with `pip install tensorflow-datasets`."
+ )
+ if values["sample_to_document_function"] is None:
+ raise ValueError(
+ "sample_to_document_function is None. "
+ "Please provide a function that converts a dataset sample to"
+ " a Document."
+ )
+ values["dataset"] = tensorflow_datasets.load(
+ values["dataset_name"], split=values["split_name"]
+ )
+
+ return values
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Download a selected dataset lazily.
+
+ Returns: an iterator of Documents.
+
+ """
+ return (
+ self.sample_to_document_function(s)
+ for s in self.dataset.take(self.load_max_docs)
+ if self.sample_to_document_function is not None
+ )
+
+ def load(self) -> List[Document]:
+ """Download a selected dataset.
+
+ Returns: a list of Documents.
+
+ """
+ return list(self.lazy_load())
diff --git a/libs/community/langchain_community/utilities/twilio.py b/libs/community/langchain_community/utilities/twilio.py
new file mode 100644
index 00000000000..a3ff0b23696
--- /dev/null
+++ b/libs/community/langchain_community/utilities/twilio.py
@@ -0,0 +1,82 @@
+"""Util that calls Twilio."""
+from typing import Any, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class TwilioAPIWrapper(BaseModel):
+ """Messaging Client using Twilio.
+
+ To use, you should have the ``twilio`` python package installed,
+ and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and
+ ``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as
+ named parameters to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.utilities.twilio import TwilioAPIWrapper
+ twilio = TwilioAPIWrapper(
+ account_sid="ACxxx",
+ auth_token="xxx",
+ from_number="+10123456789"
+ )
+ twilio.run('test', '+12484345508')
+ """
+
+ client: Any #: :meta private:
+ account_sid: Optional[str] = None
+ """Twilio account string identifier."""
+ auth_token: Optional[str] = None
+ """Twilio auth token."""
+ from_number: Optional[str] = None
+ """A Twilio phone number in [E.164](https://www.twilio.com/docs/glossary/what-e164)
+ format, an
+ [alphanumeric sender ID](https://www.twilio.com/docs/sms/send-messages#use-an-alphanumeric-sender-id),
+ or a [Channel Endpoint address](https://www.twilio.com/docs/sms/channels#channel-addresses)
+ that is enabled for the type of message you want to send. Phone numbers or
+ [short codes](https://www.twilio.com/docs/sms/api/short-code) purchased from
+ Twilio also work here. You cannot, for example, spoof messages from a private
+ cell phone number. If you are using `messaging_service_sid`, this parameter
+ must be empty.
+ """ # noqa: E501
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = False
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ try:
+ from twilio.rest import Client
+ except ImportError:
+ raise ImportError(
+ "Could not import twilio python package. "
+ "Please install it with `pip install twilio`."
+ )
+ account_sid = get_from_dict_or_env(values, "account_sid", "TWILIO_ACCOUNT_SID")
+ auth_token = get_from_dict_or_env(values, "auth_token", "TWILIO_AUTH_TOKEN")
+ values["from_number"] = get_from_dict_or_env(
+ values, "from_number", "TWILIO_FROM_NUMBER"
+ )
+ values["client"] = Client(account_sid, auth_token)
+ return values
+
+ def run(self, body: str, to: str) -> str:
+ """Run body through Twilio and respond with message sid.
+
+ Args:
+ body: The text of the message you want to send. Can be up to 1,600
+ characters in length.
+ to: The destination phone number in
+ [E.164](https://www.twilio.com/docs/glossary/what-e164) format for
+ SMS/MMS or
+ [Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
+ for other 3rd-party channels.
+ """ # noqa: E501
+ message = self.client.messages.create(to, from_=self.from_number, body=body)
+ return message.sid
diff --git a/libs/community/langchain_community/utilities/vertexai.py b/libs/community/langchain_community/utilities/vertexai.py
new file mode 100644
index 00000000000..f85a009df87
--- /dev/null
+++ b/libs/community/langchain_community/utilities/vertexai.py
@@ -0,0 +1,107 @@
+"""Utilities to init Vertex AI."""
+from importlib import metadata
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
+
+if TYPE_CHECKING:
+ from google.api_core.gapic_v1.client_info import ClientInfo
+ from google.auth.credentials import Credentials
+
+
+def create_retry_decorator(
+ llm: BaseLLM,
+ *,
+ max_retries: int = 1,
+ run_manager: Optional[
+ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
+ ] = None,
+) -> Callable[[Any], Any]:
+ """Creates a retry decorator for Vertex / Palm LLMs."""
+ import google.api_core
+
+ errors = [
+ google.api_core.exceptions.ResourceExhausted,
+ google.api_core.exceptions.ServiceUnavailable,
+ google.api_core.exceptions.Aborted,
+ google.api_core.exceptions.DeadlineExceeded,
+ google.api_core.exceptions.GoogleAPIError,
+ ]
+ decorator = create_base_retry_decorator(
+ error_types=errors, max_retries=max_retries, run_manager=run_manager
+ )
+ return decorator
+
+
+def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None:
+ """Raise ImportError related to Vertex SDK being not available.
+
+ Args:
+ minimum_expected_version: The lowest expected version of the SDK.
+ Raises:
+ ImportError: an ImportError that mentions a required version of the SDK.
+ """
+ raise ImportError(
+ "Please, install or upgrade the google-cloud-aiplatform library: "
+ f"pip install google-cloud-aiplatform>={minimum_expected_version}"
+ )
+
+
+def init_vertexai(
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional["Credentials"] = None,
+) -> None:
+ """Init vertexai.
+
+ Args:
+ project: The default GCP project to use when making Vertex API calls.
+ location: The default location to use when making API calls.
+ credentials: The default custom
+ credentials to use when making API calls. If not provided credentials
+ will be ascertained from the environment.
+
+ Raises:
+ ImportError: If importing vertexai SDK did not succeed.
+ """
+ try:
+ import vertexai
+ except ImportError:
+ raise_vertex_import_error()
+
+ vertexai.init(
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+
+def get_client_info(module: Optional[str] = None) -> "ClientInfo":
+ r"""Returns a custom user agent header.
+
+ Args:
+ module (Optional[str]):
+ Optional. The module for a custom user agent header.
+ Returns:
+ google.api_core.gapic_v1.client_info.ClientInfo
+ """
+ try:
+ from google.api_core.gapic_v1.client_info import ClientInfo
+ except ImportError as exc:
+ raise ImportError(
+ "Could not import ClientInfo. Please, install it with "
+ "pip install google-api-core"
+ ) from exc
+
+ langchain_version = metadata.version("langchain")
+ client_library_version = (
+ f"{langchain_version}-{module}" if module else langchain_version
+ )
+ return ClientInfo(
+ client_library_version=client_library_version,
+ user_agent=f"langchain/{client_library_version}",
+ )
diff --git a/libs/community/langchain_community/utilities/wikipedia.py b/libs/community/langchain_community/utilities/wikipedia.py
new file mode 100644
index 00000000000..37dc064ffb4
--- /dev/null
+++ b/libs/community/langchain_community/utilities/wikipedia.py
@@ -0,0 +1,116 @@
+"""Util that calls Wikipedia."""
+import logging
+from typing import Any, Dict, List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.pydantic_v1 import BaseModel, root_validator
+
+logger = logging.getLogger(__name__)
+
+WIKIPEDIA_MAX_QUERY_LENGTH = 300
+
+
+class WikipediaAPIWrapper(BaseModel):
+ """Wrapper around WikipediaAPI.
+
+ To use, you should have the ``wikipedia`` python package installed.
+ This wrapper will use the Wikipedia API to conduct searches and
+ fetch page summaries. By default, it will return the page summaries
+ of the top-k results.
+ It limits the Document content by doc_content_chars_max.
+ """
+
+ wiki_client: Any #: :meta private:
+ top_k_results: int = 3
+ lang: str = "en"
+ load_all_available_meta: bool = False
+ doc_content_chars_max: int = 4000
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in environment."""
+ try:
+ import wikipedia
+
+ wikipedia.set_lang(values["lang"])
+ values["wiki_client"] = wikipedia
+ except ImportError:
+ raise ImportError(
+ "Could not import wikipedia python package. "
+ "Please install it with `pip install wikipedia`."
+ )
+ return values
+
+ def run(self, query: str) -> str:
+ """Run Wikipedia search and get page summaries."""
+ page_titles = self.wiki_client.search(
+ query[:WIKIPEDIA_MAX_QUERY_LENGTH], results=self.top_k_results
+ )
+ summaries = []
+ for page_title in page_titles[: self.top_k_results]:
+ if wiki_page := self._fetch_page(page_title):
+ if summary := self._formatted_page_summary(page_title, wiki_page):
+ summaries.append(summary)
+ if not summaries:
+ return "No good Wikipedia Search Result was found"
+ return "\n\n".join(summaries)[: self.doc_content_chars_max]
+
+ @staticmethod
+ def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]:
+ return f"Page: {page_title}\nSummary: {wiki_page.summary}"
+
+ def _page_to_document(self, page_title: str, wiki_page: Any) -> Document:
+ main_meta = {
+ "title": page_title,
+ "summary": wiki_page.summary,
+ "source": wiki_page.url,
+ }
+ add_meta = (
+ {
+ "categories": wiki_page.categories,
+ "page_url": wiki_page.url,
+ "image_urls": wiki_page.images,
+ "related_titles": wiki_page.links,
+ "parent_id": wiki_page.parent_id,
+ "references": wiki_page.references,
+ "revision_id": wiki_page.revision_id,
+ "sections": wiki_page.sections,
+ }
+ if self.load_all_available_meta
+ else {}
+ )
+ doc = Document(
+ page_content=wiki_page.content[: self.doc_content_chars_max],
+ metadata={
+ **main_meta,
+ **add_meta,
+ },
+ )
+ return doc
+
+ def _fetch_page(self, page: str) -> Optional[str]:
+ try:
+ return self.wiki_client.page(title=page, auto_suggest=False)
+ except (
+ self.wiki_client.exceptions.PageError,
+ self.wiki_client.exceptions.DisambiguationError,
+ ):
+ return None
+
+ def load(self, query: str) -> List[Document]:
+ """
+ Run Wikipedia search and get the article text plus the meta information.
+ See
+
+ Returns: a list of documents.
+
+ """
+ page_titles = self.wiki_client.search(
+ query[:WIKIPEDIA_MAX_QUERY_LENGTH], results=self.top_k_results
+ )
+ docs = []
+ for page_title in page_titles[: self.top_k_results]:
+ if wiki_page := self._fetch_page(page_title):
+ if doc := self._page_to_document(page_title, wiki_page):
+ docs.append(doc)
+ return docs
diff --git a/libs/community/langchain_community/utilities/wolfram_alpha.py b/libs/community/langchain_community/utilities/wolfram_alpha.py
new file mode 100644
index 00000000000..079453a0cb3
--- /dev/null
+++ b/libs/community/langchain_community/utilities/wolfram_alpha.py
@@ -0,0 +1,63 @@
+"""Util that calls WolframAlpha."""
+from typing import Any, Dict, Optional
+
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+
+
+class WolframAlphaAPIWrapper(BaseModel):
+ """Wrapper for Wolfram Alpha.
+
+ Docs for using:
+
+ 1. Go to wolfram alpha and sign up for a developer account
+ 2. Create an app and get your APP ID
+ 3. Save your APP ID into WOLFRAM_ALPHA_APPID env variable
+ 4. pip install wolframalpha
+
+ """
+
+ wolfram_client: Any #: :meta private:
+ wolfram_alpha_appid: Optional[str] = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key and python package exists in environment."""
+ wolfram_alpha_appid = get_from_dict_or_env(
+ values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID"
+ )
+ values["wolfram_alpha_appid"] = wolfram_alpha_appid
+
+ try:
+ import wolframalpha
+
+ except ImportError:
+ raise ImportError(
+ "wolframalpha is not installed. "
+ "Please install it with `pip install wolframalpha`"
+ )
+ client = wolframalpha.Client(wolfram_alpha_appid)
+ values["wolfram_client"] = client
+
+ return values
+
+ def run(self, query: str) -> str:
+ """Run query through WolframAlpha and parse result."""
+ res = self.wolfram_client.query(query)
+
+ try:
+ assumption = next(res.pods).text
+ answer = next(res.results).text
+ except StopIteration:
+ return "Wolfram Alpha wasn't able to answer it"
+
+ if answer is None or answer == "":
+ # We don't want to return the assumption alone if answer is empty
+ return "No good Wolfram Alpha Result was found"
+ else:
+ return f"Assumption: {assumption} \nAnswer: {answer}"
diff --git a/libs/community/langchain_community/utilities/zapier.py b/libs/community/langchain_community/utilities/zapier.py
new file mode 100644
index 00000000000..adf0a68f46b
--- /dev/null
+++ b/libs/community/langchain_community/utilities/zapier.py
@@ -0,0 +1,297 @@
+"""Util that can interact with Zapier NLA.
+
+Full docs here: https://nla.zapier.com/start/
+
+Note: this wrapper currently only implemented the `api_key` auth method for testing
+and server-side production use cases (using the developer's connected accounts on
+Zapier.com)
+
+For use-cases where LangChain + Zapier NLA is powering a user-facing application, and
+LangChain needs access to the end-user's connected accounts on Zapier.com, you'll need
+to use oauth. Review the full docs above and reach out to nla@zapier.com for
+developer support.
+"""
+import json
+from typing import Any, Dict, List, Optional
+
+import aiohttp
+import requests
+from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
+from langchain_core.utils import get_from_dict_or_env
+from requests import Request, Session
+
+
+class ZapierNLAWrapper(BaseModel):
+ """Wrapper for Zapier NLA.
+
+ Full docs here: https://nla.zapier.com/start/
+
+ This wrapper supports both API Key and OAuth Credential auth methods. API Key
+ is the fastest way to get started using this wrapper.
+
+ Call this wrapper with either `zapier_nla_api_key` or
+ `zapier_nla_oauth_access_token` arguments, or set the `ZAPIER_NLA_API_KEY`
+ environment variable. If both arguments are set, the Access Token will take
+ precedence.
+
+ For use-cases where LangChain + Zapier NLA is powering a user-facing application,
+ and LangChain needs access to the end-user's connected accounts on Zapier.com,
+ you'll need to use OAuth. Review the full docs above to learn how to create
+ your own provider and generate credentials.
+ """
+
+ zapier_nla_api_key: str
+ zapier_nla_oauth_access_token: str
+ zapier_nla_api_base: str = "https://nla.zapier.com/api/v1/"
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _format_headers(self) -> Dict[str, str]:
+ """Format headers for requests."""
+ headers = {
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ }
+
+ if self.zapier_nla_oauth_access_token:
+ headers.update(
+ {"Authorization": f"Bearer {self.zapier_nla_oauth_access_token}"}
+ )
+ else:
+ headers.update({"X-API-Key": self.zapier_nla_api_key})
+
+ return headers
+
+ def _get_session(self) -> Session:
+ session = requests.Session()
+ session.headers.update(self._format_headers())
+ return session
+
+ async def _arequest(self, method: str, url: str, **kwargs: Any) -> Dict[str, Any]:
+ """Make an async request."""
+ async with aiohttp.ClientSession(headers=self._format_headers()) as session:
+ async with session.request(method, url, **kwargs) as response:
+ response.raise_for_status()
+ return await response.json()
+
+ def _create_action_payload( # type: ignore[no-untyped-def]
+ self, instructions: str, params: Optional[Dict] = None, preview_only=False
+ ) -> Dict:
+ """Create a payload for an action."""
+ data = params if params else {}
+ data.update(
+ {
+ "instructions": instructions,
+ }
+ )
+ if preview_only:
+ data.update({"preview_only": True})
+ return data
+
+ def _create_action_url(self, action_id: str) -> str:
+ """Create a url for an action."""
+ return self.zapier_nla_api_base + f"exposed/{action_id}/execute/"
+
+ def _create_action_request( # type: ignore[no-untyped-def]
+ self,
+ action_id: str,
+ instructions: str,
+ params: Optional[Dict] = None,
+ preview_only=False,
+ ) -> Request:
+ data = self._create_action_payload(instructions, params, preview_only)
+ return Request(
+ "POST",
+ self._create_action_url(action_id),
+ json=data,
+ )
+
+ @root_validator(pre=True)
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that api key exists in environment."""
+
+ zapier_nla_api_key_default = None
+
+ # If there is a oauth_access_key passed in the values
+ # we don't need a nla_api_key it can be blank
+ if "zapier_nla_oauth_access_token" in values:
+ zapier_nla_api_key_default = ""
+ else:
+ values["zapier_nla_oauth_access_token"] = ""
+
+ # we require at least one API Key
+ zapier_nla_api_key = get_from_dict_or_env(
+ values,
+ "zapier_nla_api_key",
+ "ZAPIER_NLA_API_KEY",
+ zapier_nla_api_key_default,
+ )
+
+ values["zapier_nla_api_key"] = zapier_nla_api_key
+
+ return values
+
+ async def alist(self) -> List[Dict]:
+ """Returns a list of all exposed (enabled) actions associated with
+ current user (associated with the set api_key). Change your exposed
+ actions here: https://nla.zapier.com/demo/start/
+
+ The return list can be empty if no actions exposed. Else will contain
+ a list of action objects:
+
+ [{
+ "id": str,
+ "description": str,
+ "params": Dict[str, str]
+ }]
+
+ `params` will always contain an `instructions` key, the only required
+ param. All others optional and if provided will override any AI guesses
+ (see "understanding the AI guessing flow" here:
+ https://nla.zapier.com/api/v1/docs)
+ """
+ response = await self._arequest("GET", self.zapier_nla_api_base + "exposed/")
+ return response["results"]
+
+ def list(self) -> List[Dict]:
+ """Returns a list of all exposed (enabled) actions associated with
+ current user (associated with the set api_key). Change your exposed
+ actions here: https://nla.zapier.com/demo/start/
+
+ The return list can be empty if no actions exposed. Else will contain
+ a list of action objects:
+
+ [{
+ "id": str,
+ "description": str,
+ "params": Dict[str, str]
+ }]
+
+ `params` will always contain an `instructions` key, the only required
+ param. All others optional and if provided will override any AI guesses
+ (see "understanding the AI guessing flow" here:
+ https://nla.zapier.com/docs/using-the-api#ai-guessing)
+ """
+ session = self._get_session()
+ try:
+ response = session.get(self.zapier_nla_api_base + "exposed/")
+ response.raise_for_status()
+ except requests.HTTPError as http_err:
+ if response.status_code == 401:
+ if self.zapier_nla_oauth_access_token:
+ raise requests.HTTPError(
+ f"An unauthorized response occurred. Check that your "
+ f"access token is correct and doesn't need to be "
+ f"refreshed. Err: {http_err}",
+ response=response,
+ )
+ raise requests.HTTPError(
+ f"An unauthorized response occurred. Check that your api "
+ f"key is correct. Err: {http_err}",
+ response=response,
+ )
+ raise http_err
+ return response.json()["results"]
+
+ def run(
+ self, action_id: str, instructions: str, params: Optional[Dict] = None
+ ) -> Dict:
+ """Executes an action that is identified by action_id, must be exposed
+ (enabled) by the current user (associated with the set api_key). Change
+ your exposed actions here: https://nla.zapier.com/demo/start/
+
+ The return JSON is guaranteed to be less than ~500 words (350
+ tokens) making it safe to inject into the prompt of another LLM
+ call.
+ """
+ session = self._get_session()
+ request = self._create_action_request(action_id, instructions, params)
+ response = session.send(session.prepare_request(request))
+ response.raise_for_status()
+ return response.json()["result"]
+
+ async def arun(
+ self, action_id: str, instructions: str, params: Optional[Dict] = None
+ ) -> Dict:
+ """Executes an action that is identified by action_id, must be exposed
+ (enabled) by the current user (associated with the set api_key). Change
+ your exposed actions here: https://nla.zapier.com/demo/start/
+
+ The return JSON is guaranteed to be less than ~500 words (350
+ tokens) making it safe to inject into the prompt of another LLM
+ call.
+ """
+ response = await self._arequest(
+ "POST",
+ self._create_action_url(action_id),
+ json=self._create_action_payload(instructions, params),
+ )
+ return response["result"]
+
+ def preview(
+ self, action_id: str, instructions: str, params: Optional[Dict] = None
+ ) -> Dict:
+ """Same as run, but instead of actually executing the action, will
+ instead return a preview of params that have been guessed by the AI in
+ case you need to explicitly review before executing."""
+ session = self._get_session()
+ params = params if params else {}
+ params.update({"preview_only": True})
+ request = self._create_action_request(action_id, instructions, params, True)
+ response = session.send(session.prepare_request(request))
+ response.raise_for_status()
+ return response.json()["input_params"]
+
+ async def apreview(
+ self, action_id: str, instructions: str, params: Optional[Dict] = None
+ ) -> Dict:
+ """Same as run, but instead of actually executing the action, will
+ instead return a preview of params that have been guessed by the AI in
+ case you need to explicitly review before executing."""
+ response = await self._arequest(
+ "POST",
+ self._create_action_url(action_id),
+ json=self._create_action_payload(instructions, params, preview_only=True),
+ )
+ return response["result"]
+
+ def run_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
+ """Same as run, but returns a stringified version of the JSON for
+ insertting back into an LLM."""
+ data = self.run(*args, **kwargs)
+ return json.dumps(data)
+
+ async def arun_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
+ """Same as run, but returns a stringified version of the JSON for
+ insertting back into an LLM."""
+ data = await self.arun(*args, **kwargs)
+ return json.dumps(data)
+
+ def preview_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
+ """Same as preview, but returns a stringified version of the JSON for
+ insertting back into an LLM."""
+ data = self.preview(*args, **kwargs)
+ return json.dumps(data)
+
+ async def apreview_as_str( # type: ignore[no-untyped-def]
+ self, *args, **kwargs
+ ) -> str:
+ """Same as preview, but returns a stringified version of the JSON for
+ insertting back into an LLM."""
+ data = await self.apreview(*args, **kwargs)
+ return json.dumps(data)
+
+ def list_as_str(self) -> str: # type: ignore[no-untyped-def]
+ """Same as list, but returns a stringified version of the JSON for
+ insertting back into an LLM."""
+ actions = self.list()
+ return json.dumps(actions)
+
+ async def alist_as_str(self) -> str: # type: ignore[no-untyped-def]
+ """Same as list, but returns a stringified version of the JSON for
+ insertting back into an LLM."""
+ actions = await self.alist()
+ return json.dumps(actions)
diff --git a/libs/langchain/tests/unit_tests/chat_loaders/__init__.py b/libs/community/langchain_community/utils/__init__.py
similarity index 100%
rename from libs/langchain/tests/unit_tests/chat_loaders/__init__.py
rename to libs/community/langchain_community/utils/__init__.py
diff --git a/libs/community/langchain_community/utils/math.py b/libs/community/langchain_community/utils/math.py
new file mode 100644
index 00000000000..99d47368197
--- /dev/null
+++ b/libs/community/langchain_community/utils/math.py
@@ -0,0 +1,75 @@
+"""Math utils."""
+import logging
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
+
+
+def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
+ """Row-wise cosine similarity between two equal-width matrices."""
+ if len(X) == 0 or len(Y) == 0:
+ return np.array([])
+
+ X = np.array(X)
+ Y = np.array(Y)
+ if X.shape[1] != Y.shape[1]:
+ raise ValueError(
+ f"Number of columns in X and Y must be the same. X has shape {X.shape} "
+ f"and Y has shape {Y.shape}."
+ )
+ try:
+ import simsimd as simd
+
+ X = np.array(X, dtype=np.float32)
+ Y = np.array(Y, dtype=np.float32)
+ Z = 1 - simd.cdist(X, Y, metric="cosine")
+ if isinstance(Z, float):
+ return np.array([Z])
+ return Z
+ except ImportError:
+ logger.info(
+ "Unable to import simsimd, defaulting to NumPy implementation. If you want "
+ "to use simsimd please install with `pip install simsimd`."
+ )
+ X_norm = np.linalg.norm(X, axis=1)
+ Y_norm = np.linalg.norm(Y, axis=1)
+ # Ignore divide by zero errors run time warnings as those are handled below.
+ with np.errstate(divide="ignore", invalid="ignore"):
+ similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
+ similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
+ return similarity
+
+
+def cosine_similarity_top_k(
+ X: Matrix,
+ Y: Matrix,
+ top_k: Optional[int] = 5,
+ score_threshold: Optional[float] = None,
+) -> Tuple[List[Tuple[int, int]], List[float]]:
+ """Row-wise cosine similarity with optional top-k and score threshold filtering.
+
+ Args:
+ X: Matrix.
+ Y: Matrix, same width as X.
+ top_k: Max number of results to return.
+ score_threshold: Minimum cosine similarity of results.
+
+ Returns:
+ Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
+ second contains corresponding cosine similarities.
+ """
+ if len(X) == 0 or len(Y) == 0:
+ return [], []
+ score_array = cosine_similarity(X, Y)
+ score_threshold = score_threshold or -1.0
+ score_array[score_array < score_threshold] = 0
+ top_k = min(top_k or len(score_array), np.count_nonzero(score_array))
+ top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:]
+ top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1]
+ ret_idxs = np.unravel_index(top_k_idxs, score_array.shape)
+ scores = score_array.ravel()[top_k_idxs].tolist()
+ return list(zip(*ret_idxs)), scores # type: ignore
diff --git a/libs/community/langchain_community/utils/openai.py b/libs/community/langchain_community/utils/openai.py
new file mode 100644
index 00000000000..d404d82e354
--- /dev/null
+++ b/libs/community/langchain_community/utils/openai.py
@@ -0,0 +1,10 @@
+from __future__ import annotations
+
+from importlib.metadata import version
+
+from packaging.version import parse
+
+
+def is_openai_v1() -> bool:
+ _version = parse(version("openai"))
+ return _version.major >= 1
diff --git a/libs/community/langchain_community/utils/openai_functions.py b/libs/community/langchain_community/utils/openai_functions.py
new file mode 100644
index 00000000000..308c14876b2
--- /dev/null
+++ b/libs/community/langchain_community/utils/openai_functions.py
@@ -0,0 +1,51 @@
+from typing import Literal, Optional, Type, TypedDict
+
+from langchain_core.pydantic_v1 import BaseModel
+from langchain_core.utils.json_schema import dereference_refs
+
+
+class FunctionDescription(TypedDict):
+ """Representation of a callable function to the OpenAI API."""
+
+ name: str
+ """The name of the function."""
+ description: str
+ """A description of the function."""
+ parameters: dict
+ """The parameters of the function."""
+
+
+class ToolDescription(TypedDict):
+ """Representation of a callable function to the OpenAI API."""
+
+ type: Literal["function"]
+ function: FunctionDescription
+
+
+def convert_pydantic_to_openai_function(
+ model: Type[BaseModel],
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+) -> FunctionDescription:
+ """Converts a Pydantic model to a function description for the OpenAI API."""
+ schema = dereference_refs(model.schema())
+ schema.pop("definitions", None)
+ return {
+ "name": name or schema["title"],
+ "description": description or schema["description"],
+ "parameters": schema,
+ }
+
+
+def convert_pydantic_to_openai_tool(
+ model: Type[BaseModel],
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+) -> ToolDescription:
+ """Converts a Pydantic model to a function description for the OpenAI API."""
+ function = convert_pydantic_to_openai_function(
+ model, name=name, description=description
+ )
+ return {"type": "function", "function": function}
diff --git a/libs/community/langchain_community/vectorstores/__init__.py b/libs/community/langchain_community/vectorstores/__init__.py
new file mode 100644
index 00000000000..df64f984cc1
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/__init__.py
@@ -0,0 +1,657 @@
+"""**Vector store** stores embedded data and performs vector search.
+
+One of the most common ways to store and search over unstructured data is to
+embed it and store the resulting embedding vectors, and then query the store
+and retrieve the data that are 'most similar' to the embedded query.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ VectorStore --> # Examples: Annoy, FAISS, Milvus
+
+ BaseRetriever --> VectorStoreRetriever --> Retriever # Example: VespaRetriever
+
+**Main helpers:**
+
+.. code-block::
+
+ Embeddings, Document
+""" # noqa: E501
+
+from typing import Any
+
+from langchain_core.vectorstores import VectorStore
+
+
+def _import_alibaba_cloud_open_search() -> Any:
+ from langchain_community.vectorstores.alibabacloud_opensearch import (
+ AlibabaCloudOpenSearch,
+ )
+
+ return AlibabaCloudOpenSearch
+
+
+def _import_alibaba_cloud_open_search_settings() -> Any:
+ from langchain_community.vectorstores.alibabacloud_opensearch import (
+ AlibabaCloudOpenSearchSettings,
+ )
+
+ return AlibabaCloudOpenSearchSettings
+
+
+def _import_azure_cosmos_db() -> Any:
+ from langchain_community.vectorstores.azure_cosmos_db import (
+ AzureCosmosDBVectorSearch,
+ )
+
+ return AzureCosmosDBVectorSearch
+
+
+def _import_elastic_knn_search() -> Any:
+ from langchain_community.vectorstores.elastic_vector_search import ElasticKnnSearch
+
+ return ElasticKnnSearch
+
+
+def _import_elastic_vector_search() -> Any:
+ from langchain_community.vectorstores.elastic_vector_search import (
+ ElasticVectorSearch,
+ )
+
+ return ElasticVectorSearch
+
+
+def _import_analyticdb() -> Any:
+ from langchain_community.vectorstores.analyticdb import AnalyticDB
+
+ return AnalyticDB
+
+
+def _import_annoy() -> Any:
+ from langchain_community.vectorstores.annoy import Annoy
+
+ return Annoy
+
+
+def _import_atlas() -> Any:
+ from langchain_community.vectorstores.atlas import AtlasDB
+
+ return AtlasDB
+
+
+def _import_awadb() -> Any:
+ from langchain_community.vectorstores.awadb import AwaDB
+
+ return AwaDB
+
+
+def _import_azuresearch() -> Any:
+ from langchain_community.vectorstores.azuresearch import AzureSearch
+
+ return AzureSearch
+
+
+def _import_bageldb() -> Any:
+ from langchain_community.vectorstores.bageldb import Bagel
+
+ return Bagel
+
+
+def _import_baiducloud_vector_search() -> Any:
+ from langchain_community.vectorstores.baiducloud_vector_search import BESVectorStore
+
+ return BESVectorStore
+
+
+def _import_cassandra() -> Any:
+ from langchain_community.vectorstores.cassandra import Cassandra
+
+ return Cassandra
+
+
+def _import_astradb() -> Any:
+ from langchain_community.vectorstores.astradb import AstraDB
+
+ return AstraDB
+
+
+def _import_chroma() -> Any:
+ from langchain_community.vectorstores.chroma import Chroma
+
+ return Chroma
+
+
+def _import_clarifai() -> Any:
+ from langchain_community.vectorstores.clarifai import Clarifai
+
+ return Clarifai
+
+
+def _import_clickhouse() -> Any:
+ from langchain_community.vectorstores.clickhouse import Clickhouse
+
+ return Clickhouse
+
+
+def _import_clickhouse_settings() -> Any:
+ from langchain_community.vectorstores.clickhouse import ClickhouseSettings
+
+ return ClickhouseSettings
+
+
+def _import_dashvector() -> Any:
+ from langchain_community.vectorstores.dashvector import DashVector
+
+ return DashVector
+
+
+def _import_databricks_vector_search() -> Any:
+ from langchain_community.vectorstores.databricks_vector_search import (
+ DatabricksVectorSearch,
+ )
+
+ return DatabricksVectorSearch
+
+
+def _import_deeplake() -> Any:
+ from langchain_community.vectorstores.deeplake import DeepLake
+
+ return DeepLake
+
+
+def _import_dingo() -> Any:
+ from langchain_community.vectorstores.dingo import Dingo
+
+ return Dingo
+
+
+def _import_docarray_hnsw() -> Any:
+ from langchain_community.vectorstores.docarray import DocArrayHnswSearch
+
+ return DocArrayHnswSearch
+
+
+def _import_docarray_inmemory() -> Any:
+ from langchain_community.vectorstores.docarray import DocArrayInMemorySearch
+
+ return DocArrayInMemorySearch
+
+
+def _import_elasticsearch() -> Any:
+ from langchain_community.vectorstores.elasticsearch import ElasticsearchStore
+
+ return ElasticsearchStore
+
+
+def _import_epsilla() -> Any:
+ from langchain_community.vectorstores.epsilla import Epsilla
+
+ return Epsilla
+
+
+def _import_faiss() -> Any:
+ from langchain_community.vectorstores.faiss import FAISS
+
+ return FAISS
+
+
+def _import_hologres() -> Any:
+ from langchain_community.vectorstores.hologres import Hologres
+
+ return Hologres
+
+
+def _import_lancedb() -> Any:
+ from langchain_community.vectorstores.lancedb import LanceDB
+
+ return LanceDB
+
+
+def _import_llm_rails() -> Any:
+ from langchain_community.vectorstores.llm_rails import LLMRails
+
+ return LLMRails
+
+
+def _import_marqo() -> Any:
+ from langchain_community.vectorstores.marqo import Marqo
+
+ return Marqo
+
+
+def _import_matching_engine() -> Any:
+ from langchain_community.vectorstores.matching_engine import MatchingEngine
+
+ return MatchingEngine
+
+
+def _import_meilisearch() -> Any:
+ from langchain_community.vectorstores.meilisearch import Meilisearch
+
+ return Meilisearch
+
+
+def _import_milvus() -> Any:
+ from langchain_community.vectorstores.milvus import Milvus
+
+ return Milvus
+
+
+def _import_momento_vector_index() -> Any:
+ from langchain_community.vectorstores.momento_vector_index import MomentoVectorIndex
+
+ return MomentoVectorIndex
+
+
+def _import_mongodb_atlas() -> Any:
+ from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
+
+ return MongoDBAtlasVectorSearch
+
+
+def _import_myscale() -> Any:
+ from langchain_community.vectorstores.myscale import MyScale
+
+ return MyScale
+
+
+def _import_myscale_settings() -> Any:
+ from langchain_community.vectorstores.myscale import MyScaleSettings
+
+ return MyScaleSettings
+
+
+def _import_neo4j_vector() -> Any:
+ from langchain_community.vectorstores.neo4j_vector import Neo4jVector
+
+ return Neo4jVector
+
+
+def _import_opensearch_vector_search() -> Any:
+ from langchain_community.vectorstores.opensearch_vector_search import (
+ OpenSearchVectorSearch,
+ )
+
+ return OpenSearchVectorSearch
+
+
+def _import_pgembedding() -> Any:
+ from langchain_community.vectorstores.pgembedding import PGEmbedding
+
+ return PGEmbedding
+
+
+def _import_pgvector() -> Any:
+ from langchain_community.vectorstores.pgvector import PGVector
+
+ return PGVector
+
+
+def _import_pinecone() -> Any:
+ from langchain_community.vectorstores.pinecone import Pinecone
+
+ return Pinecone
+
+
+def _import_qdrant() -> Any:
+ from langchain_community.vectorstores.qdrant import Qdrant
+
+ return Qdrant
+
+
+def _import_redis() -> Any:
+ from langchain_community.vectorstores.redis import Redis
+
+ return Redis
+
+
+def _import_rocksetdb() -> Any:
+ from langchain_community.vectorstores.rocksetdb import Rockset
+
+ return Rockset
+
+
+def _import_vespa() -> Any:
+ from langchain_community.vectorstores.vespa import VespaStore
+
+ return VespaStore
+
+
+def _import_scann() -> Any:
+ from langchain_community.vectorstores.scann import ScaNN
+
+ return ScaNN
+
+
+def _import_semadb() -> Any:
+ from langchain_community.vectorstores.semadb import SemaDB
+
+ return SemaDB
+
+
+def _import_singlestoredb() -> Any:
+ from langchain_community.vectorstores.singlestoredb import SingleStoreDB
+
+ return SingleStoreDB
+
+
+def _import_sklearn() -> Any:
+ from langchain_community.vectorstores.sklearn import SKLearnVectorStore
+
+ return SKLearnVectorStore
+
+
+def _import_sqlitevss() -> Any:
+ from langchain_community.vectorstores.sqlitevss import SQLiteVSS
+
+ return SQLiteVSS
+
+
+def _import_starrocks() -> Any:
+ from langchain_community.vectorstores.starrocks import StarRocks
+
+ return StarRocks
+
+
+def _import_supabase() -> Any:
+ from langchain_community.vectorstores.supabase import SupabaseVectorStore
+
+ return SupabaseVectorStore
+
+
+def _import_tair() -> Any:
+ from langchain_community.vectorstores.tair import Tair
+
+ return Tair
+
+
+def _import_tencentvectordb() -> Any:
+ from langchain_community.vectorstores.tencentvectordb import TencentVectorDB
+
+ return TencentVectorDB
+
+
+def _import_tiledb() -> Any:
+ from langchain_community.vectorstores.tiledb import TileDB
+
+ return TileDB
+
+
+def _import_tigris() -> Any:
+ from langchain_community.vectorstores.tigris import Tigris
+
+ return Tigris
+
+
+def _import_timescalevector() -> Any:
+ from langchain_community.vectorstores.timescalevector import TimescaleVector
+
+ return TimescaleVector
+
+
+def _import_typesense() -> Any:
+ from langchain_community.vectorstores.typesense import Typesense
+
+ return Typesense
+
+
+def _import_usearch() -> Any:
+ from langchain_community.vectorstores.usearch import USearch
+
+ return USearch
+
+
+def _import_vald() -> Any:
+ from langchain_community.vectorstores.vald import Vald
+
+ return Vald
+
+
+def _import_vearch() -> Any:
+ from langchain_community.vectorstores.vearch import Vearch
+
+ return Vearch
+
+
+def _import_vectara() -> Any:
+ from langchain_community.vectorstores.vectara import Vectara
+
+ return Vectara
+
+
+def _import_weaviate() -> Any:
+ from langchain_community.vectorstores.weaviate import Weaviate
+
+ return Weaviate
+
+
+def _import_yellowbrick() -> Any:
+ from langchain_community.vectorstores.yellowbrick import Yellowbrick
+
+ return Yellowbrick
+
+
+def _import_zep() -> Any:
+ from langchain_community.vectorstores.zep import ZepVectorStore
+
+ return ZepVectorStore
+
+
+def _import_zilliz() -> Any:
+ from langchain_community.vectorstores.zilliz import Zilliz
+
+ return Zilliz
+
+
+def __getattr__(name: str) -> Any:
+ if name == "AnalyticDB":
+ return _import_analyticdb()
+ elif name == "AlibabaCloudOpenSearch":
+ return _import_alibaba_cloud_open_search()
+ elif name == "AlibabaCloudOpenSearchSettings":
+ return _import_alibaba_cloud_open_search_settings()
+ elif name == "AzureCosmosDBVectorSearch":
+ return _import_azure_cosmos_db()
+ elif name == "ElasticKnnSearch":
+ return _import_elastic_knn_search()
+ elif name == "ElasticVectorSearch":
+ return _import_elastic_vector_search()
+ elif name == "Annoy":
+ return _import_annoy()
+ elif name == "AtlasDB":
+ return _import_atlas()
+ elif name == "AwaDB":
+ return _import_awadb()
+ elif name == "AzureSearch":
+ return _import_azuresearch()
+ elif name == "Bagel":
+ return _import_bageldb()
+ elif name == "BESVectorStore":
+ return _import_baiducloud_vector_search()
+ elif name == "Cassandra":
+ return _import_cassandra()
+ elif name == "AstraDB":
+ return _import_astradb()
+ elif name == "Chroma":
+ return _import_chroma()
+ elif name == "Clarifai":
+ return _import_clarifai()
+ elif name == "ClickhouseSettings":
+ return _import_clickhouse_settings()
+ elif name == "Clickhouse":
+ return _import_clickhouse()
+ elif name == "DashVector":
+ return _import_dashvector()
+ elif name == "DatabricksVectorSearch":
+ return _import_databricks_vector_search()
+ elif name == "DeepLake":
+ return _import_deeplake()
+ elif name == "Dingo":
+ return _import_dingo()
+ elif name == "DocArrayInMemorySearch":
+ return _import_docarray_inmemory()
+ elif name == "DocArrayHnswSearch":
+ return _import_docarray_hnsw()
+ elif name == "ElasticsearchStore":
+ return _import_elasticsearch()
+ elif name == "Epsilla":
+ return _import_epsilla()
+ elif name == "FAISS":
+ return _import_faiss()
+ elif name == "Hologres":
+ return _import_hologres()
+ elif name == "LanceDB":
+ return _import_lancedb()
+ elif name == "LLMRails":
+ return _import_llm_rails()
+ elif name == "Marqo":
+ return _import_marqo()
+ elif name == "MatchingEngine":
+ return _import_matching_engine()
+ elif name == "Meilisearch":
+ return _import_meilisearch()
+ elif name == "Milvus":
+ return _import_milvus()
+ elif name == "MomentoVectorIndex":
+ return _import_momento_vector_index()
+ elif name == "MongoDBAtlasVectorSearch":
+ return _import_mongodb_atlas()
+ elif name == "MyScaleSettings":
+ return _import_myscale_settings()
+ elif name == "MyScale":
+ return _import_myscale()
+ elif name == "Neo4jVector":
+ return _import_neo4j_vector()
+ elif name == "OpenSearchVectorSearch":
+ return _import_opensearch_vector_search()
+ elif name == "PGEmbedding":
+ return _import_pgembedding()
+ elif name == "PGVector":
+ return _import_pgvector()
+ elif name == "Pinecone":
+ return _import_pinecone()
+ elif name == "Qdrant":
+ return _import_qdrant()
+ elif name == "Redis":
+ return _import_redis()
+ elif name == "Rockset":
+ return _import_rocksetdb()
+ elif name == "ScaNN":
+ return _import_scann()
+ elif name == "SemaDB":
+ return _import_semadb()
+ elif name == "SingleStoreDB":
+ return _import_singlestoredb()
+ elif name == "SKLearnVectorStore":
+ return _import_sklearn()
+ elif name == "SQLiteVSS":
+ return _import_sqlitevss()
+ elif name == "StarRocks":
+ return _import_starrocks()
+ elif name == "SupabaseVectorStore":
+ return _import_supabase()
+ elif name == "Tair":
+ return _import_tair()
+ elif name == "TencentVectorDB":
+ return _import_tencentvectordb()
+ elif name == "TileDB":
+ return _import_tiledb()
+ elif name == "Tigris":
+ return _import_tigris()
+ elif name == "TimescaleVector":
+ return _import_timescalevector()
+ elif name == "Typesense":
+ return _import_typesense()
+ elif name == "USearch":
+ return _import_usearch()
+ elif name == "Vald":
+ return _import_vald()
+ elif name == "Vearch":
+ return _import_vearch()
+ elif name == "Vectara":
+ return _import_vectara()
+ elif name == "Weaviate":
+ return _import_weaviate()
+ elif name == "Yellowbrick":
+ return _import_yellowbrick()
+ elif name == "ZepVectorStore":
+ return _import_zep()
+ elif name == "Zilliz":
+ return _import_zilliz()
+ elif name == "VespaStore":
+ return _import_vespa()
+ else:
+ raise AttributeError(f"Could not find: {name}")
+
+
+__all__ = [
+ "AlibabaCloudOpenSearch",
+ "AlibabaCloudOpenSearchSettings",
+ "AnalyticDB",
+ "Annoy",
+ "AtlasDB",
+ "AwaDB",
+ "AzureSearch",
+ "Bagel",
+ "Cassandra",
+ "AstraDB",
+ "Chroma",
+ "Clarifai",
+ "Clickhouse",
+ "ClickhouseSettings",
+ "DashVector",
+ "DatabricksVectorSearch",
+ "DeepLake",
+ "Dingo",
+ "DocArrayHnswSearch",
+ "DocArrayInMemorySearch",
+ "ElasticKnnSearch",
+ "ElasticVectorSearch",
+ "ElasticsearchStore",
+ "Epsilla",
+ "FAISS",
+ "Hologres",
+ "LanceDB",
+ "LLMRails",
+ "Marqo",
+ "MatchingEngine",
+ "Meilisearch",
+ "Milvus",
+ "MomentoVectorIndex",
+ "MongoDBAtlasVectorSearch",
+ "MyScale",
+ "MyScaleSettings",
+ "Neo4jVector",
+ "OpenSearchVectorSearch",
+ "PGEmbedding",
+ "PGVector",
+ "Pinecone",
+ "Qdrant",
+ "Redis",
+ "Rockset",
+ "SKLearnVectorStore",
+ "ScaNN",
+ "SemaDB",
+ "SingleStoreDB",
+ "SQLiteVSS",
+ "StarRocks",
+ "SupabaseVectorStore",
+ "Tair",
+ "TileDB",
+ "Tigris",
+ "TimescaleVector",
+ "Typesense",
+ "USearch",
+ "Vald",
+ "Vearch",
+ "Vectara",
+ "VespaStore",
+ "Weaviate",
+ "Yellowbrick",
+ "ZepVectorStore",
+ "Zilliz",
+ "TencentVectorDB",
+ "AzureCosmosDBVectorSearch",
+ "VectorStore",
+]
diff --git a/libs/community/langchain_community/vectorstores/alibabacloud_opensearch.py b/libs/community/langchain_community/vectorstores/alibabacloud_opensearch.py
new file mode 100644
index 00000000000..c9d5b85a6bd
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/alibabacloud_opensearch.py
@@ -0,0 +1,532 @@
+import json
+import logging
+import numbers
+from hashlib import sha1
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger()
+
+
+class AlibabaCloudOpenSearchSettings:
+ """Alibaba Cloud Opensearch` client configuration.
+
+ Attribute:
+ endpoint (str) : The endpoint of opensearch instance, You can find it
+ from the console of Alibaba Cloud OpenSearch.
+ instance_id (str) : The identify of opensearch instance, You can find
+ it from the console of Alibaba Cloud OpenSearch.
+ username (str) : The username specified when purchasing the instance.
+ password (str) : The password specified when purchasing the instanceοΌ
+ After the instance is created, you can modify it on the console.
+ tablename (str): The table name specified during instance configuration.
+ field_name_mapping (Dict) : Using field name mapping between opensearch
+ vector store and opensearch instance configuration table field names:
+ {
+ 'id': 'The id field name map of index document.',
+ 'document': 'The text field name map of index document.',
+ 'embedding': 'In the embedding field of the opensearch instance,
+ the values must be in float type and separated by separator,
+ default is comma.',
+ 'metadata_field_x': 'Metadata field mapping includes the mapped
+ field name and operator in the mapping value, separated by a comma
+ between the mapped field name and the operator.',
+ }
+ protocol (str): Communication Protocol between SDK and Server, default is http.
+ namespace (str) : The instance data will be partitioned based on the "namespace"
+ field,If the namespace is enabled, you need to specify the namespace field
+ name during initialization, Otherwise, the queries cannot be executed
+ correctly.
+ embedding_field_separator(str): Delimiter specified for writing vector
+ field data, default is comma.
+ output_fields: Specify the field list returned when invoking OpenSearch,
+ by default it is the value list of the field mapping field.
+ """
+
+ def __init__(
+ self,
+ endpoint: str,
+ instance_id: str,
+ username: str,
+ password: str,
+ table_name: str,
+ field_name_mapping: Dict[str, str],
+ protocol: str = "http",
+ namespace: str = "",
+ embedding_field_separator: str = ",",
+ output_fields: Optional[List[str]] = None,
+ ) -> None:
+ self.endpoint = endpoint
+ self.instance_id = instance_id
+ self.protocol = protocol
+ self.username = username
+ self.password = password
+ self.namespace = namespace
+ self.table_name = table_name
+ self.opt_table_name = "_".join([self.instance_id, self.table_name])
+ self.field_name_mapping = field_name_mapping
+ self.embedding_field_separator = embedding_field_separator
+ if output_fields is None:
+ self.output_fields = [
+ field.split(",")[0] for field in self.field_name_mapping.values()
+ ]
+ self.inverse_field_name_mapping: Dict[str, str] = {}
+ for key, value in self.field_name_mapping.items():
+ self.inverse_field_name_mapping[value.split(",")[0]] = key
+
+ def __getitem__(self, item: str) -> Any:
+ return getattr(self, item)
+
+
+def create_metadata(fields: Dict[str, Any]) -> Dict[str, Any]:
+ """Create metadata from fields.
+
+ Args:
+ fields: The fields of the document. The fields must be a dict.
+
+ Returns:
+ metadata: The metadata of the document. The metadata must be a dict.
+ """
+ metadata: Dict[str, Any] = {}
+ for key, value in fields.items():
+ if key == "id" or key == "document" or key == "embedding":
+ continue
+ metadata[key] = value
+ return metadata
+
+
+class AlibabaCloudOpenSearch(VectorStore):
+ """`Alibaba Cloud OpenSearch` vector store."""
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ config: AlibabaCloudOpenSearchSettings,
+ **kwargs: Any,
+ ) -> None:
+ try:
+ from alibabacloud_ha3engine_vector import client, models
+ from alibabacloud_tea_util import models as util_models
+ except ImportError:
+ raise ImportError(
+ "Could not import alibaba cloud opensearch python package. "
+ "Please install it with `pip install alibabacloud-ha3engine-vector`."
+ )
+
+ self.config = config
+ self.embedding = embedding
+
+ self.runtime = util_models.RuntimeOptions(
+ connect_timeout=5000,
+ read_timeout=10000,
+ autoretry=False,
+ ignore_ssl=False,
+ max_idle_conns=50,
+ )
+ self.ha3_engine_client = client.Client(
+ models.Config(
+ endpoint=config.endpoint,
+ instance_id=config.instance_id,
+ protocol=config.protocol,
+ access_user_name=config.username,
+ access_pass_word=config.password,
+ )
+ )
+
+ self.options_headers: Dict[str, str] = {}
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Insert documents into the instance..
+ Args:
+ texts: The text segments to be inserted into the vector storage,
+ should not be empty.
+ metadatas: Metadata information.
+ Returns:
+ id_list: List of document IDs.
+ """
+
+ def _upsert(push_doc_list: List[Dict]) -> List[str]:
+ if push_doc_list is None or len(push_doc_list) == 0:
+ return []
+ try:
+ push_request = models.PushDocumentsRequest(
+ self.options_headers, push_doc_list
+ )
+ push_response = self.ha3_engine_client.push_documents(
+ self.config.opt_table_name, field_name_map["id"], push_request
+ )
+ json_response = json.loads(push_response.body)
+ if json_response["status"] == "OK":
+ return [
+ push_doc["fields"][field_name_map["id"]]
+ for push_doc in push_doc_list
+ ]
+ return []
+ except Exception as e:
+ logger.error(
+ f"add doc to endpoint:{self.config.endpoint} "
+ f"instance_id:{self.config.instance_id} failed.",
+ e,
+ )
+ raise e
+
+ from alibabacloud_ha3engine_vector import models
+
+ id_list = [sha1(t.encode("utf-8")).hexdigest() for t in texts]
+ embeddings = self.embedding.embed_documents(list(texts))
+ metadatas = metadatas or [{} for _ in texts]
+ field_name_map = self.config.field_name_mapping
+ add_doc_list = []
+ text_list = list(texts)
+ for idx, doc_id in enumerate(id_list):
+ embedding = embeddings[idx] if idx < len(embeddings) else None
+ metadata = metadatas[idx] if idx < len(metadatas) else None
+ text = text_list[idx] if idx < len(text_list) else None
+ add_doc: Dict[str, Any] = dict()
+ add_doc_fields: Dict[str, Any] = dict()
+ add_doc_fields.__setitem__(field_name_map["id"], doc_id)
+ add_doc_fields.__setitem__(field_name_map["document"], text)
+ if embedding is not None:
+ add_doc_fields.__setitem__(
+ field_name_map["embedding"],
+ self.config.embedding_field_separator.join(
+ str(unit) for unit in embedding
+ ),
+ )
+ if metadata is not None:
+ for md_key, md_value in metadata.items():
+ add_doc_fields.__setitem__(
+ field_name_map[md_key].split(",")[0], md_value
+ )
+ add_doc.__setitem__("fields", add_doc_fields)
+ add_doc.__setitem__("cmd", "add")
+ add_doc_list.append(add_doc)
+ return _upsert(add_doc_list)
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ search_filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform similarity retrieval based on text.
+ Args:
+ query: Vectorize text for retrieval.οΌshould not be empty.
+ k: top n.
+ search_filter: Additional filtering conditions.
+ Returns:
+ document_list: List of documents.
+ """
+ embedding = self.embedding.embed_query(query)
+ return self.create_results(
+ self.inner_embedding_query(
+ embedding=embedding, search_filter=search_filter, k=k
+ )
+ )
+
+ def similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ search_filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Perform similarity retrieval based on text with scores.
+ Args:
+ query: Vectorize text for retrieval.οΌshould not be empty.
+ k: top n.
+ search_filter: Additional filtering conditions.
+ Returns:
+ document_list: List of documents.
+ """
+ embedding: List[float] = self.embedding.embed_query(query)
+ return self.create_results_with_score(
+ self.inner_embedding_query(
+ embedding=embedding, search_filter=search_filter, k=k
+ )
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ search_filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform retrieval directly using vectors.
+ Args:
+ embedding: vectors.
+ k: top n.
+ search_filter: Additional filtering conditions.
+ Returns:
+ document_list: List of documents.
+ """
+ return self.create_results(
+ self.inner_embedding_query(
+ embedding=embedding, search_filter=search_filter, k=k
+ )
+ )
+
+ def inner_embedding_query(
+ self,
+ embedding: List[float],
+ search_filter: Optional[Dict[str, Any]] = None,
+ k: int = 4,
+ ) -> Dict[str, Any]:
+ def generate_filter_query() -> str:
+ if search_filter is None:
+ return ""
+ filter_clause = " AND ".join(
+ [
+ create_filter(md_key, md_value)
+ for md_key, md_value in search_filter.items()
+ ]
+ )
+ return filter_clause
+
+ def create_filter(md_key: str, md_value: Any) -> str:
+ md_filter_expr = self.config.field_name_mapping[md_key]
+ if md_filter_expr is None:
+ return ""
+ expr = md_filter_expr.split(",")
+ if len(expr) != 2:
+ logger.error(
+ f"filter {md_filter_expr} express is not correct, "
+ f"must contain mapping field and operator."
+ )
+ return ""
+ md_filter_key = expr[0].strip()
+ md_filter_operator = expr[1].strip()
+ if isinstance(md_value, numbers.Number):
+ return f"{md_filter_key} {md_filter_operator} {md_value}"
+ return f'{md_filter_key}{md_filter_operator}"{md_value}"'
+
+ def search_data() -> Dict[str, Any]:
+ request = QueryRequest(
+ table_name=self.config.table_name,
+ namespace=self.config.namespace,
+ vector=embedding,
+ include_vector=True,
+ output_fields=self.config.output_fields,
+ filter=generate_filter_query(),
+ top_k=k,
+ )
+
+ query_result = self.ha3_engine_client.query(request)
+ return json.loads(query_result.body)
+
+ from alibabacloud_ha3engine_vector.models import QueryRequest
+
+ try:
+ json_response = search_data()
+ if (
+ "errorCode" in json_response
+ and "errorMsg" in json_response
+ and len(json_response["errorMsg"]) > 0
+ ):
+ logger.error(
+ f"query {self.config.endpoint} {self.config.instance_id} "
+ f"failed:{json_response['errorMsg']}."
+ )
+ else:
+ return json_response
+ except Exception as e:
+ logger.error(
+ f"query instance endpoint:{self.config.endpoint} "
+ f"instance_id:{self.config.instance_id} failed.",
+ e,
+ )
+ return {}
+
+ def create_results(self, json_result: Dict[str, Any]) -> List[Document]:
+ """Assemble documents."""
+ items = json_result["result"]
+ query_result_list: List[Document] = []
+ for item in items:
+ if (
+ "fields" not in item
+ or self.config.field_name_mapping["document"] not in item["fields"]
+ ):
+ query_result_list.append(Document())
+ else:
+ fields = item["fields"]
+ query_result_list.append(
+ Document(
+ page_content=fields[self.config.field_name_mapping["document"]],
+ metadata=self.create_inverse_metadata(fields),
+ )
+ )
+ return query_result_list
+
+ def create_inverse_metadata(self, fields: Dict[str, Any]) -> Dict[str, Any]:
+ """Create metadata from fields.
+
+ Args:
+ fields: The fields of the document. The fields must be a dict.
+
+ Returns:
+ metadata: The metadata of the document. The metadata must be a dict.
+ """
+ metadata: Dict[str, Any] = {}
+ for key, value in fields.items():
+ if key == "id" or key == "document" or key == "embedding":
+ continue
+ metadata[self.config.inverse_field_name_mapping[key]] = value
+ return metadata
+
+ def create_results_with_score(
+ self, json_result: Dict[str, Any]
+ ) -> List[Tuple[Document, float]]:
+ """Parsing the returned results with scores.
+ Args:
+ json_result: Results from OpenSearch query.
+ Returns:
+ query_result_list: Results with scores.
+ """
+ items = json_result["result"]
+ query_result_list: List[Tuple[Document, float]] = []
+ for item in items:
+ fields = item["fields"]
+ query_result_list.append(
+ (
+ Document(
+ page_content=fields[self.config.field_name_mapping["document"]],
+ metadata=self.create_inverse_metadata(fields),
+ ),
+ float(item["score"]),
+ )
+ )
+ return query_result_list
+
+ def delete_documents_with_texts(self, texts: List[str]) -> bool:
+ """Delete documents based on their page content.
+
+ Args:
+ texts: List of document page content.
+ Returns:
+ Whether the deletion was successful or not.
+ """
+ id_list = [sha1(t.encode("utf-8")).hexdigest() for t in texts]
+ return self.delete_documents_with_document_id(id_list)
+
+ def delete_documents_with_document_id(self, id_list: List[str]) -> bool:
+ """Delete documents based on their IDs.
+
+ Args:
+ id_list: List of document IDs.
+ Returns:
+ Whether the deletion was successful or not.
+ """
+ if id_list is None or len(id_list) == 0:
+ return True
+
+ from alibabacloud_ha3engine_vector import models
+
+ delete_doc_list = []
+ for doc_id in id_list:
+ delete_doc_list.append(
+ {
+ "fields": {self.config.field_name_mapping["id"]: doc_id},
+ "cmd": "delete",
+ }
+ )
+
+ delete_request = models.PushDocumentsRequest(
+ self.options_headers, delete_doc_list
+ )
+ try:
+ delete_response = self.ha3_engine_client.push_documents(
+ self.config.opt_table_name,
+ self.config.field_name_mapping["id"],
+ delete_request,
+ )
+ json_response = json.loads(delete_response.body)
+ return json_response["status"] == "OK"
+ except Exception as e:
+ logger.error(
+ f"delete doc from :{self.config.endpoint} "
+ f"instance_id:{self.config.instance_id} failed.",
+ e,
+ )
+ raise e
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ config: Optional[AlibabaCloudOpenSearchSettings] = None,
+ **kwargs: Any,
+ ) -> "AlibabaCloudOpenSearch":
+ """Create alibaba cloud opensearch vector store instance.
+
+ Args:
+ texts: The text segments to be inserted into the vector storage,
+ should not be empty.
+ embedding: Embedding function, Embedding function.
+ config: Alibaba OpenSearch instance configuration.
+ metadatas: Metadata information.
+ Returns:
+ AlibabaCloudOpenSearch: Alibaba cloud opensearch vector store instance.
+ """
+ if texts is None or len(texts) == 0:
+ raise Exception("the inserted text segments, should not be empty.")
+
+ if embedding is None:
+ raise Exception("the embeddings should not be empty.")
+
+ if config is None:
+ raise Exception("config should not be none.")
+
+ ctx = cls(embedding, config, **kwargs)
+ ctx.add_texts(texts=texts, metadatas=metadatas)
+ return ctx
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: List[Document],
+ embedding: Embeddings,
+ config: Optional[AlibabaCloudOpenSearchSettings] = None,
+ **kwargs: Any,
+ ) -> "AlibabaCloudOpenSearch":
+ """Create alibaba cloud opensearch vector store instance.
+
+ Args:
+ documents: Documents to be inserted into the vector storage,
+ should not be empty.
+ embedding: Embedding function, Embedding function.
+ config: Alibaba OpenSearch instance configuration.
+ ids: Specify the ID for the inserted document. If left empty, the ID will be
+ automatically generated based on the text content.
+ Returns:
+ AlibabaCloudOpenSearch: Alibaba cloud opensearch vector store instance.
+ """
+ if documents is None or len(documents) == 0:
+ raise Exception("the inserted documents, should not be empty.")
+
+ if embedding is None:
+ raise Exception("the embeddings should not be empty.")
+
+ if config is None:
+ raise Exception("config can't be none")
+
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+ return cls.from_texts(
+ texts=texts,
+ embedding=embedding,
+ metadatas=metadatas,
+ config=config,
+ **kwargs,
+ )
diff --git a/libs/community/langchain_community/vectorstores/analyticdb.py b/libs/community/langchain_community/vectorstores/analyticdb.py
new file mode 100644
index 00000000000..e72383eae47
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/analyticdb.py
@@ -0,0 +1,452 @@
+from __future__ import annotations
+
+import logging
+import uuid
+from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type
+
+from sqlalchemy import REAL, Column, String, Table, create_engine, insert, text
+from sqlalchemy.dialects.postgresql import ARRAY, JSON, TEXT
+
+try:
+ from sqlalchemy.orm import declarative_base
+except ImportError:
+ from sqlalchemy.ext.declarative import declarative_base
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore
+
+_LANGCHAIN_DEFAULT_EMBEDDING_DIM = 1536
+_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_document"
+
+Base = declarative_base() # type: Any
+
+
+class AnalyticDB(VectorStore):
+ """`AnalyticDB` (distributed PostgreSQL) vector store.
+
+ AnalyticDB is a distributed full postgresql syntax cloud-native database.
+ - `connection_string` is a postgres connection string.
+ - `embedding_function` any embedding function implementing
+ `langchain.embeddings.base.Embeddings` interface.
+ - `collection_name` is the name of the collection to use. (default: langchain)
+ - NOTE: This is not the name of the table, but the name of the collection.
+ The tables will be created when initializing the store (if not exists)
+ So, make sure the user has the right permissions to create tables.
+ - `pre_delete_collection` if True, will delete the collection if it exists.
+ (default: False)
+ - Useful for testing.
+
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ embedding_function: Embeddings,
+ embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ pre_delete_collection: bool = False,
+ logger: Optional[logging.Logger] = None,
+ engine_args: Optional[dict] = None,
+ ) -> None:
+ self.connection_string = connection_string
+ self.embedding_function = embedding_function
+ self.embedding_dimension = embedding_dimension
+ self.collection_name = collection_name
+ self.pre_delete_collection = pre_delete_collection
+ self.logger = logger or logging.getLogger(__name__)
+ self.__post_init__(engine_args)
+
+ def __post_init__(
+ self,
+ engine_args: Optional[dict] = None,
+ ) -> None:
+ """
+ Initialize the store.
+ """
+
+ _engine_args = engine_args or {}
+
+ if (
+ "pool_recycle" not in _engine_args
+ ): # Check if pool_recycle is not in _engine_args
+ _engine_args[
+ "pool_recycle"
+ ] = 3600 # Set pool_recycle to 3600s if not present
+
+ self.engine = create_engine(self.connection_string, **_engine_args)
+ self.create_collection()
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_function
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ return self._euclidean_relevance_score_fn
+
+ def create_table_if_not_exists(self) -> None:
+ # Define the dynamic table
+ Table(
+ self.collection_name,
+ Base.metadata,
+ Column("id", TEXT, primary_key=True, default=uuid.uuid4),
+ Column("embedding", ARRAY(REAL)),
+ Column("document", String, nullable=True),
+ Column("metadata", JSON, nullable=True),
+ extend_existing=True,
+ )
+ with self.engine.connect() as conn:
+ with conn.begin():
+ # Create the table
+ Base.metadata.create_all(conn)
+
+ # Check if the index exists
+ index_name = f"{self.collection_name}_embedding_idx"
+ index_query = text(
+ f"""
+ SELECT 1
+ FROM pg_indexes
+ WHERE indexname = '{index_name}';
+ """
+ )
+ result = conn.execute(index_query).scalar()
+
+ # Create the index if it doesn't exist
+ if not result:
+ index_statement = text(
+ f"""
+ CREATE INDEX {index_name}
+ ON {self.collection_name} USING ann(embedding)
+ WITH (
+ "dim" = {self.embedding_dimension},
+ "hnsw_m" = 100
+ );
+ """
+ )
+ conn.execute(index_statement)
+
+ def create_collection(self) -> None:
+ if self.pre_delete_collection:
+ self.delete_collection()
+ self.create_table_if_not_exists()
+
+ def delete_collection(self) -> None:
+ self.logger.debug("Trying to delete collection")
+ drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};")
+ with self.engine.connect() as conn:
+ with conn.begin():
+ conn.execute(drop_statement)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ batch_size: int = 500,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ embeddings = self.embedding_function.embed_documents(list(texts))
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ # Define the table schema
+ chunks_table = Table(
+ self.collection_name,
+ Base.metadata,
+ Column("id", TEXT, primary_key=True),
+ Column("embedding", ARRAY(REAL)),
+ Column("document", String, nullable=True),
+ Column("metadata", JSON, nullable=True),
+ extend_existing=True,
+ )
+
+ chunks_table_data = []
+ with self.engine.connect() as conn:
+ with conn.begin():
+ for document, metadata, chunk_id, embedding in zip(
+ texts, metadatas, ids, embeddings
+ ):
+ chunks_table_data.append(
+ {
+ "id": chunk_id,
+ "embedding": embedding,
+ "document": document,
+ "metadata": metadata,
+ }
+ )
+
+ # Execute the batch insert when the batch size is reached
+ if len(chunks_table_data) == batch_size:
+ conn.execute(insert(chunks_table).values(chunks_table_data))
+ # Clear the chunks_table_data list for the next batch
+ chunks_table_data.clear()
+
+ # Insert any remaining records that didn't make up a full batch
+ if chunks_table_data:
+ conn.execute(insert(chunks_table).values(chunks_table_data))
+
+ return ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search with AnalyticDB with distance.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ embedding = self.embedding_function.embed_query(text=query)
+ return self.similarity_search_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ )
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ embedding = self.embedding_function.embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return docs
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[dict] = None,
+ ) -> List[Tuple[Document, float]]:
+ # Add the filter if provided
+ try:
+ from sqlalchemy.engine import Row
+ except ImportError:
+ raise ImportError(
+ "Could not import Row from sqlalchemy.engine. "
+ "Please 'pip install sqlalchemy>=1.4'."
+ )
+
+ filter_condition = ""
+ if filter is not None:
+ conditions = [
+ f"metadata->>{key!r} = {value!r}" for key, value in filter.items()
+ ]
+ filter_condition = f"WHERE {' AND '.join(conditions)}"
+
+ # Define the base query
+ sql_query = f"""
+ SELECT *, l2_distance(embedding, :embedding) as distance
+ FROM {self.collection_name}
+ {filter_condition}
+ ORDER BY embedding <-> :embedding
+ LIMIT :k
+ """
+
+ # Set up the query parameters
+ params = {"embedding": embedding, "k": k}
+
+ # Execute the query and fetch the results
+ with self.engine.connect() as conn:
+ results: Sequence[Row] = conn.execute(text(sql_query), params).fetchall()
+
+ documents_with_scores = [
+ (
+ Document(
+ page_content=result.document,
+ metadata=result.metadata,
+ ),
+ result.distance if self.embedding_function is not None else None,
+ )
+ for result in results
+ ]
+ return documents_with_scores
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete by vector IDs.
+
+ Args:
+ ids: List of ids to delete.
+ """
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ # Define the table schema
+ chunks_table = Table(
+ self.collection_name,
+ Base.metadata,
+ Column("id", TEXT, primary_key=True),
+ Column("embedding", ARRAY(REAL)),
+ Column("document", String, nullable=True),
+ Column("metadata", JSON, nullable=True),
+ extend_existing=True,
+ )
+
+ try:
+ with self.engine.connect() as conn:
+ with conn.begin():
+ delete_condition = chunks_table.c.id.in_(ids)
+ conn.execute(chunks_table.delete().where(delete_condition))
+ return True
+ except Exception as e:
+ print("Delete operation failed:", str(e))
+ return False
+
+ @classmethod
+ def from_texts(
+ cls: Type[AnalyticDB],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ engine_args: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> AnalyticDB:
+ """
+ Return VectorStore initialized from texts and embeddings.
+ Postgres Connection string is required
+ Either pass it as a parameter
+ or set the PG_CONNECTION_STRING environment variable.
+ """
+
+ connection_string = cls.get_connection_string(kwargs)
+
+ store = cls(
+ connection_string=connection_string,
+ collection_name=collection_name,
+ embedding_function=embedding,
+ embedding_dimension=embedding_dimension,
+ pre_delete_collection=pre_delete_collection,
+ engine_args=engine_args,
+ )
+
+ store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
+ return store
+
+ @classmethod
+ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
+ connection_string: str = get_from_dict_or_env(
+ data=kwargs,
+ key="connection_string",
+ env_key="PG_CONNECTION_STRING",
+ )
+
+ if not connection_string:
+ raise ValueError(
+ "Postgres connection string is required"
+ "Either pass it as a parameter"
+ "or set the PG_CONNECTION_STRING environment variable."
+ )
+
+ return connection_string
+
+ @classmethod
+ def from_documents(
+ cls: Type[AnalyticDB],
+ documents: List[Document],
+ embedding: Embeddings,
+ embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ engine_args: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> AnalyticDB:
+ """
+ Return VectorStore initialized from documents and embeddings.
+ Postgres Connection string is required
+ Either pass it as a parameter
+ or set the PG_CONNECTION_STRING environment variable.
+ """
+
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+ connection_string = cls.get_connection_string(kwargs)
+
+ kwargs["connection_string"] = connection_string
+
+ return cls.from_texts(
+ texts=texts,
+ pre_delete_collection=pre_delete_collection,
+ embedding=embedding,
+ embedding_dimension=embedding_dimension,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ engine_args=engine_args,
+ **kwargs,
+ )
+
+ @classmethod
+ def connection_string_from_db_params(
+ cls,
+ driver: str,
+ host: str,
+ port: int,
+ database: str,
+ user: str,
+ password: str,
+ ) -> str:
+ """Return connection string from database parameters."""
+ return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
diff --git a/libs/community/langchain_community/vectorstores/annoy.py b/libs/community/langchain_community/vectorstores/annoy.py
new file mode 100644
index 00000000000..b797fcf9bda
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/annoy.py
@@ -0,0 +1,455 @@
+from __future__ import annotations
+
+import os
+import pickle
+import uuid
+from configparser import ConfigParser
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.docstore.base import Docstore
+from langchain_community.docstore.in_memory import InMemoryDocstore
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+INDEX_METRICS = frozenset(["angular", "euclidean", "manhattan", "hamming", "dot"])
+DEFAULT_METRIC = "angular"
+
+
+def dependable_annoy_import() -> Any:
+ """Import annoy if available, otherwise raise error."""
+ try:
+ import annoy
+ except ImportError:
+ raise ImportError(
+ "Could not import annoy python package. "
+ "Please install it with `pip install --user annoy` "
+ )
+ return annoy
+
+
+class Annoy(VectorStore):
+ """`Annoy` vector store.
+
+ To use, you should have the ``annoy`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Annoy
+ db = Annoy(embedding_function, index, docstore, index_to_docstore_id)
+
+ """
+
+ def __init__(
+ self,
+ embedding_function: Callable,
+ index: Any,
+ metric: str,
+ docstore: Docstore,
+ index_to_docstore_id: Dict[int, str],
+ ):
+ """Initialize with necessary components."""
+ self.embedding_function = embedding_function
+ self.index = index
+ self.metric = metric
+ self.docstore = docstore
+ self.index_to_docstore_id = index_to_docstore_id
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ # TODO: Accept embedding object directly
+ return None
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ raise NotImplementedError(
+ "Annoy does not allow to add new data once the index is build."
+ )
+
+ def process_index_results(
+ self, idxs: List[int], dists: List[float]
+ ) -> List[Tuple[Document, float]]:
+ """Turns annoy results into a list of documents and scores.
+
+ Args:
+ idxs: List of indices of the documents in the index.
+ dists: List of distances of the documents in the index.
+ Returns:
+ List of Documents and scores.
+ """
+ docs = []
+ for idx, dist in zip(idxs, dists):
+ _id = self.index_to_docstore_id[idx]
+ doc = self.docstore.search(_id)
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {_id}, got {doc}")
+ docs.append((doc, dist))
+ return docs
+
+ def similarity_search_with_score_by_vector(
+ self, embedding: List[float], k: int = 4, search_k: int = -1
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ search_k: inspect up to search_k nodes which defaults
+ to n_trees * n if not provided
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ idxs, dists = self.index.get_nns_by_vector(
+ embedding, k, search_k=search_k, include_distances=True
+ )
+ return self.process_index_results(idxs, dists)
+
+ def similarity_search_with_score_by_index(
+ self, docstore_index: int, k: int = 4, search_k: int = -1
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ search_k: inspect up to search_k nodes which defaults
+ to n_trees * n if not provided
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ idxs, dists = self.index.get_nns_by_item(
+ docstore_index, k, search_k=search_k, include_distances=True
+ )
+ return self.process_index_results(idxs, dists)
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, search_k: int = -1
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ search_k: inspect up to search_k nodes which defaults
+ to n_trees * n if not provided
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ embedding = self.embedding_function(query)
+ docs = self.similarity_search_with_score_by_vector(embedding, k, search_k)
+ return docs
+
+ def similarity_search_by_vector(
+ self, embedding: List[float], k: int = 4, search_k: int = -1, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ search_k: inspect up to search_k nodes which defaults
+ to n_trees * n if not provided
+
+ Returns:
+ List of Documents most similar to the embedding.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding, k, search_k
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_by_index(
+ self, docstore_index: int, k: int = 4, search_k: int = -1, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to docstore_index.
+
+ Args:
+ docstore_index: Index of document in docstore
+ k: Number of Documents to return. Defaults to 4.
+ search_k: inspect up to search_k nodes which defaults
+ to n_trees * n if not provided
+
+ Returns:
+ List of Documents most similar to the embedding.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_index(
+ docstore_index, k, search_k
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search(
+ self, query: str, k: int = 4, search_k: int = -1, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ search_k: inspect up to search_k nodes which defaults
+ to n_trees * n if not provided
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ docs_and_scores = self.similarity_search_with_score(query, k, search_k)
+ return [doc for doc, _ in docs_and_scores]
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ k: Number of Documents to return. Defaults to 4.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ idxs = self.index.get_nns_by_vector(
+ embedding, fetch_k, search_k=-1, include_distances=False
+ )
+ embeddings = [self.index.get_item_vector(i) for i in idxs]
+ mmr_selected = maximal_marginal_relevance(
+ np.array([embedding], dtype=np.float32),
+ embeddings,
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ # ignore the -1's if not enough docs are returned/indexed
+ selected_indices = [idxs[i] for i in mmr_selected if i != -1]
+
+ docs = []
+ for i in selected_indices:
+ _id = self.index_to_docstore_id[i]
+ doc = self.docstore.search(_id)
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {_id}, got {doc}")
+ docs.append(doc)
+ return docs
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding = self.embedding_function(query)
+ docs = self.max_marginal_relevance_search_by_vector(
+ embedding, k, fetch_k, lambda_mult=lambda_mult
+ )
+ return docs
+
+ @classmethod
+ def __from(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ metric: str = DEFAULT_METRIC,
+ trees: int = 100,
+ n_jobs: int = -1,
+ **kwargs: Any,
+ ) -> Annoy:
+ if metric not in INDEX_METRICS:
+ raise ValueError(
+ (
+ f"Unsupported distance metric: {metric}. "
+ f"Expected one of {list(INDEX_METRICS)}"
+ )
+ )
+ annoy = dependable_annoy_import()
+ if not embeddings:
+ raise ValueError("embeddings must be provided to build AnnoyIndex")
+ f = len(embeddings[0])
+ index = annoy.AnnoyIndex(f, metric=metric)
+ for i, emb in enumerate(embeddings):
+ index.add_item(i, emb)
+ index.build(trees, n_jobs=n_jobs)
+
+ documents = []
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+ documents.append(Document(page_content=text, metadata=metadata))
+ index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
+ docstore = InMemoryDocstore(
+ {index_to_id[i]: doc for i, doc in enumerate(documents)}
+ )
+ return cls(embedding.embed_query, index, metric, docstore, index_to_id)
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ metric: str = DEFAULT_METRIC,
+ trees: int = 100,
+ n_jobs: int = -1,
+ **kwargs: Any,
+ ) -> Annoy:
+ """Construct Annoy wrapper from raw documents.
+
+ Args:
+ texts: List of documents to index.
+ embedding: Embedding function to use.
+ metadatas: List of metadata dictionaries to associate with documents.
+ metric: Metric to use for indexing. Defaults to "angular".
+ trees: Number of trees to use for indexing. Defaults to 100.
+ n_jobs: Number of jobs to use for indexing. Defaults to -1.
+
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Creates an in memory docstore
+ 3. Initializes the Annoy database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Annoy
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ index = Annoy.from_texts(texts, embeddings)
+ """
+ embeddings = embedding.embed_documents(texts)
+ return cls.__from(
+ texts, embeddings, embedding, metadatas, metric, trees, n_jobs, **kwargs
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ metric: str = DEFAULT_METRIC,
+ trees: int = 100,
+ n_jobs: int = -1,
+ **kwargs: Any,
+ ) -> Annoy:
+ """Construct Annoy wrapper from embeddings.
+
+ Args:
+ text_embeddings: List of tuples of (text, embedding)
+ embedding: Embedding function to use.
+ metadatas: List of metadata dictionaries to associate with documents.
+ metric: Metric to use for indexing. Defaults to "angular".
+ trees: Number of trees to use for indexing. Defaults to 100.
+ n_jobs: Number of jobs to use for indexing. Defaults to -1
+
+ This is a user friendly interface that:
+ 1. Creates an in memory docstore with provided embeddings
+ 2. Initializes the Annoy database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Annoy
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ db = Annoy.from_embeddings(text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+
+ return cls.__from(
+ texts, embeddings, embedding, metadatas, metric, trees, n_jobs, **kwargs
+ )
+
+ def save_local(self, folder_path: str, prefault: bool = False) -> None:
+ """Save Annoy index, docstore, and index_to_docstore_id to disk.
+
+ Args:
+ folder_path: folder path to save index, docstore,
+ and index_to_docstore_id to.
+ prefault: Whether to pre-load the index into memory.
+ """
+ path = Path(folder_path)
+ os.makedirs(path, exist_ok=True)
+ # save index, index config, docstore and index_to_docstore_id
+ config_object = ConfigParser()
+ config_object["ANNOY"] = {
+ "f": self.index.f,
+ "metric": self.metric,
+ }
+ self.index.save(str(path / "index.annoy"), prefault=prefault)
+ with open(path / "index.pkl", "wb") as file:
+ pickle.dump((self.docstore, self.index_to_docstore_id, config_object), file)
+
+ @classmethod
+ def load_local(
+ cls,
+ folder_path: str,
+ embeddings: Embeddings,
+ ) -> Annoy:
+ """Load Annoy index, docstore, and index_to_docstore_id to disk.
+
+ Args:
+ folder_path: folder path to load index, docstore,
+ and index_to_docstore_id from.
+ embeddings: Embeddings to use when generating queries.
+ """
+ path = Path(folder_path)
+ # load index separately since it is not picklable
+ annoy = dependable_annoy_import()
+ # load docstore and index_to_docstore_id
+ with open(path / "index.pkl", "rb") as file:
+ docstore, index_to_docstore_id, config_object = pickle.load(file)
+
+ f = int(config_object["ANNOY"]["f"])
+ metric = config_object["ANNOY"]["metric"]
+
+ index = annoy.AnnoyIndex(f, metric=metric)
+ index.load(str(path / "index.annoy"))
+
+ return cls(
+ embeddings.embed_query, index, metric, docstore, index_to_docstore_id
+ )
diff --git a/libs/community/langchain_community/vectorstores/astradb.py b/libs/community/langchain_community/vectorstores/astradb.py
new file mode 100644
index 00000000000..efb544f9afd
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/astradb.py
@@ -0,0 +1,776 @@
+from __future__ import annotations
+
+import uuid
+import warnings
+from concurrent.futures import ThreadPoolExecutor
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils.iter import batch_iterate
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+ADBVST = TypeVar("ADBVST", bound="AstraDB")
+T = TypeVar("T")
+U = TypeVar("U")
+DocDict = Dict[str, Any] # dicts expressing entries to insert
+
+# Batch/concurrency default values (if parameters not provided):
+# Size of batches for bulk insertions:
+# (20 is the max batch size for the HTTP API at the time of writing)
+DEFAULT_BATCH_SIZE = 20
+# Number of threads to insert batches concurrently:
+DEFAULT_BULK_INSERT_BATCH_CONCURRENCY = 16
+# Number of threads in a batch to insert pre-existing entries:
+DEFAULT_BULK_INSERT_OVERWRITE_CONCURRENCY = 10
+# Number of threads (for deleting multiple rows concurrently):
+DEFAULT_BULK_DELETE_CONCURRENCY = 20
+
+
+def _unique_list(lst: List[T], key: Callable[[T], U]) -> List[T]:
+ visited_keys: Set[U] = set()
+ new_lst = []
+ for item in lst:
+ item_key = key(item)
+ if item_key not in visited_keys:
+ visited_keys.add(item_key)
+ new_lst.append(item)
+ return new_lst
+
+
+class AstraDB(VectorStore):
+ """Wrapper around DataStax Astra DB for vector-store workloads.
+
+ To use it, you need a recent installation of the `astrapy` library
+ and an Astra DB cloud database.
+
+ For quickstart and details, visit:
+ docs.datastax.com/en/astra/home/astra.html
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import AstraDB
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ vectorstore = AstraDB(
+ embedding=embeddings,
+ collection_name="my_store",
+ token="AstraCS:...",
+ api_endpoint="https://-us-east1.apps.astra.datastax.com"
+ )
+
+ vectorstore.add_texts(["Giraffes", "All good here"])
+ results = vectorstore.similarity_search("Everything's ok", k=1)
+
+ Constructor Args (only keyword-arguments accepted):
+ embedding (Embeddings): embedding function to use.
+ collection_name (str): name of the Astra DB collection to create/use.
+ token (Optional[str]): API token for Astra DB usage.
+ api_endpoint (Optional[str]): full URL to the API endpoint,
+ such as "https://-us-east1.apps.astra.datastax.com".
+ astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
+ you can pass an already-created 'astrapy.db.AstraDB' instance.
+ namespace (Optional[str]): namespace (aka keyspace) where the
+ collection is created. Defaults to the database's "default namespace".
+ metric (Optional[str]): similarity function to use out of those
+ available in Astra DB. If left out, it will use Astra DB API's
+ defaults (i.e. "cosine" - but, for performance reasons,
+ "dot_product" is suggested if embeddings are normalized to one).
+
+ Advanced arguments (coming with sensible defaults):
+ batch_size (Optional[int]): Size of batches for bulk insertions.
+ bulk_insert_batch_concurrency (Optional[int]): Number of threads
+ to insert batches concurrently.
+ bulk_insert_overwrite_concurrency (Optional[int]): Number of
+ threads in a batch to insert pre-existing entries.
+ bulk_delete_concurrency (Optional[int]): Number of threads
+ (for deleting multiple rows concurrently).
+ pre_delete_collection (Optional[bool]): whether to delete the collection
+ before creating it. If False and the collection already exists,
+ the collection will be used as is.
+
+ A note on concurrency: as a rule of thumb, on a typical client machine
+ it is suggested to keep the quantity
+ bulk_insert_batch_concurrency * bulk_insert_overwrite_concurrency
+ much below 1000 to avoid exhausting the client multithreading/networking
+ resources. The hardcoded defaults are somewhat conservative to meet
+ most machines' specs, but a sensible choice to test may be:
+ bulk_insert_batch_concurrency = 80
+ bulk_insert_overwrite_concurrency = 10
+ A bit of experimentation is required to nail the best results here,
+ depending on both the machine/network specs and the expected workload
+ (specifically, how often a write is an update of an existing id).
+ Remember you can pass concurrency settings to individual calls to
+ add_texts and add_documents as well.
+ """
+
+ @staticmethod
+ def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]:
+ if filter_dict is None:
+ return {}
+ else:
+ return {f"metadata.{mdk}": mdv for mdk, mdv in filter_dict.items()}
+
+ def __init__(
+ self,
+ *,
+ embedding: Embeddings,
+ collection_name: str,
+ token: Optional[str] = None,
+ api_endpoint: Optional[str] = None,
+ astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
+ namespace: Optional[str] = None,
+ metric: Optional[str] = None,
+ batch_size: Optional[int] = None,
+ bulk_insert_batch_concurrency: Optional[int] = None,
+ bulk_insert_overwrite_concurrency: Optional[int] = None,
+ bulk_delete_concurrency: Optional[int] = None,
+ pre_delete_collection: bool = False,
+ ) -> None:
+ """
+ Create an AstraDB vector store object. See class docstring for help.
+ """
+ try:
+ from astrapy.db import (
+ AstraDB as LibAstraDB,
+ )
+ from astrapy.db import (
+ AstraDBCollection as LibAstraDBCollection,
+ )
+ except (ImportError, ModuleNotFoundError):
+ raise ImportError(
+ "Could not import a recent astrapy python package. "
+ "Please install it with `pip install --upgrade astrapy`."
+ )
+
+ # Conflicting-arg checks:
+ if astra_db_client is not None:
+ if token is not None or api_endpoint is not None:
+ raise ValueError(
+ "You cannot pass 'astra_db_client' to AstraDB if passing "
+ "'token' and 'api_endpoint'."
+ )
+
+ self.embedding = embedding
+ self.collection_name = collection_name
+ self.token = token
+ self.api_endpoint = api_endpoint
+ self.namespace = namespace
+ # Concurrency settings
+ self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
+ self.bulk_insert_batch_concurrency: int = (
+ bulk_insert_batch_concurrency or DEFAULT_BULK_INSERT_BATCH_CONCURRENCY
+ )
+ self.bulk_insert_overwrite_concurrency: int = (
+ bulk_insert_overwrite_concurrency
+ or DEFAULT_BULK_INSERT_OVERWRITE_CONCURRENCY
+ )
+ self.bulk_delete_concurrency: int = (
+ bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
+ )
+ # "vector-related" settings
+ self._embedding_dimension: Optional[int] = None
+ self.metric = metric
+
+ if astra_db_client is not None:
+ self.astra_db = astra_db_client
+ else:
+ self.astra_db = LibAstraDB(
+ token=self.token,
+ api_endpoint=self.api_endpoint,
+ namespace=self.namespace,
+ )
+ if not pre_delete_collection:
+ self._provision_collection()
+ else:
+ self.clear()
+
+ self.collection = LibAstraDBCollection(
+ collection_name=self.collection_name,
+ astra_db=self.astra_db,
+ )
+
+ def _get_embedding_dimension(self) -> int:
+ if self._embedding_dimension is None:
+ self._embedding_dimension = len(
+ self.embedding.embed_query("This is a sample sentence.")
+ )
+ return self._embedding_dimension
+
+ def _drop_collection(self) -> None:
+ """
+ Drop the collection from storage.
+
+ This is meant as an internal-usage method, no members
+ are set other than actual deletion on the backend.
+ """
+ _ = self.astra_db.delete_collection(
+ collection_name=self.collection_name,
+ )
+ return None
+
+ def _provision_collection(self) -> None:
+ """
+ Run the API invocation to create the collection on the backend.
+
+ Internal-usage method, no object members are set,
+ other than working on the underlying actual storage.
+ """
+ _ = self.astra_db.create_collection(
+ dimension=self._get_embedding_dimension(),
+ collection_name=self.collection_name,
+ metric=self.metric,
+ )
+ return None
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding
+
+ @staticmethod
+ def _dont_flip_the_cos_score(similarity0to1: float) -> float:
+ """Keep similarity from client unchanged ad it's in [0:1] already."""
+ return similarity0to1
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The underlying API calls already returns a "score proper",
+ i.e. one in [0, 1] where higher means more *similar*,
+ so here the final score transformation is not reversing the interval:
+ """
+ return self._dont_flip_the_cos_score
+
+ def clear(self) -> None:
+ """Empty the collection of all its stored entries."""
+ self._drop_collection()
+ self._provision_collection()
+ return None
+
+ def delete_by_document_id(self, document_id: str) -> bool:
+ """
+ Remove a single document from the store, given its document_id (str).
+ Return True if a document has indeed been deleted, False if ID not found.
+ """
+ deletion_response = self.collection.delete(document_id)
+ return ((deletion_response or {}).get("status") or {}).get(
+ "deletedCount", 0
+ ) == 1
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ concurrency: Optional[int] = None,
+ **kwargs: Any,
+ ) -> Optional[bool]:
+ """Delete by vector ids.
+
+ Args:
+ ids (Optional[List[str]]): List of ids to delete.
+ concurrency (Optional[int]): max number of threads issuing
+ single-doc delete requests. Defaults to instance-level setting.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+
+ if kwargs:
+ warnings.warn(
+ "Method 'delete' of AstraDB vector store invoked with "
+ f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
+ "which will be ignored."
+ )
+
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ _max_workers = concurrency or self.bulk_delete_concurrency
+ with ThreadPoolExecutor(max_workers=_max_workers) as tpe:
+ _ = list(
+ tpe.map(
+ self.delete_by_document_id,
+ ids,
+ )
+ )
+ return True
+
+ def delete_collection(self) -> None:
+ """
+ Completely delete the collection from the database (as opposed
+ to 'clear()', which empties it only).
+ Stored data is lost and unrecoverable, resources are freed.
+ Use with caution.
+ """
+ self._drop_collection()
+ return None
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ *,
+ batch_size: Optional[int] = None,
+ batch_concurrency: Optional[int] = None,
+ overwrite_concurrency: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run texts through the embeddings and add them to the vectorstore.
+
+ If passing explicit ids, those entries whose id is in the store already
+ will be replaced.
+
+ Args:
+ texts (Iterable[str]): Texts to add to the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ ids (Optional[List[str]], optional): Optional list of ids.
+ batch_size (Optional[int]): Number of documents in each API call.
+ Check the underlying Astra DB HTTP API specs for the max value
+ (20 at the time of writing this). If not provided, defaults
+ to the instance-level setting.
+ batch_concurrency (Optional[int]): number of threads to process
+ insertion batches concurrently. Defaults to instance-level
+ setting if not provided.
+ overwrite_concurrency (Optional[int]): number of threads to process
+ pre-existing documents in each batch (which require individual
+ API calls). Defaults to instance-level setting if not provided.
+
+ A note on metadata: there are constraints on the allowed field names
+ in this dictionary, coming from the underlying Astra DB API.
+ For instance, the `$` (dollar sign) cannot be used in the dict keys.
+ See this document for details:
+ docs.datastax.com/en/astra-serverless/docs/develop/dev-with-json.html
+
+ Returns:
+ List[str]: List of ids of the added texts.
+ """
+
+ if kwargs:
+ warnings.warn(
+ "Method 'add_texts' of AstraDB vector store invoked with "
+ f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
+ "which will be ignored."
+ )
+
+ _texts = list(texts)
+ if ids is None:
+ ids = [uuid.uuid4().hex for _ in _texts]
+ if metadatas is None:
+ metadatas = [{} for _ in _texts]
+ #
+ embedding_vectors = self.embedding.embed_documents(_texts)
+
+ documents_to_insert = [
+ {
+ "content": b_txt,
+ "_id": b_id,
+ "$vector": b_emb,
+ "metadata": b_md,
+ }
+ for b_txt, b_emb, b_id, b_md in zip(
+ _texts,
+ embedding_vectors,
+ ids,
+ metadatas,
+ )
+ ]
+ # make unique by id, keeping the last
+ uniqued_documents_to_insert = _unique_list(
+ documents_to_insert[::-1],
+ lambda document: document["_id"],
+ )[::-1]
+
+ all_ids = []
+
+ def _handle_batch(document_batch: List[DocDict]) -> List[str]:
+ im_result = self.collection.insert_many(
+ documents=document_batch,
+ options={"ordered": False},
+ partial_failures_allowed=True,
+ )
+ if "status" not in im_result:
+ raise ValueError(
+ f"API Exception while running bulk insertion: {str(im_result)}"
+ )
+
+ batch_inserted = im_result["status"]["insertedIds"]
+ # estimation of the preexisting documents that failed
+ missed_inserted_ids = {
+ document["_id"] for document in document_batch
+ } - set(batch_inserted)
+ errors = im_result.get("errors", [])
+ # careful for other sources of error other than "doc already exists"
+ num_errors = len(errors)
+ unexpected_errors = any(
+ error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors
+ )
+ if num_errors != len(missed_inserted_ids) or unexpected_errors:
+ raise ValueError(
+ f"API Exception while running bulk insertion: {str(errors)}"
+ )
+
+ # deal with the missing insertions as upserts
+ missing_from_batch = [
+ document
+ for document in document_batch
+ if document["_id"] in missed_inserted_ids
+ ]
+
+ def _handle_missing_document(missing_document: DocDict) -> str:
+ replacement_result = self.collection.find_one_and_replace(
+ filter={"_id": missing_document["_id"]},
+ replacement=missing_document,
+ )
+ return replacement_result["data"]["document"]["_id"]
+
+ _u_max_workers = (
+ overwrite_concurrency or self.bulk_insert_overwrite_concurrency
+ )
+ with ThreadPoolExecutor(max_workers=_u_max_workers) as tpe2:
+ batch_replaced = list(
+ tpe2.map(
+ _handle_missing_document,
+ missing_from_batch,
+ )
+ )
+
+ upsert_ids = batch_inserted + batch_replaced
+ return upsert_ids
+
+ _b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency
+ with ThreadPoolExecutor(max_workers=_b_max_workers) as tpe:
+ all_ids_nested = tpe.map(
+ _handle_batch,
+ batch_iterate(
+ batch_size or self.batch_size,
+ uniqued_documents_to_insert,
+ ),
+ )
+
+ all_ids = [iid for id_list in all_ids_nested for iid in id_list]
+
+ return all_ids
+
+ def similarity_search_with_score_id_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Tuple[Document, float, str]]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding (str): Embedding to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ Returns:
+ List of (Document, score, id), the most similar to the query vector.
+ """
+ metadata_parameter = self._filter_to_metadata(filter)
+ #
+ hits = list(
+ self.collection.paginated_find(
+ filter=metadata_parameter,
+ sort={"$vector": embedding},
+ options={"limit": k, "includeSimilarity": True},
+ projection={
+ "_id": 1,
+ "content": 1,
+ "metadata": 1,
+ },
+ )
+ )
+ #
+ return [
+ (
+ Document(
+ page_content=hit["content"],
+ metadata=hit["metadata"],
+ ),
+ hit["$similarity"],
+ hit["_id"],
+ )
+ for hit in hits
+ ]
+
+ def similarity_search_with_score_id(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Tuple[Document, float, str]]:
+ embedding_vector = self.embedding.embed_query(query)
+ return self.similarity_search_with_score_id_by_vector(
+ embedding=embedding_vector,
+ k=k,
+ filter=filter,
+ )
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding (str): Embedding to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ Returns:
+ List of (Document, score), the most similar to the query vector.
+ """
+ return [
+ (doc, score)
+ for (doc, score, doc_id) in self.similarity_search_with_score_id_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ )
+ ]
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ embedding_vector = self.embedding.embed_query(query)
+ return self.similarity_search_by_vector(
+ embedding_vector,
+ k,
+ filter=filter,
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ return [
+ doc
+ for doc, _ in self.similarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ )
+ ]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Tuple[Document, float]]:
+ embedding_vector = self.embedding.embed_query(query)
+ return self.similarity_search_with_score_by_vector(
+ embedding_vector,
+ k,
+ filter=filter,
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ metadata_parameter = self._filter_to_metadata(filter)
+
+ prefetch_hits = list(
+ self.collection.paginated_find(
+ filter=metadata_parameter,
+ sort={"$vector": embedding},
+ options={"limit": fetch_k, "includeSimilarity": True},
+ projection={
+ "_id": 1,
+ "content": 1,
+ "metadata": 1,
+ "$vector": 1,
+ },
+ )
+ )
+
+ mmr_chosen_indices = maximal_marginal_relevance(
+ np.array(embedding, dtype=np.float32),
+ [prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ mmr_hits = [
+ prefetch_hit
+ for prefetch_index, prefetch_hit in enumerate(prefetch_hits)
+ if prefetch_index in mmr_chosen_indices
+ ]
+ return [
+ Document(
+ page_content=hit["content"],
+ metadata=hit["metadata"],
+ )
+ for hit in mmr_hits
+ ]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ query (str): Text to look up documents similar to.
+ k (int = 4): Number of Documents to return.
+ fetch_k (int = 20): Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult (float = 0.5): Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Optional.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding_vector = self.embedding.embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ embedding_vector,
+ k,
+ fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ )
+
+ @classmethod
+ def from_texts(
+ cls: Type[ADBVST],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> ADBVST:
+ """Create an Astra DB vectorstore from raw texts.
+
+ Args:
+ texts (List[str]): the texts to insert.
+ embedding (Embeddings): the embedding function to use in the store.
+ metadatas (Optional[List[dict]]): metadata dicts for the texts.
+ ids (Optional[List[str]]): ids to associate to the texts.
+ *Additional arguments*: you can pass any argument that you would
+ to 'add_texts' and/or to the 'AstraDB' class constructor
+ (see these methods for details). These arguments will be
+ routed to the respective methods as they are.
+
+ Returns:
+ an `AstraDb` vectorstore.
+ """
+
+ known_kwargs = {
+ "collection_name",
+ "token",
+ "api_endpoint",
+ "astra_db_client",
+ "namespace",
+ "metric",
+ "batch_size",
+ "bulk_insert_batch_concurrency",
+ "bulk_insert_overwrite_concurrency",
+ "bulk_delete_concurrency",
+ "batch_concurrency",
+ "overwrite_concurrency",
+ }
+ if kwargs:
+ unknown_kwargs = set(kwargs.keys()) - known_kwargs
+ if unknown_kwargs:
+ warnings.warn(
+ "Method 'from_texts' of AstraDB vector store invoked with "
+ f"unsupported arguments ({', '.join(sorted(unknown_kwargs))}), "
+ "which will be ignored."
+ )
+
+ collection_name: str = kwargs["collection_name"]
+ token = kwargs.get("token")
+ api_endpoint = kwargs.get("api_endpoint")
+ astra_db_client = kwargs.get("astra_db_client")
+ namespace = kwargs.get("namespace")
+ metric = kwargs.get("metric")
+
+ astra_db_store = cls(
+ embedding=embedding,
+ collection_name=collection_name,
+ token=token,
+ api_endpoint=api_endpoint,
+ astra_db_client=astra_db_client,
+ namespace=namespace,
+ metric=metric,
+ batch_size=kwargs.get("batch_size"),
+ bulk_insert_batch_concurrency=kwargs.get("bulk_insert_batch_concurrency"),
+ bulk_insert_overwrite_concurrency=kwargs.get(
+ "bulk_insert_overwrite_concurrency"
+ ),
+ bulk_delete_concurrency=kwargs.get("bulk_delete_concurrency"),
+ )
+ astra_db_store.add_texts(
+ texts=texts,
+ metadatas=metadatas,
+ ids=ids,
+ batch_size=kwargs.get("batch_size"),
+ batch_concurrency=kwargs.get("batch_concurrency"),
+ overwrite_concurrency=kwargs.get("overwrite_concurrency"),
+ )
+ return astra_db_store
+
+ @classmethod
+ def from_documents(
+ cls: Type[ADBVST],
+ documents: List[Document],
+ embedding: Embeddings,
+ **kwargs: Any,
+ ) -> ADBVST:
+ """Create an Astra DB vectorstore from a document list.
+
+ Utility method that defers to 'from_texts' (see that one).
+
+ Args: see 'from_texts', except here you have to supply 'documents'
+ in place of 'texts' and 'metadatas'.
+
+ Returns:
+ an `AstraDB` vectorstore.
+ """
+ return super().from_documents(documents, embedding, **kwargs)
diff --git a/libs/community/langchain_community/vectorstores/atlas.py b/libs/community/langchain_community/vectorstores/atlas.py
new file mode 100644
index 00000000000..88b2da8f7c9
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/atlas.py
@@ -0,0 +1,326 @@
+from __future__ import annotations
+
+import logging
+import uuid
+from typing import Any, Iterable, List, Optional, Type
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger(__name__)
+
+
+class AtlasDB(VectorStore):
+ """`Atlas` vector store.
+
+ Atlas is the `Nomic's` neural database and `rhizomatic` instrument.
+
+ To use, you should have the ``nomic`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import AtlasDB
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ vectorstore = AtlasDB("my_project", embeddings.embed_query)
+ """
+
+ _ATLAS_DEFAULT_ID_FIELD = "atlas_id"
+
+ def __init__(
+ self,
+ name: str,
+ embedding_function: Optional[Embeddings] = None,
+ api_key: Optional[str] = None,
+ description: str = "A description for your project",
+ is_public: bool = True,
+ reset_project_if_exists: bool = False,
+ ) -> None:
+ """
+ Initialize the Atlas Client
+
+ Args:
+ name (str): The name of your project. If the project already exists,
+ it will be loaded.
+ embedding_function (Optional[Embeddings]): An optional function used for
+ embedding your data. If None, data will be embedded with
+ Nomic's embed model.
+ api_key (str): Your nomic API key
+ description (str): A description for your project.
+ is_public (bool): Whether your project is publicly accessible.
+ True by default.
+ reset_project_if_exists (bool): Whether to reset this project if it
+ already exists. Default False.
+ Generally useful during development and testing.
+ """
+ try:
+ import nomic
+ from nomic import AtlasProject
+ except ImportError:
+ raise ImportError(
+ "Could not import nomic python package. "
+ "Please install it with `pip install nomic`."
+ )
+
+ if api_key is None:
+ raise ValueError("No API key provided. Sign up at atlas.nomic.ai!")
+ nomic.login(api_key)
+
+ self._embedding_function = embedding_function
+ modality = "text"
+ if self._embedding_function is not None:
+ modality = "embedding"
+
+ # Check if the project exists, create it if not
+ self.project = AtlasProject(
+ name=name,
+ description=description,
+ modality=modality,
+ is_public=is_public,
+ reset_project_if_exists=reset_project_if_exists,
+ unique_id_field=AtlasDB._ATLAS_DEFAULT_ID_FIELD,
+ )
+ self.project._latest_project_state()
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embedding_function
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ refresh: bool = True,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts (Iterable[str]): Texts to add to the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ ids (Optional[List[str]]): An optional list of ids.
+ refresh(bool): Whether or not to refresh indices with the updated data.
+ Default True.
+ Returns:
+ List[str]: List of IDs of the added texts.
+ """
+
+ if (
+ metadatas is not None
+ and len(metadatas) > 0
+ and "text" in metadatas[0].keys()
+ ):
+ raise ValueError("Cannot accept key text in metadata!")
+
+ texts = list(texts)
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ # Embedding upload case
+ if self._embedding_function is not None:
+ _embeddings = self._embedding_function.embed_documents(texts)
+ embeddings = np.stack(_embeddings)
+ if metadatas is None:
+ data = [
+ {AtlasDB._ATLAS_DEFAULT_ID_FIELD: ids[i], "text": texts[i]}
+ for i, _ in enumerate(texts)
+ ]
+ else:
+ for i in range(len(metadatas)):
+ metadatas[i][AtlasDB._ATLAS_DEFAULT_ID_FIELD] = ids[i]
+ metadatas[i]["text"] = texts[i]
+ data = metadatas
+
+ self.project._validate_map_data_inputs(
+ [], id_field=AtlasDB._ATLAS_DEFAULT_ID_FIELD, data=data
+ )
+ with self.project.wait_for_project_lock():
+ self.project.add_embeddings(embeddings=embeddings, data=data)
+ # Text upload case
+ else:
+ if metadatas is None:
+ data = [
+ {"text": text, AtlasDB._ATLAS_DEFAULT_ID_FIELD: ids[i]}
+ for i, text in enumerate(texts)
+ ]
+ else:
+ for i, text in enumerate(texts):
+ metadatas[i]["text"] = texts
+ metadatas[i][AtlasDB._ATLAS_DEFAULT_ID_FIELD] = ids[i]
+ data = metadatas
+
+ self.project._validate_map_data_inputs(
+ [], id_field=AtlasDB._ATLAS_DEFAULT_ID_FIELD, data=data
+ )
+
+ with self.project.wait_for_project_lock():
+ self.project.add_text(data)
+
+ if refresh:
+ if len(self.project.indices) > 0:
+ with self.project.wait_for_project_lock():
+ self.project.rebuild_maps()
+
+ return ids
+
+ def create_index(self, **kwargs: Any) -> Any:
+ """Creates an index in your project.
+
+ See
+ https://docs.nomic.ai/atlas_api.html#nomic.project.AtlasProject.create_index
+ for full detail.
+ """
+ with self.project.wait_for_project_lock():
+ return self.project.create_index(**kwargs)
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search with AtlasDB
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+
+ Returns:
+ List[Document]: List of documents most similar to the query text.
+ """
+ if self._embedding_function is None:
+ raise NotImplementedError(
+ "AtlasDB requires an embedding_function for text similarity search!"
+ )
+
+ _embedding = self._embedding_function.embed_documents([query])[0]
+ embedding = np.array(_embedding).reshape(1, -1)
+ with self.project.wait_for_project_lock():
+ neighbors, _ = self.project.projections[0].vector_search(
+ queries=embedding, k=k
+ )
+ data = self.project.get_data(ids=neighbors[0])
+
+ docs = [
+ Document(page_content=data[i]["text"], metadata=data[i])
+ for i, neighbor in enumerate(neighbors)
+ ]
+ return docs
+
+ @classmethod
+ def from_texts(
+ cls: Type[AtlasDB],
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ name: Optional[str] = None,
+ api_key: Optional[str] = None,
+ description: str = "A description for your project",
+ is_public: bool = True,
+ reset_project_if_exists: bool = False,
+ index_kwargs: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> AtlasDB:
+ """Create an AtlasDB vectorstore from a raw documents.
+
+ Args:
+ texts (List[str]): The list of texts to ingest.
+ name (str): Name of the project to create.
+ api_key (str): Your nomic API key,
+ embedding (Optional[Embeddings]): Embedding function. Defaults to None.
+ metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
+ ids (Optional[List[str]]): Optional list of document IDs. If None,
+ ids will be auto created
+ description (str): A description for your project.
+ is_public (bool): Whether your project is publicly accessible.
+ True by default.
+ reset_project_if_exists (bool): Whether to reset this project if it
+ already exists. Default False.
+ Generally useful during development and testing.
+ index_kwargs (Optional[dict]): Dict of kwargs for index creation.
+ See https://docs.nomic.ai/atlas_api.html
+
+ Returns:
+ AtlasDB: Nomic's neural database and finest rhizomatic instrument
+ """
+ if name is None or api_key is None:
+ raise ValueError("`name` and `api_key` cannot be None.")
+
+ # Inject relevant kwargs
+ all_index_kwargs = {"name": name + "_index", "indexed_field": "text"}
+ if index_kwargs is not None:
+ for k, v in index_kwargs.items():
+ all_index_kwargs[k] = v
+
+ # Build project
+ atlasDB = cls(
+ name,
+ embedding_function=embedding,
+ api_key=api_key,
+ description="A description for your project",
+ is_public=is_public,
+ reset_project_if_exists=reset_project_if_exists,
+ )
+ with atlasDB.project.wait_for_project_lock():
+ atlasDB.add_texts(texts=texts, metadatas=metadatas, ids=ids)
+ atlasDB.create_index(**all_index_kwargs)
+ return atlasDB
+
+ @classmethod
+ def from_documents(
+ cls: Type[AtlasDB],
+ documents: List[Document],
+ embedding: Optional[Embeddings] = None,
+ ids: Optional[List[str]] = None,
+ name: Optional[str] = None,
+ api_key: Optional[str] = None,
+ persist_directory: Optional[str] = None,
+ description: str = "A description for your project",
+ is_public: bool = True,
+ reset_project_if_exists: bool = False,
+ index_kwargs: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> AtlasDB:
+ """Create an AtlasDB vectorstore from a list of documents.
+
+ Args:
+ name (str): Name of the collection to create.
+ api_key (str): Your nomic API key,
+ documents (List[Document]): List of documents to add to the vectorstore.
+ embedding (Optional[Embeddings]): Embedding function. Defaults to None.
+ ids (Optional[List[str]]): Optional list of document IDs. If None,
+ ids will be auto created
+ description (str): A description for your project.
+ is_public (bool): Whether your project is publicly accessible.
+ True by default.
+ reset_project_if_exists (bool): Whether to reset this project if
+ it already exists. Default False.
+ Generally useful during development and testing.
+ index_kwargs (Optional[dict]): Dict of kwargs for index creation.
+ See https://docs.nomic.ai/atlas_api.html
+
+ Returns:
+ AtlasDB: Nomic's neural database and finest rhizomatic instrument
+ """
+ if name is None or api_key is None:
+ raise ValueError("`name` and `api_key` cannot be None.")
+ texts = [doc.page_content for doc in documents]
+ metadatas = [doc.metadata for doc in documents]
+ return cls.from_texts(
+ name=name,
+ api_key=api_key,
+ texts=texts,
+ embedding=embedding,
+ metadatas=metadatas,
+ ids=ids,
+ description=description,
+ is_public=is_public,
+ reset_project_if_exists=reset_project_if_exists,
+ index_kwargs=index_kwargs,
+ )
diff --git a/libs/community/langchain_community/vectorstores/awadb.py b/libs/community/langchain_community/vectorstores/awadb.py
new file mode 100644
index 00000000000..96555558edd
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/awadb.py
@@ -0,0 +1,627 @@
+from __future__ import annotations
+
+import logging
+import uuid
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Type
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+if TYPE_CHECKING:
+ import awadb
+
+logger = logging.getLogger()
+DEFAULT_TOPN = 4
+
+
+class AwaDB(VectorStore):
+ """`AwaDB` vector store."""
+
+ _DEFAULT_TABLE_NAME = "langchain_awadb"
+
+ def __init__(
+ self,
+ table_name: str = _DEFAULT_TABLE_NAME,
+ embedding: Optional[Embeddings] = None,
+ log_and_data_dir: Optional[str] = None,
+ client: Optional[awadb.Client] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize with AwaDB client.
+ If table_name is not specified,
+ a random table name of `_DEFAULT_TABLE_NAME + last segment of uuid`
+ would be created automatically.
+
+ Args:
+ table_name: Name of the table created, default _DEFAULT_TABLE_NAME.
+ embedding: Optional Embeddings initially set.
+ log_and_data_dir: Optional the root directory of log and data.
+ client: Optional AwaDB client.
+ kwargs: Any possible extend parameters in the future.
+
+ Returns:
+ None.
+ """
+ try:
+ import awadb
+ except ImportError:
+ raise ImportError(
+ "Could not import awadb python package. "
+ "Please install it with `pip install awadb`."
+ )
+
+ if client is not None:
+ self.awadb_client = client
+ else:
+ if log_and_data_dir is not None:
+ self.awadb_client = awadb.Client(log_and_data_dir)
+ else:
+ self.awadb_client = awadb.Client()
+
+ if table_name == self._DEFAULT_TABLE_NAME:
+ table_name += "_"
+ table_name += str(uuid.uuid4()).split("-")[-1]
+
+ self.awadb_client.Create(table_name)
+ self.table2embeddings: dict[str, Embeddings] = {}
+ if embedding is not None:
+ self.table2embeddings[table_name] = embedding
+ self.using_table_name = table_name
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ if self.using_table_name in self.table2embeddings:
+ return self.table2embeddings[self.using_table_name]
+ return None
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ is_duplicate_texts: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ is_duplicate_texts: Optional whether to duplicate texts. Defaults to True.
+ kwargs: any possible extend parameters in the future.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ embeddings = None
+ if self.using_table_name in self.table2embeddings:
+ embeddings = self.table2embeddings[self.using_table_name].embed_documents(
+ list(texts)
+ )
+
+ return self.awadb_client.AddTexts(
+ "embedding_text",
+ "text_embedding",
+ texts,
+ embeddings,
+ metadatas,
+ is_duplicate_texts,
+ )
+
+ def load_local(
+ self,
+ table_name: str,
+ **kwargs: Any,
+ ) -> bool:
+ """Load the local specified table.
+
+ Args:
+ table_name: Table name
+ kwargs: Any possible extend parameters in the future.
+
+ Returns:
+ Success or failure of loading the local specified table
+ """
+
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ return self.awadb_client.Load(table_name)
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = DEFAULT_TOPN,
+ text_in_page_content: Optional[str] = None,
+ meta_filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text query.
+ k: The maximum number of documents to return.
+ text_in_page_content: Filter by the text in page_content of Document.
+ meta_filter (Optional[dict]): Filter by metadata. Defaults to None.
+ E.g. `{"color" : "red", "price": 4.20}`. Optional.
+ E.g. `{"max_price" : 15.66, "min_price": 4.20}`
+ `price` is the metadata field, means range filter(4.20<'price'<15.66).
+ E.g. `{"maxe_price" : 15.66, "mine_price": 4.20}`
+ `price` is the metadata field, means range filter(4.20<='price'<=15.66).
+ kwargs: Any possible extend parameters in the future.
+
+ Returns:
+ Returns the k most similar documents to the specified text query.
+ """
+
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ embedding = None
+ if self.using_table_name in self.table2embeddings:
+ embedding = self.table2embeddings[self.using_table_name].embed_query(query)
+ else:
+ from awadb import AwaEmbedding
+
+ embedding = AwaEmbedding().Embedding(query)
+
+ not_include_fields: Set[str] = {"text_embedding", "_id", "score"}
+ return self.similarity_search_by_vector(
+ embedding,
+ k,
+ text_in_page_content=text_in_page_content,
+ meta_filter=meta_filter,
+ not_include_fields_in_metadata=not_include_fields,
+ )
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = DEFAULT_TOPN,
+ text_in_page_content: Optional[str] = None,
+ meta_filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """The most k similar documents and scores of the specified query.
+
+ Args:
+ query: Text query.
+ k: The k most similar documents to the text query.
+ text_in_page_content: Filter by the text in page_content of Document.
+ meta_filter: Filter by metadata. Defaults to None.
+ kwargs: Any possible extend parameters in the future.
+
+ Returns:
+ The k most similar documents to the specified text query.
+ 0 is dissimilar, 1 is the most similar.
+ """
+
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ embedding = None
+ if self.using_table_name in self.table2embeddings:
+ embedding = self.table2embeddings[self.using_table_name].embed_query(query)
+ else:
+ from awadb import AwaEmbedding
+
+ embedding = AwaEmbedding().Embedding(query)
+
+ results: List[Tuple[Document, float]] = []
+
+ not_include_fields: Set[str] = {"text_embedding", "_id"}
+ retrieval_docs = self.similarity_search_by_vector(
+ embedding,
+ k,
+ text_in_page_content=text_in_page_content,
+ meta_filter=meta_filter,
+ not_include_fields_in_metadata=not_include_fields,
+ )
+
+ for doc in retrieval_docs:
+ score = doc.metadata["score"]
+ del doc.metadata["score"]
+ doc_tuple = (doc, score)
+ results.append(doc_tuple)
+
+ return results
+
+ def _similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ return self.similarity_search_with_score(query, k, **kwargs)
+
+ def similarity_search_by_vector(
+ self,
+ embedding: Optional[List[float]] = None,
+ k: int = DEFAULT_TOPN,
+ text_in_page_content: Optional[str] = None,
+ meta_filter: Optional[dict] = None,
+ not_include_fields_in_metadata: Optional[Set[str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ text_in_page_content: Filter by the text in page_content of Document.
+ meta_filter: Filter by metadata. Defaults to None.
+ not_incude_fields_in_metadata: Not include meta fields of each document.
+
+ Returns:
+ List of Documents which are the most similar to the query vector.
+ """
+
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ results: List[Document] = []
+
+ if embedding is None:
+ return results
+
+ show_results = self.awadb_client.Search(
+ embedding,
+ k,
+ text_in_page_content=text_in_page_content,
+ meta_filter=meta_filter,
+ not_include_fields=not_include_fields_in_metadata,
+ )
+
+ if show_results.__len__() == 0:
+ return results
+
+ for item_detail in show_results[0]["ResultItems"]:
+ content = ""
+ meta_data = {}
+ for item_key in item_detail:
+ if item_key == "embedding_text":
+ content = item_detail[item_key]
+ continue
+ elif not_include_fields_in_metadata is not None:
+ if item_key in not_include_fields_in_metadata:
+ continue
+ meta_data[item_key] = item_detail[item_key]
+ results.append(Document(page_content=content, metadata=meta_data))
+ return results
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ text_in_page_content: Optional[str] = None,
+ meta_filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ text_in_page_content: Filter by the text in page_content of Document.
+ meta_filter (Optional[dict]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ embedding: List[float] = []
+ if self.using_table_name in self.table2embeddings:
+ embedding = self.table2embeddings[self.using_table_name].embed_query(query)
+ else:
+ from awadb import AwaEmbedding
+
+ embedding = AwaEmbedding().Embedding(query)
+
+ if embedding.__len__() == 0:
+ return []
+
+ results = self.max_marginal_relevance_search_by_vector(
+ embedding,
+ k,
+ fetch_k,
+ lambda_mult=lambda_mult,
+ text_in_page_content=text_in_page_content,
+ meta_filter=meta_filter,
+ )
+ return results
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ text_in_page_content: Optional[str] = None,
+ meta_filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ text_in_page_content: Filter by the text in page_content of Document.
+ meta_filter (Optional[dict]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ results: List[Document] = []
+
+ if embedding is None:
+ return results
+
+ not_include_fields: set = {"_id", "score"}
+ retrieved_docs = self.similarity_search_by_vector(
+ embedding,
+ fetch_k,
+ text_in_page_content=text_in_page_content,
+ meta_filter=meta_filter,
+ not_include_fields_in_metadata=not_include_fields,
+ )
+
+ top_embeddings = []
+
+ for doc in retrieved_docs:
+ top_embeddings.append(doc.metadata["text_embedding"])
+
+ selected_docs = maximal_marginal_relevance(
+ np.array(embedding, dtype=np.float32), embedding_list=top_embeddings
+ )
+
+ for s_id in selected_docs:
+ if "text_embedding" in retrieved_docs[s_id].metadata:
+ del retrieved_docs[s_id].metadata["text_embedding"]
+ results.append(retrieved_docs[s_id])
+ return results
+
+ def get(
+ self,
+ ids: Optional[List[str]] = None,
+ text_in_page_content: Optional[str] = None,
+ meta_filter: Optional[dict] = None,
+ not_include_fields: Optional[Set[str]] = None,
+ limit: Optional[int] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Document]:
+ """Return docs according ids.
+
+ Args:
+ ids: The ids of the embedding vectors.
+ text_in_page_content: Filter by the text in page_content of Document.
+ meta_filter: Filter by any metadata of the document.
+ not_include_fields: Not pack the specified fields of each document.
+ limit: The number of documents to return. Defaults to 5. Optional.
+
+ Returns:
+ Documents which satisfy the input conditions.
+ """
+
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ docs_detail = self.awadb_client.Get(
+ ids=ids,
+ text_in_page_content=text_in_page_content,
+ meta_filter=meta_filter,
+ not_include_fields=not_include_fields,
+ limit=limit,
+ )
+
+ results: Dict[str, Document] = {}
+ for doc_detail in docs_detail:
+ content = ""
+ meta_info = {}
+ for field in doc_detail:
+ if field == "embedding_text":
+ content = doc_detail[field]
+ continue
+ elif field == "text_embedding" or field == "_id":
+ continue
+
+ meta_info[field] = doc_detail[field]
+
+ doc = Document(page_content=content, metadata=meta_info)
+ results[doc_detail["_id"]] = doc
+ return results
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Optional[bool]:
+ """Delete the documents which have the specified ids.
+
+ Args:
+ ids: The ids of the embedding vectors.
+ **kwargs: Other keyword arguments that subclasses might use.
+
+ Returns:
+ Optional[bool]: True if deletion is successful.
+ False otherwise, None if not implemented.
+ """
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+ ret: Optional[bool] = None
+ if ids is None or ids.__len__() == 0:
+ return ret
+ ret = self.awadb_client.Delete(ids)
+ return ret
+
+ def update(
+ self,
+ ids: List[str],
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Update the documents which have the specified ids.
+
+ Args:
+ ids: The id list of the updating embedding vector.
+ texts: The texts of the updating documents.
+ metadatas: The metadatas of the updating documents.
+ Returns:
+ the ids of the updated documents.
+ """
+
+ if self.awadb_client is None:
+ raise ValueError("AwaDB client is None!!!")
+
+ return self.awadb_client.UpdateTexts(
+ ids=ids, text_field_name="embedding_text", texts=texts, metadatas=metadatas
+ )
+
+ def create_table(
+ self,
+ table_name: str,
+ **kwargs: Any,
+ ) -> bool:
+ """Create a new table."""
+
+ if self.awadb_client is None:
+ return False
+
+ ret = self.awadb_client.Create(table_name)
+
+ if ret:
+ self.using_table_name = table_name
+ return ret
+
+ def use(
+ self,
+ table_name: str,
+ **kwargs: Any,
+ ) -> bool:
+ """Use the specified table. Don't know the tables, please invoke list_tables."""
+
+ if self.awadb_client is None:
+ return False
+
+ ret = self.awadb_client.Use(table_name)
+ if ret:
+ self.using_table_name = table_name
+
+ return ret
+
+ def list_tables(
+ self,
+ **kwargs: Any,
+ ) -> List[str]:
+ """List all the tables created by the client."""
+
+ if self.awadb_client is None:
+ return []
+
+ return self.awadb_client.ListAllTables()
+
+ def get_current_table(
+ self,
+ **kwargs: Any,
+ ) -> str:
+ """Get the current table."""
+
+ return self.using_table_name
+
+ @classmethod
+ def from_texts(
+ cls: Type[AwaDB],
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ table_name: str = _DEFAULT_TABLE_NAME,
+ log_and_data_dir: Optional[str] = None,
+ client: Optional[awadb.Client] = None,
+ **kwargs: Any,
+ ) -> AwaDB:
+ """Create an AwaDB vectorstore from a raw documents.
+
+ Args:
+ texts (List[str]): List of texts to add to the table.
+ embedding (Optional[Embeddings]): Embedding function. Defaults to None.
+ metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
+ table_name (str): Name of the table to create.
+ log_and_data_dir (Optional[str]): Directory of logging and persistence.
+ client (Optional[awadb.Client]): AwaDB client
+
+ Returns:
+ AwaDB: AwaDB vectorstore.
+ """
+ awadb_client = cls(
+ table_name=table_name,
+ embedding=embedding,
+ log_and_data_dir=log_and_data_dir,
+ client=client,
+ )
+ awadb_client.add_texts(texts=texts, metadatas=metadatas)
+ return awadb_client
+
+ @classmethod
+ def from_documents(
+ cls: Type[AwaDB],
+ documents: List[Document],
+ embedding: Optional[Embeddings] = None,
+ table_name: str = _DEFAULT_TABLE_NAME,
+ log_and_data_dir: Optional[str] = None,
+ client: Optional[awadb.Client] = None,
+ **kwargs: Any,
+ ) -> AwaDB:
+ """Create an AwaDB vectorstore from a list of documents.
+
+ If a log_and_data_dir specified, the table will be persisted there.
+
+ Args:
+ documents (List[Document]): List of documents to add to the vectorstore.
+ embedding (Optional[Embeddings]): Embedding function. Defaults to None.
+ table_name (str): Name of the table to create.
+ log_and_data_dir (Optional[str]): Directory to persist the table.
+ client (Optional[awadb.Client]): AwaDB client.
+ Any: Any possible parameters in the future
+
+ Returns:
+ AwaDB: AwaDB vectorstore.
+ """
+ texts = [doc.page_content for doc in documents]
+ metadatas = [doc.metadata for doc in documents]
+ return cls.from_texts(
+ texts=texts,
+ embedding=embedding,
+ metadatas=metadatas,
+ table_name=table_name,
+ log_and_data_dir=log_and_data_dir,
+ client=client,
+ )
diff --git a/libs/community/langchain_community/vectorstores/azure_cosmos_db.py b/libs/community/langchain_community/vectorstores/azure_cosmos_db.py
new file mode 100644
index 00000000000..e65ea1d2806
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/azure_cosmos_db.py
@@ -0,0 +1,425 @@
+from __future__ import annotations
+
+import logging
+from enum import Enum
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+if TYPE_CHECKING:
+ from langchain_core.embeddings import Embeddings
+ from pymongo.collection import Collection
+
+
+# Before Python 3.11 native StrEnum is not available
+class CosmosDBSimilarityType(str, Enum):
+ """Cosmos DB Similarity Type as enumerator."""
+
+ COS = "COS"
+ """CosineSimilarity"""
+ IP = "IP"
+ """inner - product"""
+ L2 = "L2"
+ """Euclidean distance"""
+
+
+CosmosDBDocumentType = TypeVar("CosmosDBDocumentType", bound=Dict[str, Any])
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_INSERT_BATCH_SIZE = 128
+
+
+class AzureCosmosDBVectorSearch(VectorStore):
+ """`Azure Cosmos DB for MongoDB vCore` vector store.
+
+ To use, you should have both:
+ - the ``pymongo`` python package installed
+ - a connection string associated with a MongoDB VCore Cluster
+
+ Example:
+ . code-block:: python
+
+ from langchain_community.vectorstores import AzureCosmosDBVectorSearch
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ from pymongo import MongoClient
+
+ mongo_client = MongoClient("")
+ collection = mongo_client[""][""]
+ embeddings = OpenAIEmbeddings()
+ vectorstore = AzureCosmosDBVectorSearch(collection, embeddings)
+ """
+
+ def __init__(
+ self,
+ collection: Collection[CosmosDBDocumentType],
+ embedding: Embeddings,
+ *,
+ index_name: str = "vectorSearchIndex",
+ text_key: str = "textContent",
+ embedding_key: str = "vectorContent",
+ ):
+ """Constructor for AzureCosmosDBVectorSearch
+
+ Args:
+ collection: MongoDB collection to add the texts to.
+ embedding: Text embedding model to use.
+ index_name: Name of the Atlas Search index.
+ text_key: MongoDB field that will contain the text
+ for each document.
+ embedding_key: MongoDB field that will contain the embedding
+ for each document.
+ """
+ self._collection = collection
+ self._embedding = embedding
+ self._index_name = index_name
+ self._text_key = text_key
+ self._embedding_key = embedding_key
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embedding
+
+ def get_index_name(self) -> str:
+ """Returns the index name
+
+ Returns:
+ Returns the index name
+
+ """
+ return self._index_name
+
+ @classmethod
+ def from_connection_string(
+ cls,
+ connection_string: str,
+ namespace: str,
+ embedding: Embeddings,
+ **kwargs: Any,
+ ) -> AzureCosmosDBVectorSearch:
+ """Creates an Instance of AzureCosmosDBVectorSearch from a Connection String
+
+ Args:
+ connection_string: The MongoDB vCore instance connection string
+ namespace: The namespace (database.collection)
+ embedding: The embedding utility
+ **kwargs: Dynamic keyword arguments
+
+ Returns:
+ an instance of the vector store
+
+ """
+ try:
+ from pymongo import MongoClient
+ except ImportError:
+ raise ImportError(
+ "Could not import pymongo, please install it with "
+ "`pip install pymongo`."
+ )
+ client: MongoClient = MongoClient(connection_string)
+ db_name, collection_name = namespace.split(".")
+ collection = client[db_name][collection_name]
+ return cls(collection, embedding, **kwargs)
+
+ def index_exists(self) -> bool:
+ """Verifies if the specified index name during instance
+ construction exists on the collection
+
+ Returns:
+ Returns True on success and False if no such index exists
+ on the collection
+ """
+ cursor = self._collection.list_indexes()
+ index_name = self._index_name
+
+ for res in cursor:
+ current_index_name = res.pop("name")
+ if current_index_name == index_name:
+ return True
+
+ return False
+
+ def delete_index(self) -> None:
+ """Deletes the index specified during instance construction if it exists"""
+ if self.index_exists():
+ self._collection.drop_index(self._index_name)
+ # Raises OperationFailure on an error (e.g. trying to drop
+ # an index that does not exist)
+
+ def create_index(
+ self,
+ num_lists: int = 100,
+ dimensions: int = 1536,
+ similarity: CosmosDBSimilarityType = CosmosDBSimilarityType.COS,
+ ) -> dict[str, Any]:
+ """Creates an index using the index name specified at
+ instance construction
+
+ Setting the numLists parameter correctly is important for achieving
+ good accuracy and performance.
+ Since the vector store uses IVF as the indexing strategy,
+ you should create the index only after you
+ have loaded a large enough sample documents to ensure that the
+ centroids for the respective buckets are
+ faily distributed.
+
+ We recommend that numLists is set to documentCount/1000 for up
+ to 1 million documents
+ and to sqrt(documentCount) for more than 1 million documents.
+ As the number of items in your database grows, you should
+ tune numLists to be larger
+ in order to achieve good latency performance for vector search.
+
+ If you're experimenting with a new scenario or creating a
+ small demo, you can start with numLists
+ set to 1 to perform a brute-force search across all vectors.
+ This should provide you with the most
+ accurate results from the vector search, however be aware that
+ the search speed and latency will be slow.
+ After your initial setup, you should go ahead and tune
+ the numLists parameter using the above guidance.
+
+ Args:
+ num_lists: This integer is the number of clusters that the
+ inverted file (IVF) index uses to group the vector data.
+ We recommend that numLists is set to documentCount/1000
+ for up to 1 million documents and to sqrt(documentCount)
+ for more than 1 million documents.
+ Using a numLists value of 1 is akin to performing
+ brute-force search, which has limited performance
+ dimensions: Number of dimensions for vector similarity.
+ The maximum number of supported dimensions is 2000
+ similarity: Similarity metric to use with the IVF index.
+
+ Possible options are:
+ - CosmosDBSimilarityType.COS (cosine distance),
+ - CosmosDBSimilarityType.L2 (Euclidean distance), and
+ - CosmosDBSimilarityType.IP (inner product).
+
+ Returns:
+ An object describing the created index
+
+ """
+ # prepare the command
+ create_index_commands = {
+ "createIndexes": self._collection.name,
+ "indexes": [
+ {
+ "name": self._index_name,
+ "key": {self._embedding_key: "cosmosSearch"},
+ "cosmosSearchOptions": {
+ "kind": "vector-ivf",
+ "numLists": num_lists,
+ "similarity": similarity,
+ "dimensions": dimensions,
+ },
+ }
+ ],
+ }
+
+ # retrieve the database object
+ current_database = self._collection.database
+
+ # invoke the command from the database object
+ create_index_responses: dict[str, Any] = current_database.command(
+ create_index_commands
+ )
+
+ return create_index_responses
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[str, Any]]] = None,
+ **kwargs: Any,
+ ) -> List:
+ batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
+ _metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
+ texts_batch = []
+ metadatas_batch = []
+ result_ids = []
+ for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
+ texts_batch.append(text)
+ metadatas_batch.append(metadata)
+ if (i + 1) % batch_size == 0:
+ result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
+ texts_batch = []
+ metadatas_batch = []
+ if texts_batch:
+ result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
+ return result_ids
+
+ def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
+ """Used to Load Documents into the collection
+
+ Args:
+ texts: The list of documents strings to load
+ metadatas: The list of metadata objects associated with each document
+
+ Returns:
+
+ """
+ # If the text is empty, then exit early
+ if not texts:
+ return []
+
+ # Embed and create the documents
+ embeddings = self._embedding.embed_documents(texts)
+ to_insert = [
+ {self._text_key: t, self._embedding_key: embedding, **m}
+ for t, m, embedding in zip(texts, metadatas, embeddings)
+ ]
+ # insert the documents in Cosmos DB
+ insert_result = self._collection.insert_many(to_insert) # type: ignore
+ return insert_result.inserted_ids
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection: Optional[Collection[CosmosDBDocumentType]] = None,
+ **kwargs: Any,
+ ) -> AzureCosmosDBVectorSearch:
+ if collection is None:
+ raise ValueError("Must provide 'collection' named parameter.")
+ vectorstore = cls(collection, embedding, **kwargs)
+ vectorstore.add_texts(texts, metadatas=metadatas)
+ return vectorstore
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ if ids is None:
+ raise ValueError("No document ids provided to delete.")
+
+ for document_id in ids:
+ self.delete_document_by_id(document_id)
+ return True
+
+ def delete_document_by_id(self, document_id: Optional[str] = None) -> None:
+ """Removes a Specific Document by Id
+
+ Args:
+ document_id: The document identifier
+ """
+ try:
+ from bson.objectid import ObjectId
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import bson, please install with `pip install bson`."
+ ) from e
+ if document_id is None:
+ raise ValueError("No document id provided to delete.")
+
+ self._collection.delete_one({"_id": ObjectId(document_id)})
+
+ def _similarity_search_with_score(
+ self, embeddings: List[float], k: int = 4
+ ) -> List[Tuple[Document, float]]:
+ """Returns a list of documents with their scores
+
+ Args:
+ embeddings: The query vector
+ k: the number of documents to return
+
+ Returns:
+ A list of documents closest to the query vector
+ """
+ pipeline: List[dict[str, Any]] = [
+ {
+ "$search": {
+ "cosmosSearch": {
+ "vector": embeddings,
+ "path": self._embedding_key,
+ "k": k,
+ },
+ "returnStoredSource": True,
+ }
+ },
+ {
+ "$project": {
+ "similarityScore": {"$meta": "searchScore"},
+ "document": "$$ROOT",
+ }
+ },
+ ]
+
+ cursor = self._collection.aggregate(pipeline)
+
+ docs = []
+
+ for res in cursor:
+ score = res.pop("similarityScore")
+ document_object_field = res.pop("document")
+ text = document_object_field.pop(self._text_key)
+ docs.append(
+ (Document(page_content=text, metadata=document_object_field), score)
+ )
+
+ return docs
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4
+ ) -> List[Tuple[Document, float]]:
+ embeddings = self._embedding.embed_query(query)
+ docs = self._similarity_search_with_score(embeddings=embeddings, k=k)
+ return docs
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ docs_and_scores = self.similarity_search_with_score(query, k=k)
+ return [doc for doc, _ in docs_and_scores]
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ # Retrieves the docs with similarity scores
+ # sorted by similarity scores in DESC order
+ docs = self._similarity_search_with_score(embedding, k=fetch_k)
+
+ # Re-ranks the docs using MMR
+ mmr_doc_indexes = maximal_marginal_relevance(
+ np.array(embedding),
+ [doc.metadata[self._embedding_key] for doc, _ in docs],
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
+ return mmr_docs
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ # compute the embeddings vector from the query string
+ embeddings = self._embedding.embed_query(query)
+
+ docs = self.max_marginal_relevance_search_by_vector(
+ embeddings, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult
+ )
+ return docs
diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py
new file mode 100644
index 00000000000..0b449688bd8
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/azuresearch.py
@@ -0,0 +1,739 @@
+from __future__ import annotations
+
+import base64
+import json
+import logging
+import uuid
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+
+import numpy as np
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForRetrieverRun,
+ CallbackManagerForRetrieverRun,
+)
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import root_validator
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.utils import get_from_env
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+ from azure.search.documents import SearchClient
+ from azure.search.documents.indexes.models import (
+ CorsOptions,
+ ScoringProfile,
+ SearchField,
+ VectorSearch,
+ )
+
+ try:
+ from azure.search.documents.indexes.models import SemanticSearch
+ except ImportError:
+ from azure.search.documents.indexes.models import SemanticSettings # <11.4.0
+
+# Allow overriding field names for Azure Search
+FIELDS_ID = get_from_env(
+ key="AZURESEARCH_FIELDS_ID", env_key="AZURESEARCH_FIELDS_ID", default="id"
+)
+FIELDS_CONTENT = get_from_env(
+ key="AZURESEARCH_FIELDS_CONTENT",
+ env_key="AZURESEARCH_FIELDS_CONTENT",
+ default="content",
+)
+FIELDS_CONTENT_VECTOR = get_from_env(
+ key="AZURESEARCH_FIELDS_CONTENT_VECTOR",
+ env_key="AZURESEARCH_FIELDS_CONTENT_VECTOR",
+ default="content_vector",
+)
+FIELDS_METADATA = get_from_env(
+ key="AZURESEARCH_FIELDS_TAG", env_key="AZURESEARCH_FIELDS_TAG", default="metadata"
+)
+
+MAX_UPLOAD_BATCH_SIZE = 1000
+
+
+def _get_search_client(
+ endpoint: str,
+ key: str,
+ index_name: str,
+ semantic_configuration_name: Optional[str] = None,
+ fields: Optional[List[SearchField]] = None,
+ vector_search: Optional[VectorSearch] = None,
+ semantic_settings: Optional[Union[SemanticSearch, SemanticSettings]] = None,
+ scoring_profiles: Optional[List[ScoringProfile]] = None,
+ default_scoring_profile: Optional[str] = None,
+ default_fields: Optional[List[SearchField]] = None,
+ user_agent: Optional[str] = "langchain",
+ cors_options: Optional[CorsOptions] = None,
+) -> SearchClient:
+ from azure.core.credentials import AzureKeyCredential
+ from azure.core.exceptions import ResourceNotFoundError
+ from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
+ from azure.search.documents import SearchClient
+ from azure.search.documents.indexes import SearchIndexClient
+ from azure.search.documents.indexes.models import (
+ SearchIndex,
+ SemanticConfiguration,
+ SemanticField,
+ VectorSearch,
+ )
+
+ # class names changed for versions >= 11.4.0
+ try:
+ from azure.search.documents.indexes.models import (
+ HnswAlgorithmConfiguration, # HnswVectorSearchAlgorithmConfiguration is old
+ SemanticPrioritizedFields, # PrioritizedFields outdated
+ SemanticSearch, # SemanticSettings outdated
+ )
+
+ NEW_VERSION = True
+ except ImportError:
+ from azure.search.documents.indexes.models import (
+ HnswVectorSearchAlgorithmConfiguration,
+ PrioritizedFields,
+ SemanticSettings,
+ )
+
+ NEW_VERSION = False
+
+ default_fields = default_fields or []
+ if key is None:
+ credential = DefaultAzureCredential()
+ elif key.upper() == "INTERACTIVE":
+ credential = InteractiveBrowserCredential()
+ credential.get_token("https://search.azure.com/.default")
+ else:
+ credential = AzureKeyCredential(key)
+ index_client: SearchIndexClient = SearchIndexClient(
+ endpoint=endpoint, credential=credential, user_agent=user_agent
+ )
+ try:
+ index_client.get_index(name=index_name)
+ except ResourceNotFoundError:
+ # Fields configuration
+ if fields is not None:
+ # Check mandatory fields
+ fields_types = {f.name: f.type for f in fields}
+ mandatory_fields = {df.name: df.type for df in default_fields}
+ # Check for missing keys
+ missing_fields = {
+ key: mandatory_fields[key]
+ for key, value in set(mandatory_fields.items())
+ - set(fields_types.items())
+ }
+ if len(missing_fields) > 0:
+ # Helper for formatting field information for each missing field.
+ def fmt_err(x: str) -> str:
+ return (
+ f"{x} current type: '{fields_types.get(x, 'MISSING')}'. "
+ f"It has to be '{mandatory_fields.get(x)}' or you can point "
+ f"to a different '{mandatory_fields.get(x)}' field name by "
+ f"using the env variable 'AZURESEARCH_FIELDS_{x.upper()}'"
+ )
+
+ error = "\n".join([fmt_err(x) for x in missing_fields])
+ raise ValueError(
+ f"You need to specify at least the following fields "
+ f"{missing_fields} or provide alternative field names in the env "
+ f"variables.\n\n{error}"
+ )
+ else:
+ fields = default_fields
+ # Vector search configuration
+ if vector_search is None:
+ if NEW_VERSION:
+ # >= 11.4.0:
+ # VectorSearch(algorithm_configuration) --> VectorSearch(algorithms)
+ # HnswVectorSearchAlgorithmConfiguration --> HnswAlgorithmConfiguration
+ vector_search = VectorSearch(
+ algorithms=[
+ HnswAlgorithmConfiguration(
+ name="default",
+ kind="hnsw",
+ parameters={ # type: ignore
+ "m": 4,
+ "efConstruction": 400,
+ "efSearch": 500,
+ "metric": "cosine",
+ },
+ )
+ ]
+ )
+ else: # < 11.4.0
+ vector_search = VectorSearch(
+ algorithm_configurations=[
+ HnswVectorSearchAlgorithmConfiguration(
+ name="default",
+ kind="hnsw",
+ parameters={ # type: ignore
+ "m": 4,
+ "efConstruction": 400,
+ "efSearch": 500,
+ "metric": "cosine",
+ },
+ )
+ ]
+ )
+
+ # Create the semantic settings with the configuration
+ if semantic_settings is None and semantic_configuration_name is not None:
+ if NEW_VERSION:
+ # <=11.4.0: SemanticSettings --> SemanticSearch
+ # PrioritizedFields(prioritized_content_fields)
+ # --> SemanticPrioritizedFields(content_fields)
+ semantic_settings = SemanticSearch(
+ configurations=[
+ SemanticConfiguration(
+ name=semantic_configuration_name,
+ prioritized_fields=SemanticPrioritizedFields(
+ content_fields=[
+ SemanticField(field_name=FIELDS_CONTENT)
+ ],
+ ),
+ )
+ ]
+ )
+ else: # < 11.4.0
+ semantic_settings = SemanticSettings(
+ configurations=[
+ SemanticConfiguration(
+ name=semantic_configuration_name,
+ prioritized_fields=PrioritizedFields(
+ prioritized_content_fields=[
+ SemanticField(field_name=FIELDS_CONTENT)
+ ],
+ ),
+ )
+ ]
+ )
+ # Create the search index with the semantic settings and vector search
+ index = SearchIndex(
+ name=index_name,
+ fields=fields,
+ vector_search=vector_search,
+ semantic_settings=semantic_settings,
+ scoring_profiles=scoring_profiles,
+ default_scoring_profile=default_scoring_profile,
+ cors_options=cors_options,
+ )
+ index_client.create_index(index)
+ # Create the search client
+ return SearchClient(
+ endpoint=endpoint,
+ index_name=index_name,
+ credential=credential,
+ user_agent=user_agent,
+ )
+
+
+class AzureSearch(VectorStore):
+ """`Azure Cognitive Search` vector store."""
+
+ def __init__(
+ self,
+ azure_search_endpoint: str,
+ azure_search_key: str,
+ index_name: str,
+ embedding_function: Callable,
+ search_type: str = "hybrid",
+ semantic_configuration_name: Optional[str] = None,
+ semantic_query_language: str = "en-us",
+ fields: Optional[List[SearchField]] = None,
+ vector_search: Optional[VectorSearch] = None,
+ semantic_settings: Optional[Union[SemanticSearch, SemanticSettings]] = None,
+ scoring_profiles: Optional[List[ScoringProfile]] = None,
+ default_scoring_profile: Optional[str] = None,
+ cors_options: Optional[CorsOptions] = None,
+ **kwargs: Any,
+ ):
+ from azure.search.documents.indexes.models import (
+ SearchableField,
+ SearchField,
+ SearchFieldDataType,
+ SimpleField,
+ )
+
+ """Initialize with necessary components."""
+ # Initialize base class
+ self.embedding_function = embedding_function
+ default_fields = [
+ SimpleField(
+ name=FIELDS_ID,
+ type=SearchFieldDataType.String,
+ key=True,
+ filterable=True,
+ ),
+ SearchableField(
+ name=FIELDS_CONTENT,
+ type=SearchFieldDataType.String,
+ ),
+ SearchField(
+ name=FIELDS_CONTENT_VECTOR,
+ type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
+ searchable=True,
+ vector_search_dimensions=len(embedding_function("Text")),
+ vector_search_configuration="default",
+ ),
+ SearchableField(
+ name=FIELDS_METADATA,
+ type=SearchFieldDataType.String,
+ ),
+ ]
+ user_agent = "langchain"
+ if "user_agent" in kwargs and kwargs["user_agent"]:
+ user_agent += " " + kwargs["user_agent"]
+ self.client = _get_search_client(
+ azure_search_endpoint,
+ azure_search_key,
+ index_name,
+ semantic_configuration_name=semantic_configuration_name,
+ fields=fields,
+ vector_search=vector_search,
+ semantic_settings=semantic_settings,
+ scoring_profiles=scoring_profiles,
+ default_scoring_profile=default_scoring_profile,
+ default_fields=default_fields,
+ user_agent=user_agent,
+ cors_options=cors_options,
+ )
+ self.search_type = search_type
+ self.semantic_configuration_name = semantic_configuration_name
+ self.semantic_query_language = semantic_query_language
+ self.fields = fields if fields else default_fields
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ # TODO: Support embedding object directly
+ return None
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add texts data to an existing index."""
+ keys = kwargs.get("keys")
+ ids = []
+ # Write data to index
+ data = []
+ for i, text in enumerate(texts):
+ # Use provided key otherwise use default key
+ key = keys[i] if keys else str(uuid.uuid4())
+ # Encoding key for Azure Search valid characters
+ key = base64.urlsafe_b64encode(bytes(key, "utf-8")).decode("ascii")
+ metadata = metadatas[i] if metadatas else {}
+ # Add data to index
+ # Additional metadata to fields mapping
+ doc = {
+ "@search.action": "upload",
+ FIELDS_ID: key,
+ FIELDS_CONTENT: text,
+ FIELDS_CONTENT_VECTOR: np.array(
+ self.embedding_function(text), dtype=np.float32
+ ).tolist(),
+ FIELDS_METADATA: json.dumps(metadata),
+ }
+ if metadata:
+ additional_fields = {
+ k: v
+ for k, v in metadata.items()
+ if k in [x.name for x in self.fields]
+ }
+ doc.update(additional_fields)
+ data.append(doc)
+ ids.append(key)
+ # Upload data in batches
+ if len(data) == MAX_UPLOAD_BATCH_SIZE:
+ response = self.client.upload_documents(documents=data)
+ # Check if all documents were successfully uploaded
+ if not all([r.succeeded for r in response]):
+ raise Exception(response)
+ # Reset data
+ data = []
+
+ # Considering case where data is an exact multiple of batch-size entries
+ if len(data) == 0:
+ return ids
+
+ # Upload data to index
+ response = self.client.upload_documents(documents=data)
+ # Check if all documents were successfully uploaded
+ if all([r.succeeded for r in response]):
+ return ids
+ else:
+ raise Exception(response)
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ search_type = kwargs.get("search_type", self.search_type)
+ if search_type == "similarity":
+ docs = self.vector_search(query, k=k, **kwargs)
+ elif search_type == "hybrid":
+ docs = self.hybrid_search(query, k=k, **kwargs)
+ elif search_type == "semantic_hybrid":
+ docs = self.semantic_hybrid_search(query, k=k, **kwargs)
+ else:
+ raise ValueError(f"search_type of {search_type} not allowed.")
+ return docs
+
+ def similarity_search_with_relevance_scores(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ score_threshold = kwargs.pop("score_threshold", None)
+ result = self.vector_search_with_score(query, k=k, **kwargs)
+ return (
+ result
+ if score_threshold is None
+ else [r for r in result if r[1] >= score_threshold]
+ )
+
+ def vector_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
+ """
+ Returns the most similar indexed documents to the query text.
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query text.
+ """
+ docs_and_scores = self.vector_search_with_score(
+ query, k=k, filters=kwargs.get("filters", None)
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def vector_search_with_score(
+ self, query: str, k: int = 4, filters: Optional[str] = None
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ from azure.search.documents.models import Vector
+
+ results = self.client.search(
+ search_text="",
+ vectors=[
+ Vector(
+ value=np.array(
+ self.embedding_function(query), dtype=np.float32
+ ).tolist(),
+ k=k,
+ fields=FIELDS_CONTENT_VECTOR,
+ )
+ ],
+ filter=filters,
+ )
+ # Convert results to Document objects
+ docs = [
+ (
+ Document(
+ page_content=result.pop(FIELDS_CONTENT),
+ metadata={
+ **(
+ {FIELDS_ID: result.pop(FIELDS_ID)}
+ if FIELDS_ID in result
+ else {}
+ ),
+ **(
+ json.loads(result[FIELDS_METADATA])
+ if FIELDS_METADATA in result
+ else {
+ k: v
+ for k, v in result.items()
+ if k != FIELDS_CONTENT_VECTOR
+ }
+ ),
+ },
+ ),
+ float(result["@search.score"]),
+ )
+ for result in results
+ ]
+ return docs
+
+ def hybrid_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
+ """
+ Returns the most similar indexed documents to the query text.
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query text.
+ """
+ docs_and_scores = self.hybrid_search_with_score(
+ query, k=k, filters=kwargs.get("filters", None)
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def hybrid_search_with_score(
+ self, query: str, k: int = 4, filters: Optional[str] = None
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query with an hybrid query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ from azure.search.documents.models import Vector
+
+ results = self.client.search(
+ search_text=query,
+ vectors=[
+ Vector(
+ value=np.array(
+ self.embedding_function(query), dtype=np.float32
+ ).tolist(),
+ k=k,
+ fields=FIELDS_CONTENT_VECTOR,
+ )
+ ],
+ filter=filters,
+ top=k,
+ )
+ # Convert results to Document objects
+ docs = [
+ (
+ Document(
+ page_content=result.pop(FIELDS_CONTENT),
+ metadata={
+ **(
+ {FIELDS_ID: result.pop(FIELDS_ID)}
+ if FIELDS_ID in result
+ else {}
+ ),
+ **(
+ json.loads(result[FIELDS_METADATA])
+ if FIELDS_METADATA in result
+ else {
+ k: v
+ for k, v in result.items()
+ if k != FIELDS_CONTENT_VECTOR
+ }
+ ),
+ },
+ ),
+ float(result["@search.score"]),
+ )
+ for result in results
+ ]
+ return docs
+
+ def semantic_hybrid_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """
+ Returns the most similar indexed documents to the query text.
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query text.
+ """
+ docs_and_scores = self.semantic_hybrid_search_with_score_and_rerank(
+ query, k=k, filters=kwargs.get("filters", None)
+ )
+ return [doc for doc, _, _ in docs_and_scores]
+
+ def semantic_hybrid_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """
+ Returns the most similar indexed documents to the query text.
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query text.
+ """
+ docs_and_scores = self.semantic_hybrid_search_with_score_and_rerank(
+ query, k=k, filters=kwargs.get("filters", None)
+ )
+ return [(doc, score) for doc, score, _ in docs_and_scores]
+
+ def semantic_hybrid_search_with_score_and_rerank(
+ self, query: str, k: int = 4, filters: Optional[str] = None
+ ) -> List[Tuple[Document, float, float]]:
+ """Return docs most similar to query with an hybrid query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ from azure.search.documents.models import Vector
+
+ results = self.client.search(
+ search_text=query,
+ vectors=[
+ Vector(
+ value=np.array(
+ self.embedding_function(query), dtype=np.float32
+ ).tolist(),
+ k=50,
+ fields=FIELDS_CONTENT_VECTOR,
+ )
+ ],
+ filter=filters,
+ query_type="semantic",
+ query_language=self.semantic_query_language,
+ semantic_configuration_name=self.semantic_configuration_name,
+ query_caption="extractive",
+ query_answer="extractive",
+ top=k,
+ )
+ # Get Semantic Answers
+ semantic_answers = results.get_answers() or []
+ semantic_answers_dict: Dict = {}
+ for semantic_answer in semantic_answers:
+ semantic_answers_dict[semantic_answer.key] = {
+ "text": semantic_answer.text,
+ "highlights": semantic_answer.highlights,
+ }
+ # Convert results to Document objects
+ docs = [
+ (
+ Document(
+ page_content=result.pop(FIELDS_CONTENT),
+ metadata={
+ **(
+ {FIELDS_ID: result.pop(FIELDS_ID)}
+ if FIELDS_ID in result
+ else {}
+ ),
+ **(
+ json.loads(result[FIELDS_METADATA])
+ if FIELDS_METADATA in result
+ else {
+ k: v
+ for k, v in result.items()
+ if k != FIELDS_CONTENT_VECTOR
+ }
+ ),
+ **{
+ "captions": {
+ "text": result.get("@search.captions", [{}])[0].text,
+ "highlights": result.get("@search.captions", [{}])[
+ 0
+ ].highlights,
+ }
+ if result.get("@search.captions")
+ else {},
+ "answers": semantic_answers_dict.get(
+ json.loads(result["metadata"]).get("key"), ""
+ ),
+ },
+ },
+ ),
+ float(result["@search.score"]),
+ float(result["@search.reranker_score"]),
+ )
+ for result in results
+ ]
+ return docs
+
+ @classmethod
+ def from_texts(
+ cls: Type[AzureSearch],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ azure_search_endpoint: str = "",
+ azure_search_key: str = "",
+ index_name: str = "langchain-index",
+ **kwargs: Any,
+ ) -> AzureSearch:
+ # Creating a new Azure Search instance
+ azure_search = cls(
+ azure_search_endpoint,
+ azure_search_key,
+ index_name,
+ embedding.embed_query,
+ )
+ azure_search.add_texts(texts, metadatas, **kwargs)
+ return azure_search
+
+
+class AzureSearchVectorStoreRetriever(BaseRetriever):
+ """Retriever that uses `Azure Cognitive Search`."""
+
+ vectorstore: AzureSearch
+ """Azure Search instance used to find similar documents."""
+ search_type: str = "hybrid"
+ """Type of search to perform. Options are "similarity", "hybrid",
+ "semantic_hybrid"."""
+ k: int = 4
+ """Number of documents to return."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ @root_validator()
+ def validate_search_type(cls, values: Dict) -> Dict:
+ """Validate search type."""
+ if "search_type" in values:
+ search_type = values["search_type"]
+ if search_type not in ("similarity", "hybrid", "semantic_hybrid"):
+ raise ValueError(f"search_type of {search_type} not allowed.")
+ return values
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ run_manager: CallbackManagerForRetrieverRun,
+ **kwargs: Any,
+ ) -> List[Document]:
+ if self.search_type == "similarity":
+ docs = self.vectorstore.vector_search(query, k=self.k, **kwargs)
+ elif self.search_type == "hybrid":
+ docs = self.vectorstore.hybrid_search(query, k=self.k, **kwargs)
+ elif self.search_type == "semantic_hybrid":
+ docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **kwargs)
+ else:
+ raise ValueError(f"search_type of {self.search_type} not allowed.")
+ return docs
+
+ async def _aget_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: AsyncCallbackManagerForRetrieverRun,
+ ) -> List[Document]:
+ raise NotImplementedError(
+ "AzureSearchVectorStoreRetriever does not support async"
+ )
diff --git a/libs/community/langchain_community/vectorstores/bageldb.py b/libs/community/langchain_community/vectorstores/bageldb.py
new file mode 100644
index 00000000000..fee0cdf1f98
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/bageldb.py
@@ -0,0 +1,431 @@
+from __future__ import annotations
+
+import uuid
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
+
+if TYPE_CHECKING:
+ import bagel
+ import bagel.config
+ from bagel.api.types import ID, OneOrMany, Where, WhereDocument
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import xor_args
+from langchain_core.vectorstores import VectorStore
+
+DEFAULT_K = 5
+
+
+def _results_to_docs(results: Any) -> List[Document]:
+ return [doc for doc, _ in _results_to_docs_and_scores(results)]
+
+
+def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
+ return [
+ (Document(page_content=result[0], metadata=result[1] or {}), result[2])
+ for result in zip(
+ results["documents"][0],
+ results["metadatas"][0],
+ results["distances"][0],
+ )
+ ]
+
+
+class Bagel(VectorStore):
+ """``BagelDB.ai`` vector store.
+
+ To use, you should have the ``betabageldb`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Bagel
+ vectorstore = Bagel(cluster_name="langchain_store")
+ """
+
+ _LANGCHAIN_DEFAULT_CLUSTER_NAME = "langchain"
+
+ def __init__(
+ self,
+ cluster_name: str = _LANGCHAIN_DEFAULT_CLUSTER_NAME,
+ client_settings: Optional[bagel.config.Settings] = None,
+ embedding_function: Optional[Embeddings] = None,
+ cluster_metadata: Optional[Dict] = None,
+ client: Optional[bagel.Client] = None,
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ ) -> None:
+ """Initialize with bagel client"""
+ try:
+ import bagel
+ import bagel.config
+ except ImportError:
+ raise ImportError("Please install bagel `pip install betabageldb`.")
+ if client is not None:
+ self._client_settings = client_settings
+ self._client = client
+ else:
+ if client_settings:
+ _client_settings = client_settings
+ else:
+ _client_settings = bagel.config.Settings(
+ bagel_api_impl="rest",
+ bagel_server_host="api.bageldb.ai",
+ )
+ self._client_settings = _client_settings
+ self._client = bagel.Client(_client_settings)
+
+ self._cluster = self._client.get_or_create_cluster(
+ name=cluster_name,
+ metadata=cluster_metadata,
+ )
+ self.override_relevance_score_fn = relevance_score_fn
+ self._embedding_function = embedding_function
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embedding_function
+
+ @xor_args(("query_texts", "query_embeddings"))
+ def __query_cluster(
+ self,
+ query_texts: Optional[List[str]] = None,
+ query_embeddings: Optional[List[List[float]]] = None,
+ n_results: int = 4,
+ where: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Query the BagelDB cluster based on the provided parameters."""
+ try:
+ import bagel # noqa: F401
+ except ImportError:
+ raise ImportError("Please install bagel `pip install betabageldb`.")
+ return self._cluster.find(
+ query_texts=query_texts,
+ query_embeddings=query_embeddings,
+ n_results=n_results,
+ where=where,
+ **kwargs,
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ embeddings: Optional[List[List[float]]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """
+ Add texts along with their corresponding embeddings and optional
+ metadata to the BagelDB cluster.
+
+ Args:
+ texts (Iterable[str]): Texts to be added.
+ embeddings (Optional[List[float]]): List of embeddingvectors
+ metadatas (Optional[List[dict]]): Optional list of metadatas.
+ ids (Optional[List[str]]): List of unique ID for the texts.
+
+ Returns:
+ List[str]: List of unique ID representing the added texts.
+ """
+ # creating unique ids if None
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ texts = list(texts)
+ if self._embedding_function and embeddings is None and texts:
+ embeddings = self._embedding_function.embed_documents(texts)
+ if metadatas:
+ length_diff = len(texts) - len(metadatas)
+ if length_diff:
+ metadatas = metadatas + [{}] * length_diff
+ empty_ids = []
+ non_empty_ids = []
+ for idx, metadata in enumerate(metadatas):
+ if metadata:
+ non_empty_ids.append(idx)
+ else:
+ empty_ids.append(idx)
+ if non_empty_ids:
+ metadatas = [metadatas[idx] for idx in non_empty_ids]
+ texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
+ embeddings_with_metadatas = (
+ [embeddings[idx] for idx in non_empty_ids] if embeddings else None
+ )
+ ids_with_metadata = [ids[idx] for idx in non_empty_ids]
+ self._cluster.upsert(
+ embeddings=embeddings_with_metadatas,
+ metadatas=metadatas,
+ documents=texts_with_metadatas,
+ ids=ids_with_metadata,
+ )
+ if empty_ids:
+ texts_without_metadatas = [texts[j] for j in empty_ids]
+ embeddings_without_metadatas = (
+ [embeddings[j] for j in empty_ids] if embeddings else None
+ )
+ ids_without_metadatas = [ids[j] for j in empty_ids]
+ self._cluster.upsert(
+ embeddings=embeddings_without_metadatas,
+ documents=texts_without_metadatas,
+ ids=ids_without_metadatas,
+ )
+ else:
+ metadatas = [{}] * len(texts)
+ self._cluster.upsert(
+ embeddings=embeddings,
+ documents=texts,
+ metadatas=metadatas,
+ ids=ids,
+ )
+ return ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = DEFAULT_K,
+ where: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """
+ Run a similarity search with BagelDB.
+
+ Args:
+ query (str): The query text to search for similar documents/texts.
+ k (int): The number of results to return.
+ where (Optional[Dict[str, str]]): Metadata filters to narrow down.
+
+ Returns:
+ List[Document]: List of documents objects representing
+ the documents most similar to the query text.
+ """
+ docs_and_scores = self.similarity_search_with_score(query, k, where=where)
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = DEFAULT_K,
+ where: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """
+ Run a similarity search with BagelDB and return documents with their
+ corresponding similarity scores.
+
+ Args:
+ query (str): The query text to search for similar documents.
+ k (int): The number of results to return.
+ where (Optional[Dict[str, str]]): Filter using metadata.
+
+ Returns:
+ List[Tuple[Document, float]]: List of tuples, each containing a
+ Document object representing a similar document and its
+ corresponding similarity score.
+
+ """
+ results = self.__query_cluster(query_texts=[query], n_results=k, where=where)
+ return _results_to_docs_and_scores(results)
+
+ @classmethod
+ def from_texts(
+ cls: Type[Bagel],
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ cluster_name: str = _LANGCHAIN_DEFAULT_CLUSTER_NAME,
+ client_settings: Optional[bagel.config.Settings] = None,
+ cluster_metadata: Optional[Dict] = None,
+ client: Optional[bagel.Client] = None,
+ text_embeddings: Optional[List[List[float]]] = None,
+ **kwargs: Any,
+ ) -> Bagel:
+ """
+ Create and initialize a Bagel instance from list of texts.
+
+ Args:
+ texts (List[str]): List of text content to be added.
+ cluster_name (str): The name of the BagelDB cluster.
+ client_settings (Optional[bagel.config.Settings]): Client settings.
+ cluster_metadata (Optional[Dict]): Metadata of the cluster.
+ embeddings (Optional[Embeddings]): List of embedding.
+ metadatas (Optional[List[dict]]): List of metadata.
+ ids (Optional[List[str]]): List of unique ID. Defaults to None.
+ client (Optional[bagel.Client]): Bagel client instance.
+
+ Returns:
+ Bagel: Bagel vectorstore.
+ """
+ bagel_cluster = cls(
+ cluster_name=cluster_name,
+ embedding_function=embedding,
+ client_settings=client_settings,
+ client=client,
+ cluster_metadata=cluster_metadata,
+ **kwargs,
+ )
+ _ = bagel_cluster.add_texts(
+ texts=texts, embeddings=text_embeddings, metadatas=metadatas, ids=ids
+ )
+ return bagel_cluster
+
+ def delete_cluster(self) -> None:
+ """Delete the cluster."""
+ self._client.delete_cluster(self._cluster.name)
+
+ def similarity_search_by_vector_with_relevance_scores(
+ self,
+ query_embeddings: List[float],
+ k: int = DEFAULT_K,
+ where: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """
+ Return docs most similar to embedding vector and similarity score.
+ """
+ results = self.__query_cluster(
+ query_embeddings=query_embeddings, n_results=k, where=where
+ )
+ return _results_to_docs_and_scores(results)
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = DEFAULT_K,
+ where: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector."""
+ results = self.__query_cluster(
+ query_embeddings=embedding, n_results=k, where=where
+ )
+ return _results_to_docs(results)
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ Select and return the appropriate relevance score function based
+ on the distance metric used in the BagelDB cluster.
+ """
+ if self.override_relevance_score_fn:
+ return self.override_relevance_score_fn
+
+ distance = "l2"
+ distance_key = "hnsw:space"
+ metadata = self._cluster.metadata
+
+ if metadata and distance_key in metadata:
+ distance = metadata[distance_key]
+
+ if distance == "cosine":
+ return self._cosine_relevance_score_fn
+ elif distance == "l2":
+ return self._euclidean_relevance_score_fn
+ elif distance == "ip":
+ return self._max_inner_product_relevance_score_fn
+ else:
+ raise ValueError(
+ "No supported normalization function for distance"
+ f" metric of type: {distance}. Consider providing"
+ " relevance_score_fn to Bagel constructor."
+ )
+
+ @classmethod
+ def from_documents(
+ cls: Type[Bagel],
+ documents: List[Document],
+ embedding: Optional[Embeddings] = None,
+ ids: Optional[List[str]] = None,
+ cluster_name: str = _LANGCHAIN_DEFAULT_CLUSTER_NAME,
+ client_settings: Optional[bagel.config.Settings] = None,
+ client: Optional[bagel.Client] = None,
+ cluster_metadata: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> Bagel:
+ """
+ Create a Bagel vectorstore from a list of documents.
+
+ Args:
+ documents (List[Document]): List of Document objects to add to the
+ Bagel vectorstore.
+ embedding (Optional[List[float]]): List of embedding.
+ ids (Optional[List[str]]): List of IDs. Defaults to None.
+ cluster_name (str): The name of the BagelDB cluster.
+ client_settings (Optional[bagel.config.Settings]): Client settings.
+ client (Optional[bagel.Client]): Bagel client instance.
+ cluster_metadata (Optional[Dict]): Metadata associated with the
+ Bagel cluster. Defaults to None.
+
+ Returns:
+ Bagel: Bagel vectorstore.
+ """
+ texts = [doc.page_content for doc in documents]
+ metadatas = [doc.metadata for doc in documents]
+ return cls.from_texts(
+ texts=texts,
+ embedding=embedding,
+ metadatas=metadatas,
+ ids=ids,
+ cluster_name=cluster_name,
+ client_settings=client_settings,
+ client=client,
+ cluster_metadata=cluster_metadata,
+ **kwargs,
+ )
+
+ def update_document(self, document_id: str, document: Document) -> None:
+ """Update a document in the cluster.
+
+ Args:
+ document_id (str): ID of the document to update.
+ document (Document): Document to update.
+ """
+ text = document.page_content
+ metadata = document.metadata
+ self._cluster.update(
+ ids=[document_id],
+ documents=[text],
+ metadatas=[metadata],
+ )
+
+ def get(
+ self,
+ ids: Optional[OneOrMany[ID]] = None,
+ where: Optional[Where] = None,
+ limit: Optional[int] = None,
+ offset: Optional[int] = None,
+ where_document: Optional[WhereDocument] = None,
+ include: Optional[List[str]] = None,
+ ) -> Dict[str, Any]:
+ """Gets the collection."""
+ kwargs = {
+ "ids": ids,
+ "where": where,
+ "limit": limit,
+ "offset": offset,
+ "where_document": where_document,
+ }
+
+ if include is not None:
+ kwargs["include"] = include
+
+ return self._cluster.get(**kwargs)
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
+ """
+ Delete by IDs.
+
+ Args:
+ ids: List of ids to delete.
+ """
+ self._cluster.delete(ids=ids)
diff --git a/libs/community/langchain_community/vectorstores/baiducloud_vector_search.py b/libs/community/langchain_community/vectorstores/baiducloud_vector_search.py
new file mode 100644
index 00000000000..71c9694e090
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/baiducloud_vector_search.py
@@ -0,0 +1,491 @@
+import logging
+import uuid
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ from elasticsearch import Elasticsearch
+
+logger = logging.getLogger(__name__)
+
+
+class BESVectorStore(VectorStore):
+ """`Baidu Elasticsearch` vector store.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import BESVectorStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ vectorstore = BESVectorStore(
+ embedding=OpenAIEmbeddings(),
+ index_name="langchain-demo",
+ bes_url="http://localhost:9200"
+ )
+
+ Args:
+ index_name: Name of the Elasticsearch index to create.
+ bes_url: URL of the Baidu Elasticsearch instance to connect to.
+ user: Username to use when connecting to Elasticsearch.
+ password: Password to use when connecting to Elasticsearch.
+
+ More information can be obtained from:
+ https://cloud.baidu.com/doc/BES/s/8llyn0hh4
+
+ """
+
+ def __init__(
+ self,
+ index_name: str,
+ bes_url: str,
+ user: Optional[str] = None,
+ password: Optional[str] = None,
+ embedding: Optional[Embeddings] = None,
+ **kwargs: Optional[dict],
+ ) -> None:
+ self.embedding = embedding
+ self.index_name = index_name
+ self.query_field = kwargs.get("query_field", "text")
+ self.vector_query_field = kwargs.get("vector_query_field", "vector")
+ self.space_type = kwargs.get("space_type", "cosine")
+ self.index_type = kwargs.get("index_type", "linear")
+ self.index_params = kwargs.get("index_params") or {}
+
+ if bes_url is not None:
+ self.client = BESVectorStore.bes_client(
+ bes_url=bes_url, username=user, password=password
+ )
+ else:
+ raise ValueError("""Please specified a bes connection url.""")
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self.embedding
+
+ @staticmethod
+ def bes_client(
+ *,
+ bes_url: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ ) -> "Elasticsearch":
+ try:
+ import elasticsearch
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+
+ connection_params: Dict[str, Any] = {}
+
+ connection_params["hosts"] = [bes_url]
+ if username and password:
+ connection_params["basic_auth"] = (username, password)
+
+ es_client = elasticsearch.Elasticsearch(**connection_params)
+ try:
+ es_client.info()
+ except Exception as e:
+ logger.error(f"Error connecting to Elasticsearch: {e}")
+ raise e
+ return es_client
+
+ def _create_index_if_not_exists(self, dims_length: Optional[int] = None) -> None:
+ """Create the index if it doesn't already exist.
+
+ Args:
+ dims_length: Length of the embedding vectors.
+ """
+
+ if self.client.indices.exists(index=self.index_name):
+ logger.info(f"Index {self.index_name} already exists. Skipping creation.")
+
+ else:
+ if dims_length is None:
+ raise ValueError(
+ "Cannot create index without specifying dims_length "
+ + "when the index doesn't already exist. "
+ )
+
+ indexMapping = self._index_mapping(dims_length=dims_length)
+
+ logger.debug(
+ f"Creating index {self.index_name} with mappings {indexMapping}"
+ )
+
+ self.client.indices.create(
+ index=self.index_name,
+ body={
+ "settings": {"index": {"knn": True}},
+ "mappings": {"properties": indexMapping},
+ },
+ )
+
+ def _index_mapping(self, dims_length: Union[int, None]) -> Dict:
+ """
+ Executes when the index is created.
+
+ Args:
+ dims_length: Numeric length of the embedding vectors,
+ or None if not using vector-based query.
+ index_params: The extra pamameters for creating index.
+
+ Returns:
+ Dict: The Elasticsearch settings and mappings for the strategy.
+ """
+ if "linear" == self.index_type:
+ return {
+ self.vector_query_field: {
+ "type": "bpack_vector",
+ "dims": dims_length,
+ "build_index": self.index_params.get("build_index", False),
+ }
+ }
+
+ elif "hnsw" == self.index_type:
+ return {
+ self.vector_query_field: {
+ "type": "bpack_vector",
+ "dims": dims_length,
+ "index_type": "hnsw",
+ "space_type": self.space_type,
+ "parameters": {
+ "ef_construction": self.index_params.get(
+ "hnsw_ef_construction", 200
+ ),
+ "m": self.index_params.get("hnsw_m", 4),
+ },
+ }
+ }
+ else:
+ return {
+ self.vector_query_field: {
+ "type": "bpack_vector",
+ "model_id": self.index_params.get("model_id", ""),
+ }
+ }
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Optional[bool]:
+ """Delete documents from the index.
+
+ Args:
+ ids: List of ids of documents to delete
+ """
+ try:
+ from elasticsearch.helpers import BulkIndexError, bulk
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+
+ body = []
+
+ if ids is None:
+ raise ValueError("ids must be provided.")
+
+ for _id in ids:
+ body.append({"_op_type": "delete", "_index": self.index_name, "_id": _id})
+
+ if len(body) > 0:
+ try:
+ bulk(
+ self.client,
+ body,
+ refresh=kwargs.get("refresh_indices", True),
+ ignore_status=404,
+ )
+ logger.debug(f"Deleted {len(body)} texts from index")
+ return True
+ except BulkIndexError as e:
+ logger.error(f"Error deleting texts: {e}")
+ raise e
+ else:
+ logger.info("No documents to delete")
+ return False
+
+ def _query_body(
+ self,
+ query_vector: Union[List[float], None],
+ filter: Optional[dict] = None,
+ search_params: Dict = {},
+ ) -> Dict:
+ query_vector_body = {"vector": query_vector, "k": search_params.get("k", 2)}
+
+ if filter is not None and len(filter) != 0:
+ query_vector_body["filter"] = filter
+
+ if "linear" == self.index_type:
+ query_vector_body["linear"] = True
+ else:
+ query_vector_body["ef"] = search_params.get("ef", 10)
+
+ return {
+ "size": search_params.get("size", 4),
+ "query": {"knn": {self.vector_query_field: query_vector_body}},
+ }
+
+ def _search(
+ self,
+ query: Optional[str] = None,
+ query_vector: Union[List[float], None] = None,
+ filter: Optional[dict] = None,
+ custom_query: Optional[Callable[[Dict, Union[str, None]], Dict]] = None,
+ search_params: Dict = {},
+ ) -> List[Tuple[Document, float]]:
+ """Return searched documents result from BES
+
+ Args:
+ query: Text to look up documents similar to.
+ query_vector: Embedding to look up documents similar to.
+ filter: Array of Baidu ElasticSearch filter clauses to apply to the query.
+ custom_query: Function to modify the query body before it is sent to BES.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+
+ if self.embedding and query is not None:
+ query_vector = self.embedding.embed_query(query)
+
+ query_body = self._query_body(
+ query_vector=query_vector, filter=filter, search_params=search_params
+ )
+
+ if custom_query is not None:
+ query_body = custom_query(query_body, query)
+ logger.debug(f"Calling custom_query, Query body now: {query_body}")
+
+ logger.debug(f"Query body: {query_body}")
+
+ # Perform the kNN search on the BES index and return the results.
+ response = self.client.search(index=self.index_name, body=query_body)
+ logger.debug(f"response={response}")
+
+ hits = [hit for hit in response["hits"]["hits"]]
+ docs_and_scores = [
+ (
+ Document(
+ page_content=hit["_source"][self.query_field],
+ metadata=hit["_source"]["metadata"],
+ ),
+ hit["_score"],
+ )
+ for hit in hits
+ ]
+
+ return docs_and_scores
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return documents most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Array of Elasticsearch filter clauses to apply to the query.
+
+ Returns:
+ List of Documents most similar to the query,
+ in descending order of similarity.
+ """
+
+ results = self.similarity_search_with_score(
+ query=query, k=k, filter=filter, **kwargs
+ )
+ return [doc for doc, _ in results]
+
+ def similarity_search_with_score(
+ self, query: str, k: int, filter: Optional[dict] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Return documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ size: Number of Documents to return. Defaults to 4.
+ filter: Array of Elasticsearch filter clauses to apply to the query.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ search_params = kwargs.get("search_params") or {}
+
+ if len(search_params) == 0 or search_params.get("size") is None:
+ search_params["size"] = k
+
+ return self._search(query=query, filter=filter, **kwargs)
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: List[Document],
+ embedding: Optional[Embeddings] = None,
+ **kwargs: Any,
+ ) -> "BESVectorStore":
+ """Construct BESVectorStore wrapper from documents.
+
+ Args:
+ documents: List of documents to add to the Elasticsearch index.
+ embedding: Embedding function to use to embed the texts.
+ Do not provide if using a strategy
+ that doesn't require inference.
+ kwargs: create index key words arguments
+ """
+
+ vectorStore = BESVectorStore._bes_vector_store(embedding=embedding, **kwargs)
+ # Encode the provided texts and add them to the newly created index.
+ vectorStore.add_documents(documents)
+
+ return vectorStore
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[Dict[str, Any]]] = None,
+ **kwargs: Any,
+ ) -> "BESVectorStore":
+ """Construct BESVectorStore wrapper from raw documents.
+
+ Args:
+ texts: List of texts to add to the Elasticsearch index.
+ embedding: Embedding function to use to embed the texts.
+ metadatas: Optional list of metadatas associated with the texts.
+ index_name: Name of the Elasticsearch index to create.
+ kwargs: create index key words arguments
+ """
+
+ vectorStore = BESVectorStore._bes_vector_store(embedding=embedding, **kwargs)
+
+ # Encode the provided texts and add them to the newly created index.
+ vectorStore.add_texts(texts, metadatas=metadatas, **kwargs)
+
+ return vectorStore
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ try:
+ from elasticsearch.helpers import BulkIndexError, bulk
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+
+ embeddings = []
+ create_index_if_not_exists = kwargs.get("create_index_if_not_exists", True)
+ ids = kwargs.get("ids", [str(uuid.uuid4()) for _ in texts])
+ refresh_indices = kwargs.get("refresh_indices", True)
+ requests = []
+
+ if self.embedding is not None:
+ embeddings = self.embedding.embed_documents(list(texts))
+ dims_length = len(embeddings[0])
+
+ if create_index_if_not_exists:
+ self._create_index_if_not_exists(dims_length=dims_length)
+
+ for i, (text, vector) in enumerate(zip(texts, embeddings)):
+ metadata = metadatas[i] if metadatas else {}
+
+ requests.append(
+ {
+ "_op_type": "index",
+ "_index": self.index_name,
+ self.query_field: text,
+ self.vector_query_field: vector,
+ "metadata": metadata,
+ "_id": ids[i],
+ }
+ )
+
+ else:
+ if create_index_if_not_exists:
+ self._create_index_if_not_exists()
+
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+
+ requests.append(
+ {
+ "_op_type": "index",
+ "_index": self.index_name,
+ self.query_field: text,
+ "metadata": metadata,
+ "_id": ids[i],
+ }
+ )
+
+ if len(requests) > 0:
+ try:
+ success, failed = bulk(
+ self.client, requests, stats_only=True, refresh=refresh_indices
+ )
+ logger.debug(
+ f"Added {success} and failed to add {failed} texts to index"
+ )
+
+ logger.debug(f"added texts {ids} to index")
+ return ids
+ except BulkIndexError as e:
+ logger.error(f"Error adding texts: {e}")
+ firstError = e.errors[0].get("index", {}).get("error", {})
+ logger.error(f"First error reason: {firstError.get('reason')}")
+ raise e
+
+ else:
+ logger.debug("No texts to add to index")
+ return []
+
+ @staticmethod
+ def _bes_vector_store(
+ embedding: Optional[Embeddings] = None, **kwargs: Any
+ ) -> "BESVectorStore":
+ index_name = kwargs.get("index_name")
+
+ if index_name is None:
+ raise ValueError("Please provide an index_name.")
+
+ bes_url = kwargs.get("bes_url")
+ if bes_url is None:
+ raise ValueError("Please provided a valid bes connection url")
+
+ return BESVectorStore(embedding=embedding, **kwargs)
diff --git a/libs/community/langchain_community/vectorstores/cassandra.py b/libs/community/langchain_community/vectorstores/cassandra.py
new file mode 100644
index 00000000000..2a062014339
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/cassandra.py
@@ -0,0 +1,457 @@
+from __future__ import annotations
+
+import typing
+import uuid
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+import numpy as np
+
+if typing.TYPE_CHECKING:
+ from cassandra.cluster import Session
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+CVST = TypeVar("CVST", bound="Cassandra")
+
+
+class Cassandra(VectorStore):
+ """Wrapper around Apache Cassandra(R) for vector-store workloads.
+
+ To use it, you need a recent installation of the `cassio` library
+ and a Cassandra cluster / Astra DB instance supporting vector capabilities.
+
+ Visit the cassio.org website for extensive quickstarts and code examples.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Cassandra
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ session = ... # create your Cassandra session object
+ keyspace = 'my_keyspace' # the keyspace should exist already
+ table_name = 'my_vector_store'
+ vectorstore = Cassandra(embeddings, session, keyspace, table_name)
+ """
+
+ _embedding_dimension: Union[int, None]
+
+ @staticmethod
+ def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]:
+ if filter_dict is None:
+ return {}
+ else:
+ return filter_dict
+
+ def _get_embedding_dimension(self) -> int:
+ if self._embedding_dimension is None:
+ self._embedding_dimension = len(
+ self.embedding.embed_query("This is a sample sentence.")
+ )
+ return self._embedding_dimension
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ session: Session,
+ keyspace: str,
+ table_name: str,
+ ttl_seconds: Optional[int] = None,
+ ) -> None:
+ try:
+ from cassio.vector import VectorTable
+ except (ImportError, ModuleNotFoundError):
+ raise ImportError(
+ "Could not import cassio python package. "
+ "Please install it with `pip install cassio`."
+ )
+ """Create a vector table."""
+ self.embedding = embedding
+ self.session = session
+ self.keyspace = keyspace
+ self.table_name = table_name
+ self.ttl_seconds = ttl_seconds
+ #
+ self._embedding_dimension = None
+ #
+ self.table = VectorTable(
+ session=session,
+ keyspace=keyspace,
+ table=table_name,
+ embedding_dimension=self._get_embedding_dimension(),
+ primary_key_type="TEXT",
+ )
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding
+
+ @staticmethod
+ def _dont_flip_the_cos_score(distance: float) -> float:
+ # the identity
+ return distance
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The underlying VectorTable already returns a "score proper",
+ i.e. one in [0, 1] where higher means more *similar*,
+ so here the final score transformation is not reversing the interval:
+ """
+ return self._dont_flip_the_cos_score
+
+ def delete_collection(self) -> None:
+ """
+ Just an alias for `clear`
+ (to better align with other VectorStore implementations).
+ """
+ self.clear()
+
+ def clear(self) -> None:
+ """Empty the collection."""
+ self.table.clear()
+
+ def delete_by_document_id(self, document_id: str) -> None:
+ return self.table.delete(document_id)
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete by vector IDs.
+
+
+ Args:
+ ids: List of ids to delete.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ for document_id in ids:
+ self.delete_by_document_id(document_id)
+ return True
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ batch_size: int = 16,
+ ttl_seconds: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts (Iterable[str]): Texts to add to the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ ids (Optional[List[str]], optional): Optional list of IDs.
+ batch_size (int): Number of concurrent requests to send to the server.
+ ttl_seconds (Optional[int], optional): Optional time-to-live
+ for the added texts.
+
+ Returns:
+ List[str]: List of IDs of the added texts.
+ """
+ _texts = list(texts) # lest it be a generator or something
+ if ids is None:
+ ids = [uuid.uuid4().hex for _ in _texts]
+ if metadatas is None:
+ metadatas = [{} for _ in _texts]
+ #
+ ttl_seconds = ttl_seconds or self.ttl_seconds
+ #
+ embedding_vectors = self.embedding.embed_documents(_texts)
+ #
+ for i in range(0, len(_texts), batch_size):
+ batch_texts = _texts[i : i + batch_size]
+ batch_embedding_vectors = embedding_vectors[i : i + batch_size]
+ batch_ids = ids[i : i + batch_size]
+ batch_metadatas = metadatas[i : i + batch_size]
+
+ futures = [
+ self.table.put_async(
+ text, embedding_vector, text_id, metadata, ttl_seconds
+ )
+ for text, embedding_vector, text_id, metadata in zip(
+ batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
+ )
+ ]
+ for future in futures:
+ future.result()
+ return ids
+
+ # id-returning search facilities
+ def similarity_search_with_score_id_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Tuple[Document, float, str]]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding (str): Embedding to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ Returns:
+ List of (Document, score, id), the most similar to the query vector.
+ """
+ search_metadata = self._filter_to_metadata(filter)
+ #
+ hits = self.table.search(
+ embedding_vector=embedding,
+ top_k=k,
+ metric="cos",
+ metric_threshold=None,
+ metadata=search_metadata,
+ )
+ # We stick to 'cos' distance as it can be normalized on a 0-1 axis
+ # (1=most relevant), as required by this class' contract.
+ return [
+ (
+ Document(
+ page_content=hit["document"],
+ metadata=hit["metadata"],
+ ),
+ 0.5 + 0.5 * hit["distance"],
+ hit["document_id"],
+ )
+ for hit in hits
+ ]
+
+ def similarity_search_with_score_id(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Tuple[Document, float, str]]:
+ embedding_vector = self.embedding.embed_query(query)
+ return self.similarity_search_with_score_id_by_vector(
+ embedding=embedding_vector,
+ k=k,
+ filter=filter,
+ )
+
+ # id-unaware search facilities
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding (str): Embedding to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ Returns:
+ List of (Document, score), the most similar to the query vector.
+ """
+ return [
+ (doc, score)
+ for (doc, score, docId) in self.similarity_search_with_score_id_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ )
+ ]
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ embedding_vector = self.embedding.embed_query(query)
+ return self.similarity_search_by_vector(
+ embedding_vector,
+ k,
+ filter=filter,
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ return [
+ doc
+ for doc, _ in self.similarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ )
+ ]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Tuple[Document, float]]:
+ embedding_vector = self.embedding.embed_query(query)
+ return self.similarity_search_with_score_by_vector(
+ embedding_vector,
+ k,
+ filter=filter,
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ search_metadata = self._filter_to_metadata(filter)
+
+ prefetchHits = self.table.search(
+ embedding_vector=embedding,
+ top_k=fetch_k,
+ metric="cos",
+ metric_threshold=None,
+ metadata=search_metadata,
+ )
+ # let the mmr utility pick the *indices* in the above array
+ mmrChosenIndices = maximal_marginal_relevance(
+ np.array(embedding, dtype=np.float32),
+ [pfHit["embedding_vector"] for pfHit in prefetchHits],
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ mmrHits = [
+ pfHit
+ for pfIndex, pfHit in enumerate(prefetchHits)
+ if pfIndex in mmrChosenIndices
+ ]
+ return [
+ Document(
+ page_content=hit["document"],
+ metadata=hit["metadata"],
+ )
+ for hit in mmrHits
+ ]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Optional.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding_vector = self.embedding.embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ embedding_vector,
+ k,
+ fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ )
+
+ @classmethod
+ def from_texts(
+ cls: Type[CVST],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ batch_size: int = 16,
+ **kwargs: Any,
+ ) -> CVST:
+ """Create a Cassandra vectorstore from raw texts.
+
+ No support for specifying text IDs
+
+ Returns:
+ a Cassandra vectorstore.
+ """
+ session: Session = kwargs["session"]
+ keyspace: str = kwargs["keyspace"]
+ table_name: str = kwargs["table_name"]
+ cassandraStore = cls(
+ embedding=embedding,
+ session=session,
+ keyspace=keyspace,
+ table_name=table_name,
+ )
+ cassandraStore.add_texts(texts=texts, metadatas=metadatas)
+ return cassandraStore
+
+ @classmethod
+ def from_documents(
+ cls: Type[CVST],
+ documents: List[Document],
+ embedding: Embeddings,
+ batch_size: int = 16,
+ **kwargs: Any,
+ ) -> CVST:
+ """Create a Cassandra vectorstore from a document list.
+
+ No support for specifying text IDs
+
+ Returns:
+ a Cassandra vectorstore.
+ """
+ texts = [doc.page_content for doc in documents]
+ metadatas = [doc.metadata for doc in documents]
+ session: Session = kwargs["session"]
+ keyspace: str = kwargs["keyspace"]
+ table_name: str = kwargs["table_name"]
+ return cls.from_texts(
+ texts=texts,
+ metadatas=metadatas,
+ embedding=embedding,
+ session=session,
+ keyspace=keyspace,
+ table_name=table_name,
+ )
diff --git a/libs/community/langchain_community/vectorstores/chroma.py b/libs/community/langchain_community/vectorstores/chroma.py
new file mode 100644
index 00000000000..3773b380748
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/chroma.py
@@ -0,0 +1,790 @@
+from __future__ import annotations
+
+import base64
+import logging
+import uuid
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import xor_args
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+if TYPE_CHECKING:
+ import chromadb
+ import chromadb.config
+ from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
+
+logger = logging.getLogger()
+DEFAULT_K = 4 # Number of Documents to return.
+
+
+def _results_to_docs(results: Any) -> List[Document]:
+ return [doc for doc, _ in _results_to_docs_and_scores(results)]
+
+
+def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
+ return [
+ # TODO: Chroma can do batch querying,
+ # we shouldn't hard code to the 1st result
+ (Document(page_content=result[0], metadata=result[1] or {}), result[2])
+ for result in zip(
+ results["documents"][0],
+ results["metadatas"][0],
+ results["distances"][0],
+ )
+ ]
+
+
+class Chroma(VectorStore):
+ """`ChromaDB` vector store.
+
+ To use, you should have the ``chromadb`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Chroma
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ vectorstore = Chroma("langchain_store", embeddings)
+ """
+
+ _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
+
+ def __init__(
+ self,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ embedding_function: Optional[Embeddings] = None,
+ persist_directory: Optional[str] = None,
+ client_settings: Optional[chromadb.config.Settings] = None,
+ collection_metadata: Optional[Dict] = None,
+ client: Optional[chromadb.Client] = None,
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ ) -> None:
+ """Initialize with a Chroma client."""
+ try:
+ import chromadb
+ import chromadb.config
+ except ImportError:
+ raise ImportError(
+ "Could not import chromadb python package. "
+ "Please install it with `pip install chromadb`."
+ )
+
+ if client is not None:
+ self._client_settings = client_settings
+ self._client = client
+ self._persist_directory = persist_directory
+ else:
+ if client_settings:
+ # If client_settings is provided with persist_directory specified,
+ # then it is "in-memory and persisting to disk" mode.
+ client_settings.persist_directory = (
+ persist_directory or client_settings.persist_directory
+ )
+ if client_settings.persist_directory is not None:
+ # Maintain backwards compatibility with chromadb < 0.4.0
+ major, minor, _ = chromadb.__version__.split(".")
+ if int(major) == 0 and int(minor) < 4:
+ client_settings.chroma_db_impl = "duckdb+parquet"
+
+ _client_settings = client_settings
+ elif persist_directory:
+ # Maintain backwards compatibility with chromadb < 0.4.0
+ major, minor, _ = chromadb.__version__.split(".")
+ if int(major) == 0 and int(minor) < 4:
+ _client_settings = chromadb.config.Settings(
+ chroma_db_impl="duckdb+parquet",
+ )
+ else:
+ _client_settings = chromadb.config.Settings(is_persistent=True)
+ _client_settings.persist_directory = persist_directory
+ else:
+ _client_settings = chromadb.config.Settings()
+ self._client_settings = _client_settings
+ self._client = chromadb.Client(_client_settings)
+ self._persist_directory = (
+ _client_settings.persist_directory or persist_directory
+ )
+
+ self._embedding_function = embedding_function
+ self._collection = self._client.get_or_create_collection(
+ name=collection_name,
+ embedding_function=None,
+ metadata=collection_metadata,
+ )
+ self.override_relevance_score_fn = relevance_score_fn
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embedding_function
+
+ @xor_args(("query_texts", "query_embeddings"))
+ def __query_collection(
+ self,
+ query_texts: Optional[List[str]] = None,
+ query_embeddings: Optional[List[List[float]]] = None,
+ n_results: int = 4,
+ where: Optional[Dict[str, str]] = None,
+ where_document: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Query the chroma collection."""
+ try:
+ import chromadb # noqa: F401
+ except ImportError:
+ raise ValueError(
+ "Could not import chromadb python package. "
+ "Please install it with `pip install chromadb`."
+ )
+ return self._collection.query(
+ query_texts=query_texts,
+ query_embeddings=query_embeddings,
+ n_results=n_results,
+ where=where,
+ where_document=where_document,
+ **kwargs,
+ )
+
+ def encode_image(self, uri: str) -> str:
+ """Get base64 string from image URI."""
+ with open(uri, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode("utf-8")
+
+ def add_images(
+ self,
+ uris: List[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more images through the embeddings and add to the vectorstore.
+
+ Args:
+ uris List[str]: File path to the image.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ ids (Optional[List[str]], optional): Optional list of IDs.
+
+ Returns:
+ List[str]: List of IDs of the added images.
+ """
+ # Map from uris to b64 encoded strings
+ b64_texts = [self.encode_image(uri=uri) for uri in uris]
+ # Populate IDs
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in uris]
+ embeddings = None
+ # Set embeddings
+ if self._embedding_function is not None and hasattr(
+ self._embedding_function, "embed_image"
+ ):
+ embeddings = self._embedding_function.embed_image(uris=uris)
+ if metadatas:
+ # fill metadatas with empty dicts if somebody
+ # did not specify metadata for all images
+ length_diff = len(uris) - len(metadatas)
+ if length_diff:
+ metadatas = metadatas + [{}] * length_diff
+ empty_ids = []
+ non_empty_ids = []
+ for idx, m in enumerate(metadatas):
+ if m:
+ non_empty_ids.append(idx)
+ else:
+ empty_ids.append(idx)
+ if non_empty_ids:
+ metadatas = [metadatas[idx] for idx in non_empty_ids]
+ images_with_metadatas = [uris[idx] for idx in non_empty_ids]
+ embeddings_with_metadatas = (
+ [embeddings[idx] for idx in non_empty_ids] if embeddings else None
+ )
+ ids_with_metadata = [ids[idx] for idx in non_empty_ids]
+ try:
+ self._collection.upsert(
+ metadatas=metadatas,
+ embeddings=embeddings_with_metadatas,
+ documents=images_with_metadatas,
+ ids=ids_with_metadata,
+ )
+ except ValueError as e:
+ if "Expected metadata value to be" in str(e):
+ msg = (
+ "Try filtering complex metadata using "
+ "langchain.vectorstores.utils.filter_complex_metadata."
+ )
+ raise ValueError(e.args[0] + "\n\n" + msg)
+ else:
+ raise e
+ if empty_ids:
+ images_without_metadatas = [uris[j] for j in empty_ids]
+ embeddings_without_metadatas = (
+ [embeddings[j] for j in empty_ids] if embeddings else None
+ )
+ ids_without_metadatas = [ids[j] for j in empty_ids]
+ self._collection.upsert(
+ embeddings=embeddings_without_metadatas,
+ documents=images_without_metadatas,
+ ids=ids_without_metadatas,
+ )
+ else:
+ self._collection.upsert(
+ embeddings=embeddings,
+ documents=b64_texts,
+ ids=ids,
+ )
+ return ids
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts (Iterable[str]): Texts to add to the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ ids (Optional[List[str]], optional): Optional list of IDs.
+
+ Returns:
+ List[str]: List of IDs of the added texts.
+ """
+ # TODO: Handle the case where the user doesn't provide ids on the Collection
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+ embeddings = None
+ texts = list(texts)
+ if self._embedding_function is not None:
+ embeddings = self._embedding_function.embed_documents(texts)
+ if metadatas:
+ # fill metadatas with empty dicts if somebody
+ # did not specify metadata for all texts
+ length_diff = len(texts) - len(metadatas)
+ if length_diff:
+ metadatas = metadatas + [{}] * length_diff
+ empty_ids = []
+ non_empty_ids = []
+ for idx, m in enumerate(metadatas):
+ if m:
+ non_empty_ids.append(idx)
+ else:
+ empty_ids.append(idx)
+ if non_empty_ids:
+ metadatas = [metadatas[idx] for idx in non_empty_ids]
+ texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
+ embeddings_with_metadatas = (
+ [embeddings[idx] for idx in non_empty_ids] if embeddings else None
+ )
+ ids_with_metadata = [ids[idx] for idx in non_empty_ids]
+ try:
+ self._collection.upsert(
+ metadatas=metadatas,
+ embeddings=embeddings_with_metadatas,
+ documents=texts_with_metadatas,
+ ids=ids_with_metadata,
+ )
+ except ValueError as e:
+ if "Expected metadata value to be" in str(e):
+ msg = (
+ "Try filtering complex metadata from the document using "
+ "langchain.vectorstores.utils.filter_complex_metadata."
+ )
+ raise ValueError(e.args[0] + "\n\n" + msg)
+ else:
+ raise e
+ if empty_ids:
+ texts_without_metadatas = [texts[j] for j in empty_ids]
+ embeddings_without_metadatas = (
+ [embeddings[j] for j in empty_ids] if embeddings else None
+ )
+ ids_without_metadatas = [ids[j] for j in empty_ids]
+ self._collection.upsert(
+ embeddings=embeddings_without_metadatas,
+ documents=texts_without_metadatas,
+ ids=ids_without_metadatas,
+ )
+ else:
+ self._collection.upsert(
+ embeddings=embeddings,
+ documents=texts,
+ ids=ids,
+ )
+ return ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = DEFAULT_K,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search with Chroma.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Document]: List of documents most similar to the query text.
+ """
+ docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = DEFAULT_K,
+ filter: Optional[Dict[str, str]] = None,
+ where_document: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+ Args:
+ embedding (List[float]): Embedding to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ results = self.__query_collection(
+ query_embeddings=embedding,
+ n_results=k,
+ where=filter,
+ where_document=where_document,
+ )
+ return _results_to_docs(results)
+
+ def similarity_search_by_vector_with_relevance_scores(
+ self,
+ embedding: List[float],
+ k: int = DEFAULT_K,
+ filter: Optional[Dict[str, str]] = None,
+ where_document: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """
+ Return docs most similar to embedding vector and similarity score.
+
+ Args:
+ embedding (List[float]): Embedding to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Tuple[Document, float]]: List of documents most similar to
+ the query text and cosine distance in float for each.
+ Lower score represents more similarity.
+ """
+ results = self.__query_collection(
+ query_embeddings=embedding,
+ n_results=k,
+ where=filter,
+ where_document=where_document,
+ )
+ return _results_to_docs_and_scores(results)
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = DEFAULT_K,
+ filter: Optional[Dict[str, str]] = None,
+ where_document: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Run similarity search with Chroma with distance.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Tuple[Document, float]]: List of documents most similar to
+ the query text and cosine distance in float for each.
+ Lower score represents more similarity.
+ """
+ if self._embedding_function is None:
+ results = self.__query_collection(
+ query_texts=[query],
+ n_results=k,
+ where=filter,
+ where_document=where_document,
+ )
+ else:
+ query_embedding = self._embedding_function.embed_query(query)
+ results = self.__query_collection(
+ query_embeddings=[query_embedding],
+ n_results=k,
+ where=filter,
+ where_document=where_document,
+ )
+
+ return _results_to_docs_and_scores(results)
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+ if self.override_relevance_score_fn:
+ return self.override_relevance_score_fn
+
+ distance = "l2"
+ distance_key = "hnsw:space"
+ metadata = self._collection.metadata
+
+ if metadata and distance_key in metadata:
+ distance = metadata[distance_key]
+
+ if distance == "cosine":
+ return self._cosine_relevance_score_fn
+ elif distance == "l2":
+ return self._euclidean_relevance_score_fn
+ elif distance == "ip":
+ return self._max_inner_product_relevance_score_fn
+ else:
+ raise ValueError(
+ "No supported normalization function"
+ f" for distance metric of type: {distance}."
+ "Consider providing relevance_score_fn to Chroma constructor."
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = DEFAULT_K,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ where_document: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+
+ results = self.__query_collection(
+ query_embeddings=embedding,
+ n_results=fetch_k,
+ where=filter,
+ where_document=where_document,
+ include=["metadatas", "documents", "distances", "embeddings"],
+ )
+ mmr_selected = maximal_marginal_relevance(
+ np.array(embedding, dtype=np.float32),
+ results["embeddings"][0],
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+
+ candidates = _results_to_docs(results)
+
+ selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
+ return selected_results
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = DEFAULT_K,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ where_document: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ if self._embedding_function is None:
+ raise ValueError(
+ "For MMR search, you must specify an embedding function on" "creation."
+ )
+
+ embedding = self._embedding_function.embed_query(query)
+ docs = self.max_marginal_relevance_search_by_vector(
+ embedding,
+ k,
+ fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ where_document=where_document,
+ )
+ return docs
+
+ def delete_collection(self) -> None:
+ """Delete the collection."""
+ self._client.delete_collection(self._collection.name)
+
+ def get(
+ self,
+ ids: Optional[OneOrMany[ID]] = None,
+ where: Optional[Where] = None,
+ limit: Optional[int] = None,
+ offset: Optional[int] = None,
+ where_document: Optional[WhereDocument] = None,
+ include: Optional[List[str]] = None,
+ ) -> Dict[str, Any]:
+ """Gets the collection.
+
+ Args:
+ ids: The ids of the embeddings to get. Optional.
+ where: A Where type dict used to filter results by.
+ E.g. `{"color" : "red", "price": 4.20}`. Optional.
+ limit: The number of documents to return. Optional.
+ offset: The offset to start returning results from.
+ Useful for paging results with limit. Optional.
+ where_document: A WhereDocument type dict used to filter by the documents.
+ E.g. `{$contains: "hello"}`. Optional.
+ include: A list of what to include in the results.
+ Can contain `"embeddings"`, `"metadatas"`, `"documents"`.
+ Ids are always included.
+ Defaults to `["metadatas", "documents"]`. Optional.
+ """
+ kwargs = {
+ "ids": ids,
+ "where": where,
+ "limit": limit,
+ "offset": offset,
+ "where_document": where_document,
+ }
+
+ if include is not None:
+ kwargs["include"] = include
+
+ return self._collection.get(**kwargs)
+
+ def persist(self) -> None:
+ """Persist the collection.
+
+ This can be used to explicitly persist the data to disk.
+ It will also be called automatically when the object is destroyed.
+ """
+ if self._persist_directory is None:
+ raise ValueError(
+ "You must specify a persist_directory on"
+ "creation to persist the collection."
+ )
+ import chromadb
+
+ # Maintain backwards compatibility with chromadb < 0.4.0
+ major, minor, _ = chromadb.__version__.split(".")
+ if int(major) == 0 and int(minor) < 4:
+ self._client.persist()
+
+ def update_document(self, document_id: str, document: Document) -> None:
+ """Update a document in the collection.
+
+ Args:
+ document_id (str): ID of the document to update.
+ document (Document): Document to update.
+ """
+ return self.update_documents([document_id], [document])
+
+ def update_documents(self, ids: List[str], documents: List[Document]) -> None:
+ """Update a document in the collection.
+
+ Args:
+ ids (List[str]): List of ids of the document to update.
+ documents (List[Document]): List of documents to update.
+ """
+ text = [document.page_content for document in documents]
+ metadata = [document.metadata for document in documents]
+ if self._embedding_function is None:
+ raise ValueError(
+ "For update, you must specify an embedding function on creation."
+ )
+ embeddings = self._embedding_function.embed_documents(text)
+
+ if hasattr(
+ self._collection._client, "max_batch_size"
+ ): # for Chroma 0.4.10 and above
+ from chromadb.utils.batch_utils import create_batches
+
+ for batch in create_batches(
+ api=self._collection._client,
+ ids=ids,
+ metadatas=metadata,
+ documents=text,
+ embeddings=embeddings,
+ ):
+ self._collection.update(
+ ids=batch[0],
+ embeddings=batch[1],
+ documents=batch[3],
+ metadatas=batch[2],
+ )
+ else:
+ self._collection.update(
+ ids=ids,
+ embeddings=embeddings,
+ documents=text,
+ metadatas=metadata,
+ )
+
+ @classmethod
+ def from_texts(
+ cls: Type[Chroma],
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ persist_directory: Optional[str] = None,
+ client_settings: Optional[chromadb.config.Settings] = None,
+ client: Optional[chromadb.Client] = None,
+ collection_metadata: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> Chroma:
+ """Create a Chroma vectorstore from a raw documents.
+
+ If a persist_directory is specified, the collection will be persisted there.
+ Otherwise, the data will be ephemeral in-memory.
+
+ Args:
+ texts (List[str]): List of texts to add to the collection.
+ collection_name (str): Name of the collection to create.
+ persist_directory (Optional[str]): Directory to persist the collection.
+ embedding (Optional[Embeddings]): Embedding function. Defaults to None.
+ metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
+ ids (Optional[List[str]]): List of document IDs. Defaults to None.
+ client_settings (Optional[chromadb.config.Settings]): Chroma client settings
+ collection_metadata (Optional[Dict]): Collection configurations.
+ Defaults to None.
+
+ Returns:
+ Chroma: Chroma vectorstore.
+ """
+ chroma_collection = cls(
+ collection_name=collection_name,
+ embedding_function=embedding,
+ persist_directory=persist_directory,
+ client_settings=client_settings,
+ client=client,
+ collection_metadata=collection_metadata,
+ **kwargs,
+ )
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+ if hasattr(
+ chroma_collection._client, "max_batch_size"
+ ): # for Chroma 0.4.10 and above
+ from chromadb.utils.batch_utils import create_batches
+
+ for batch in create_batches(
+ api=chroma_collection._client,
+ ids=ids,
+ metadatas=metadatas,
+ documents=texts,
+ ):
+ chroma_collection.add_texts(
+ texts=batch[3] if batch[3] else [],
+ metadatas=batch[2] if batch[2] else None,
+ ids=batch[0],
+ )
+ else:
+ chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
+ return chroma_collection
+
+ @classmethod
+ def from_documents(
+ cls: Type[Chroma],
+ documents: List[Document],
+ embedding: Optional[Embeddings] = None,
+ ids: Optional[List[str]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ persist_directory: Optional[str] = None,
+ client_settings: Optional[chromadb.config.Settings] = None,
+ client: Optional[chromadb.Client] = None, # Add this line
+ collection_metadata: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> Chroma:
+ """Create a Chroma vectorstore from a list of documents.
+
+ If a persist_directory is specified, the collection will be persisted there.
+ Otherwise, the data will be ephemeral in-memory.
+
+ Args:
+ collection_name (str): Name of the collection to create.
+ persist_directory (Optional[str]): Directory to persist the collection.
+ ids (Optional[List[str]]): List of document IDs. Defaults to None.
+ documents (List[Document]): List of documents to add to the vectorstore.
+ embedding (Optional[Embeddings]): Embedding function. Defaults to None.
+ client_settings (Optional[chromadb.config.Settings]): Chroma client settings
+ collection_metadata (Optional[Dict]): Collection configurations.
+ Defaults to None.
+
+ Returns:
+ Chroma: Chroma vectorstore.
+ """
+ texts = [doc.page_content for doc in documents]
+ metadatas = [doc.metadata for doc in documents]
+ return cls.from_texts(
+ texts=texts,
+ embedding=embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ persist_directory=persist_directory,
+ client_settings=client_settings,
+ client=client,
+ collection_metadata=collection_metadata,
+ **kwargs,
+ )
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
+ """Delete by vector IDs.
+
+ Args:
+ ids: List of ids to delete.
+ """
+ self._collection.delete(ids=ids)
diff --git a/libs/community/langchain_community/vectorstores/clarifai.py b/libs/community/langchain_community/vectorstores/clarifai.py
new file mode 100644
index 00000000000..4bded7652b6
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/clarifai.py
@@ -0,0 +1,296 @@
+from __future__ import annotations
+
+import logging
+import os
+import traceback
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Iterable, List, Optional, Tuple
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger(__name__)
+
+
+class Clarifai(VectorStore):
+ """`Clarifai AI` vector store.
+
+ To use, you should have the ``clarifai`` python SDK package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Clarifai
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ vectorstore = Clarifai("langchain_store", embeddings.embed_query)
+ """
+
+ def __init__(
+ self,
+ user_id: Optional[str] = None,
+ app_id: Optional[str] = None,
+ number_of_docs: Optional[int] = None,
+ pat: Optional[str] = None,
+ ) -> None:
+ """Initialize with Clarifai client.
+
+ Args:
+ user_id (Optional[str], optional): User ID. Defaults to None.
+ app_id (Optional[str], optional): App ID. Defaults to None.
+ pat (Optional[str], optional): Personal access token. Defaults to None.
+ number_of_docs (Optional[int], optional): Number of documents to return
+ during vector search. Defaults to None.
+ api_base (Optional[str], optional): API base. Defaults to None.
+
+ Raises:
+ ValueError: If user ID, app ID or personal access token is not provided.
+ """
+ self._user_id = user_id or os.environ.get("CLARIFAI_USER_ID")
+ self._app_id = app_id or os.environ.get("CLARIFAI_APP_ID")
+ if pat:
+ os.environ["CLARIFAI_PAT"] = pat
+ self._pat = os.environ.get("CLARIFAI_PAT")
+ if self._user_id is None or self._app_id is None or self._pat is None:
+ raise ValueError(
+ "Could not find CLARIFAI_USER_ID, CLARIFAI_APP_ID or\
+ CLARIFAI_PAT in your environment. "
+ "Please set those env variables with a valid user ID, \
+ app ID and personal access token \
+ from https://clarifai.com/settings/security."
+ )
+ self._number_of_docs = number_of_docs
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add texts to the Clarifai vectorstore. This will push the text
+ to a Clarifai application.
+ Application use a base workflow that create and store embedding for each text.
+ Make sure you are using a base workflow that is compatible with text
+ (such as Language Understanding).
+
+ Args:
+ texts (Iterable[str]): Texts to add to the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ ids (Optional[List[str]], optional): Optional list of IDs.
+
+ """
+ try:
+ from clarifai.client.input import Inputs
+ from google.protobuf.struct_pb2 import Struct
+ except ImportError as e:
+ raise ImportError(
+ "Could not import clarifai python package. "
+ "Please install it with `pip install clarifai`."
+ ) from e
+
+ ltexts = list(texts)
+ length = len(ltexts)
+ assert length > 0, "No texts provided to add to the vectorstore."
+
+ if metadatas is not None:
+ assert length == len(
+ metadatas
+ ), "Number of texts and metadatas should be the same."
+
+ if ids is not None:
+ assert len(ltexts) == len(
+ ids
+ ), "Number of text inputs and input ids should be the same."
+
+ input_obj = Inputs(app_id=self._app_id, user_id=self._user_id)
+ batch_size = 32
+ input_job_ids = []
+ for idx in range(0, length, batch_size):
+ try:
+ batch_texts = ltexts[idx : idx + batch_size]
+ batch_metadatas = (
+ metadatas[idx : idx + batch_size] if metadatas else None
+ )
+ if batch_metadatas is not None:
+ meta_list = []
+ for meta in batch_metadatas:
+ meta_struct = Struct()
+ meta_struct.update(meta)
+ meta_list.append(meta_struct)
+ if ids is None:
+ ids = [uuid.uuid4().hex for _ in range(len(batch_texts))]
+ input_batch = [
+ input_obj.get_text_input(
+ input_id=ids[id],
+ raw_text=inp,
+ metadata=meta_list[id] if batch_metadatas else None,
+ )
+ for id, inp in enumerate(batch_texts)
+ ]
+ result_id = input_obj.upload_inputs(inputs=input_batch)
+ input_job_ids.extend(result_id)
+ logger.debug("Input posted successfully.")
+
+ except Exception as error:
+ logger.warning(f"Post inputs failed: {error}")
+ traceback.print_exc()
+
+ return input_job_ids
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filters: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Run similarity search with score using Clarifai.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata.
+ Defaults to None.
+
+ Returns:
+ List[Document]: List of documents most similar to the query text.
+ """
+ try:
+ from clarifai.client.search import Search
+ from clarifai_grpc.grpc.api import resources_pb2
+ from google.protobuf import json_format # type: ignore
+ except ImportError as e:
+ raise ImportError(
+ "Could not import clarifai python package. "
+ "Please install it with `pip install clarifai`."
+ ) from e
+
+ # Get number of docs to return
+ if self._number_of_docs is not None:
+ k = self._number_of_docs
+
+ search_obj = Search(user_id=self._user_id, app_id=self._app_id, top_k=k)
+ rank = [{"text_raw": query}]
+ # Add filter by metadata if provided.
+ if filters is not None:
+ search_metadata = {"metadata": filters}
+ search_response = search_obj.query(ranks=rank, filters=[search_metadata])
+ else:
+ search_response = search_obj.query(ranks=rank)
+
+ # Retrieve hits
+ hits = [hit for data in search_response for hit in data.hits]
+ executor = ThreadPoolExecutor(max_workers=10)
+
+ def hit_to_document(hit: resources_pb2.Hit) -> Tuple[Document, float]:
+ metadata = json_format.MessageToDict(hit.input.data.metadata)
+ h = {"Authorization": f"Key {self._pat}"}
+ request = requests.get(hit.input.data.text.url, headers=h)
+
+ # override encoding by real educated guess as provided by chardet
+ request.encoding = request.apparent_encoding
+ requested_text = request.text
+
+ logger.debug(
+ f"\tScore {hit.score:.2f} for annotation: {hit.annotation.id}\
+ off input: {hit.input.id}, text: {requested_text[:125]}"
+ )
+ return (Document(page_content=requested_text, metadata=metadata), hit.score)
+
+ # Iterate over hits and retrieve metadata and text
+ futures = [executor.submit(hit_to_document, hit) for hit in hits]
+ docs_and_scores = [future.result() for future in futures]
+
+ return docs_and_scores
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search using Clarifai.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ docs_and_scores = self.similarity_search_with_score(query, **kwargs)
+ return [doc for doc, _ in docs_and_scores]
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ user_id: Optional[str] = None,
+ app_id: Optional[str] = None,
+ number_of_docs: Optional[int] = None,
+ pat: Optional[str] = None,
+ **kwargs: Any,
+ ) -> Clarifai:
+ """Create a Clarifai vectorstore from a list of texts.
+
+ Args:
+ user_id (str): User ID.
+ app_id (str): App ID.
+ texts (List[str]): List of texts to add.
+ number_of_docs (Optional[int]): Number of documents to return
+ during vector search. Defaults to None.
+ metadatas (Optional[List[dict]]): Optional list of metadatas.
+ Defaults to None.
+
+ Returns:
+ Clarifai: Clarifai vectorstore.
+ """
+ clarifai_vector_db = cls(
+ user_id=user_id,
+ app_id=app_id,
+ number_of_docs=number_of_docs,
+ pat=pat,
+ )
+ clarifai_vector_db.add_texts(texts=texts, metadatas=metadatas)
+ return clarifai_vector_db
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: List[Document],
+ embedding: Optional[Embeddings] = None,
+ user_id: Optional[str] = None,
+ app_id: Optional[str] = None,
+ number_of_docs: Optional[int] = None,
+ pat: Optional[str] = None,
+ **kwargs: Any,
+ ) -> Clarifai:
+ """Create a Clarifai vectorstore from a list of documents.
+
+ Args:
+ user_id (str): User ID.
+ app_id (str): App ID.
+ documents (List[Document]): List of documents to add.
+ number_of_docs (Optional[int]): Number of documents to return
+ during vector search. Defaults to None.
+
+ Returns:
+ Clarifai: Clarifai vectorstore.
+ """
+ texts = [doc.page_content for doc in documents]
+ metadatas = [doc.metadata for doc in documents]
+ return cls.from_texts(
+ user_id=user_id,
+ app_id=app_id,
+ texts=texts,
+ number_of_docs=number_of_docs,
+ pat=pat,
+ metadatas=metadatas,
+ )
diff --git a/libs/community/langchain_community/vectorstores/clickhouse.py b/libs/community/langchain_community/vectorstores/clickhouse.py
new file mode 100644
index 00000000000..4b8628daf7a
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/clickhouse.py
@@ -0,0 +1,475 @@
+from __future__ import annotations
+
+import json
+import logging
+from hashlib import sha1
+from threading import Thread
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseSettings
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger()
+
+
+def has_mul_sub_str(s: str, *args: Any) -> bool:
+ """
+ Check if a string contains multiple substrings.
+ Args:
+ s: string to check.
+ *args: substrings to check.
+
+ Returns:
+ True if all substrings are in the string, False otherwise.
+ """
+ for a in args:
+ if a not in s:
+ return False
+ return True
+
+
+class ClickhouseSettings(BaseSettings):
+ """`ClickHouse` client configuration.
+
+ Attribute:
+ host (str) : An URL to connect to MyScale backend.
+ Defaults to 'localhost'.
+ port (int) : URL port to connect with HTTP. Defaults to 8443.
+ username (str) : Username to login. Defaults to None.
+ password (str) : Password to login. Defaults to None.
+ index_type (str): index type string.
+ index_param (list): index build parameter.
+ index_query_params(dict): index query parameters.
+ database (str) : Database name to find the table. Defaults to 'default'.
+ table (str) : Table name to operate on.
+ Defaults to 'vector_table'.
+ metric (str) : Metric to compute distance,
+ supported are ('angular', 'euclidean', 'manhattan', 'hamming',
+ 'dot'). Defaults to 'angular'.
+ https://github.com/spotify/annoy/blob/main/src/annoymodule.cc#L149-L169
+
+ column_map (Dict) : Column type map to project column name onto langchain
+ semantics. Must have keys: `text`, `id`, `vector`,
+ must be same size to number of columns. For example:
+ .. code-block:: python
+
+ {
+ 'id': 'text_id',
+ 'uuid': 'global_unique_id'
+ 'embedding': 'text_embedding',
+ 'document': 'text_plain',
+ 'metadata': 'metadata_dictionary_in_json',
+ }
+
+ Defaults to identity map.
+ """
+
+ host: str = "localhost"
+ port: int = 8123
+
+ username: Optional[str] = None
+ password: Optional[str] = None
+
+ index_type: str = "annoy"
+ # Annoy supports L2Distance and cosineDistance.
+ index_param: Optional[Union[List, Dict]] = ["'L2Distance'", 100]
+ index_query_params: Dict[str, str] = {}
+
+ column_map: Dict[str, str] = {
+ "id": "id",
+ "uuid": "uuid",
+ "document": "document",
+ "embedding": "embedding",
+ "metadata": "metadata",
+ }
+
+ database: str = "default"
+ table: str = "langchain"
+ metric: str = "angular"
+
+ def __getitem__(self, item: str) -> Any:
+ return getattr(self, item)
+
+ class Config:
+ env_file = ".env"
+ env_prefix = "clickhouse_"
+ env_file_encoding = "utf-8"
+
+
+class Clickhouse(VectorStore):
+ """`ClickHouse VectorSearch` vector store.
+
+ You need a `clickhouse-connect` python package, and a valid account
+ to connect to ClickHouse.
+
+ ClickHouse can not only search with simple vector indexes,
+ it also supports complex query with multiple conditions,
+ constraints and even sub-queries.
+
+ For more information, please visit
+ [ClickHouse official site](https://clickhouse.com/clickhouse)
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ config: Optional[ClickhouseSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ """ClickHouse Wrapper to LangChain
+
+ embedding_function (Embeddings):
+ config (ClickHouseSettings): Configuration to ClickHouse Client
+ Other keyword arguments will pass into
+ [clickhouse-connect](https://docs.clickhouse.com/)
+ """
+ try:
+ from clickhouse_connect import get_client
+ except ImportError:
+ raise ImportError(
+ "Could not import clickhouse connect python package. "
+ "Please install it with `pip install clickhouse-connect`."
+ )
+ try:
+ from tqdm import tqdm
+
+ self.pgbar = tqdm
+ except ImportError:
+ # Just in case if tqdm is not installed
+ self.pgbar = lambda x, **kwargs: x
+ super().__init__()
+ if config is not None:
+ self.config = config
+ else:
+ self.config = ClickhouseSettings()
+ assert self.config
+ assert self.config.host and self.config.port
+ assert (
+ self.config.column_map
+ and self.config.database
+ and self.config.table
+ and self.config.metric
+ )
+ for k in ["id", "embedding", "document", "metadata", "uuid"]:
+ assert k in self.config.column_map
+ assert self.config.metric in [
+ "angular",
+ "euclidean",
+ "manhattan",
+ "hamming",
+ "dot",
+ ]
+
+ # initialize the schema
+ dim = len(embedding.embed_query("test"))
+
+ index_params = (
+ (
+ ",".join([f"'{k}={v}'" for k, v in self.config.index_param.items()])
+ if self.config.index_param
+ else ""
+ )
+ if isinstance(self.config.index_param, Dict)
+ else ",".join([str(p) for p in self.config.index_param])
+ if isinstance(self.config.index_param, List)
+ else self.config.index_param
+ )
+
+ self.schema = f"""\
+CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
+ {self.config.column_map['id']} Nullable(String),
+ {self.config.column_map['document']} Nullable(String),
+ {self.config.column_map['embedding']} Array(Float32),
+ {self.config.column_map['metadata']} JSON,
+ {self.config.column_map['uuid']} UUID DEFAULT generateUUIDv4(),
+ CONSTRAINT cons_vec_len CHECK length({self.config.column_map['embedding']}) = {dim},
+ INDEX vec_idx {self.config.column_map['embedding']} TYPE \
+{self.config.index_type}({index_params}) GRANULARITY 1000
+) ENGINE = MergeTree ORDER BY uuid SETTINGS index_granularity = 8192\
+"""
+ self.dim = dim
+ self.BS = "\\"
+ self.must_escape = ("\\", "'")
+ self.embedding_function = embedding
+ self.dist_order = "ASC" # Only support ConsingDistance and L2Distance
+
+ # Create a connection to clickhouse
+ self.client = get_client(
+ host=self.config.host,
+ port=self.config.port,
+ username=self.config.username,
+ password=self.config.password,
+ **kwargs,
+ )
+ # Enable JSON type
+ self.client.command("SET allow_experimental_object_type=1")
+ # Enable Annoy index
+ self.client.command("SET allow_experimental_annoy_index=1")
+ self.client.command(self.schema)
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_function
+
+ def escape_str(self, value: str) -> str:
+ return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
+
+ def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str:
+ ks = ",".join(column_names)
+ _data = []
+ for n in transac:
+ n = ",".join([f"'{self.escape_str(str(_n))}'" for _n in n])
+ _data.append(f"({n})")
+ i_str = f"""
+ INSERT INTO TABLE
+ {self.config.database}.{self.config.table}({ks})
+ VALUES
+ {','.join(_data)}
+ """
+ return i_str
+
+ def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
+ _insert_query = self._build_insert_sql(transac, column_names)
+ self.client.command(_insert_query)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ batch_size: int = 32,
+ ids: Optional[Iterable[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Insert more texts through the embeddings and add to the VectorStore.
+
+ Args:
+ texts: Iterable of strings to add to the VectorStore.
+ ids: Optional list of ids to associate with the texts.
+ batch_size: Batch size of insertion
+ metadata: Optional column data to be inserted
+
+ Returns:
+ List of ids from adding the texts into the VectorStore.
+
+ """
+ # Embed and create the documents
+ ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts]
+ colmap_ = self.config.column_map
+ transac = []
+ column_names = {
+ colmap_["id"]: ids,
+ colmap_["document"]: texts,
+ colmap_["embedding"]: self.embedding_function.embed_documents(list(texts)),
+ }
+ metadatas = metadatas or [{} for _ in texts]
+ column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
+ assert len(set(colmap_) - set(column_names)) >= 0
+ keys, values = zip(*column_names.items())
+ try:
+ t = None
+ for v in self.pgbar(
+ zip(*values), desc="Inserting data...", total=len(metadatas)
+ ):
+ assert (
+ len(v[keys.index(self.config.column_map["embedding"])]) == self.dim
+ )
+ transac.append(v)
+ if len(transac) == batch_size:
+ if t:
+ t.join()
+ t = Thread(target=self._insert, args=[transac, keys])
+ t.start()
+ transac = []
+ if len(transac) > 0:
+ if t:
+ t.join()
+ self._insert(transac, keys)
+ return [i for i in ids]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ config: Optional[ClickhouseSettings] = None,
+ text_ids: Optional[Iterable[str]] = None,
+ batch_size: int = 32,
+ **kwargs: Any,
+ ) -> Clickhouse:
+ """Create ClickHouse wrapper with existing texts
+
+ Args:
+ embedding_function (Embeddings): Function to extract text embedding
+ texts (Iterable[str]): List or tuple of strings to be added
+ config (ClickHouseSettings, Optional): ClickHouse configuration
+ text_ids (Optional[Iterable], optional): IDs for the texts.
+ Defaults to None.
+ batch_size (int, optional): Batchsize when transmitting data to ClickHouse.
+ Defaults to 32.
+ metadata (List[dict], optional): metadata to texts. Defaults to None.
+ Other keyword arguments will pass into
+ [clickhouse-connect](https://clickhouse.com/docs/en/integrations/python#clickhouse-connect-driver-api)
+ Returns:
+ ClickHouse Index
+ """
+ ctx = cls(embedding, config, **kwargs)
+ ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas)
+ return ctx
+
+ def __repr__(self) -> str:
+ """Text representation for ClickHouse Vector Store, prints backends, username
+ and schemas. Easy to use with `str(ClickHouse())`
+
+ Returns:
+ repr: string to show connection info and data schema
+ """
+ _repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ "
+ _repr += f"{self.config.host}:{self.config.port}\033[0m\n\n"
+ _repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n"
+ _repr += "-" * 51 + "\n"
+ for r in self.client.query(
+ f"DESC {self.config.database}.{self.config.table}"
+ ).named_results():
+ _repr += (
+ f"|\033[94m{r['name']:24s}\033[0m|\033[96m{r['type']:24s}\033[0m|\n"
+ )
+ _repr += "-" * 51 + "\n"
+ return _repr
+
+ def _build_query_sql(
+ self, q_emb: List[float], topk: int, where_str: Optional[str] = None
+ ) -> str:
+ q_emb_str = ",".join(map(str, q_emb))
+ if where_str:
+ where_str = f"PREWHERE {where_str}"
+ else:
+ where_str = ""
+
+ settings_strs = []
+ if self.config.index_query_params:
+ for k in self.config.index_query_params:
+ settings_strs.append(f"SETTING {k}={self.config.index_query_params[k]}")
+ q_str = f"""
+ SELECT {self.config.column_map['document']},
+ {self.config.column_map['metadata']}, dist
+ FROM {self.config.database}.{self.config.table}
+ {where_str}
+ ORDER BY L2Distance({self.config.column_map['embedding']}, [{q_emb_str}])
+ AS dist {self.dist_order}
+ LIMIT {topk} {' '.join(settings_strs)}
+ """
+ return q_str
+
+ def similarity_search(
+ self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
+ ) -> List[Document]:
+ """Perform a similarity search with ClickHouse
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of Documents
+ """
+ return self.similarity_search_by_vector(
+ self.embedding_function.embed_query(query), k, where_str, **kwargs
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a similarity search with ClickHouse by vectors
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of documents
+ """
+ q_str = self._build_query_sql(embedding, k, where_str)
+ try:
+ return [
+ Document(
+ page_content=r[self.config.column_map["document"]],
+ metadata=r[self.config.column_map["metadata"]],
+ )
+ for r in self.client.query(q_str).named_results()
+ ]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ def similarity_search_with_relevance_scores(
+ self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Perform a similarity search with ClickHouse
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of (Document, similarity)
+ """
+ q_str = self._build_query_sql(
+ self.embedding_function.embed_query(query), k, where_str
+ )
+ try:
+ return [
+ (
+ Document(
+ page_content=r[self.config.column_map["document"]],
+ metadata=r[self.config.column_map["metadata"]],
+ ),
+ r["dist"],
+ )
+ for r in self.client.query(q_str).named_results()
+ ]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ def drop(self) -> None:
+ """
+ Helper function: Drop data
+ """
+ self.client.command(
+ f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}"
+ )
+
+ @property
+ def metadata_column(self) -> str:
+ return self.config.column_map["metadata"]
diff --git a/libs/community/langchain_community/vectorstores/dashvector.py b/libs/community/langchain_community/vectorstores/dashvector.py
new file mode 100644
index 00000000000..7e82b23c2d5
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/dashvector.py
@@ -0,0 +1,364 @@
+from __future__ import annotations
+
+import logging
+import uuid
+from typing import (
+ Any,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_env
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+logger = logging.getLogger(__name__)
+
+
+class DashVector(VectorStore):
+ """`DashVector` vector store.
+
+ To use, you should have the ``dashvector`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import DashVector
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ import dashvector
+
+ client = dashvector.Client(api_key="***")
+ client.create("langchain", dimension=1024)
+ collection = client.get("langchain")
+ embeddings = OpenAIEmbeddings()
+ vectorstore = DashVector(collection, embeddings.embed_query, "text")
+ """
+
+ def __init__(
+ self,
+ collection: Any,
+ embedding: Embeddings,
+ text_field: str,
+ ):
+ """Initialize with DashVector collection."""
+
+ try:
+ import dashvector
+ except ImportError:
+ raise ValueError(
+ "Could not import dashvector python package. "
+ "Please install it with `pip install dashvector`."
+ )
+
+ if not isinstance(collection, dashvector.Collection):
+ raise ValueError(
+ f"collection should be an instance of dashvector.Collection, "
+ f"bug got {type(collection)}"
+ )
+
+ self._collection = collection
+ self._embedding = embedding
+ self._text_field = text_field
+
+ def _similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[str] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query vector, along with scores"""
+
+ # query by vector
+ ret = self._collection.query(embedding, topk=k, filter=filter)
+ if not ret:
+ raise ValueError(
+ f"Fail to query docs by vector, error {self._collection.message}"
+ )
+
+ docs = []
+ for doc in ret:
+ metadata = doc.fields
+ text = metadata.pop(self._text_field)
+ score = doc.score
+ docs.append((Document(page_content=text, metadata=metadata), score))
+ return docs
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ batch_size: int = 25,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids associated with the texts.
+ batch_size: Optional batch size to upsert docs.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ ids = ids or [str(uuid.uuid4().hex) for _ in texts]
+ text_list = list(texts)
+ for i in range(0, len(text_list), batch_size):
+ # batch end
+ end = min(i + batch_size, len(text_list))
+
+ batch_texts = text_list[i:end]
+ batch_ids = ids[i:end]
+ batch_embeddings = self._embedding.embed_documents(list(batch_texts))
+
+ # batch metadatas
+ if metadatas:
+ batch_metadatas = metadatas[i:end]
+ else:
+ batch_metadatas = [{} for _ in range(i, end)]
+ for metadata, text in zip(batch_metadatas, batch_texts):
+ metadata[self._text_field] = text
+
+ # batch upsert to collection
+ docs = list(zip(batch_ids, batch_embeddings, batch_metadatas))
+ ret = self._collection.upsert(docs)
+ if not ret:
+ raise ValueError(
+ f"Fail to upsert docs to dashvector vector database,"
+ f"Error: {ret.message}"
+ )
+ return ids
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> bool:
+ """Delete by vector ID.
+
+ Args:
+ ids: List of ids to delete.
+
+ Returns:
+ True if deletion is successful,
+ False otherwise.
+ """
+ return bool(self._collection.delete(ids))
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to search documents similar to.
+ k: Number of documents to return. Default to 4.
+ filter: Doc fields filter conditions that meet the SQL where clause
+ specification.
+
+ Returns:
+ List of Documents most similar to the query text.
+ """
+
+ docs_and_scores = self.similarity_search_with_relevance_scores(query, k, filter)
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query text , alone with relevance scores.
+
+ Less is more similar, more is more dissimilar.
+
+ Args:
+ query: input text
+ k: Number of Documents to return. Defaults to 4.
+ filter: Doc fields filter conditions that meet the SQL where clause
+ specification.
+
+ Returns:
+ List of Tuples of (doc, similarity_score)
+ """
+
+ embedding = self._embedding.embed_query(query)
+ return self._similarity_search_with_score_by_vector(
+ embedding, k=k, filter=filter
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Doc fields filter conditions that meet the SQL where clause
+ specification.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ docs_and_scores = self._similarity_search_with_score_by_vector(
+ embedding, k, filter
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter: Doc fields filter conditions that meet the SQL where clause
+ specification.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding = self._embedding.embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ embedding, k, fetch_k, lambda_mult, filter
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter: Doc fields filter conditions that meet the SQL where clause
+ specification.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+
+ # query by vector
+ ret = self._collection.query(
+ embedding, topk=fetch_k, filter=filter, include_vector=True
+ )
+ if not ret:
+ raise ValueError(
+ f"Fail to query docs by vector, error {self._collection.message}"
+ )
+
+ candidate_embeddings = [doc.vector for doc in ret]
+ mmr_selected = maximal_marginal_relevance(
+ np.array(embedding), candidate_embeddings, lambda_mult, k
+ )
+
+ metadatas = [ret.output[i].fields for i in mmr_selected]
+ return [
+ Document(page_content=metadata.pop(self._text_field), metadata=metadata)
+ for metadata in metadatas
+ ]
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ dashvector_api_key: Optional[str] = None,
+ collection_name: str = "langchain",
+ text_field: str = "text",
+ batch_size: int = 25,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> DashVector:
+ """Return DashVector VectorStore initialized from texts and embeddings.
+
+ This is the quick way to get started with dashvector vector store.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import DashVector
+ from langchain_community.embeddings import OpenAIEmbeddings
+ import dashvector
+
+ embeddings = OpenAIEmbeddings()
+ dashvector = DashVector.from_documents(
+ docs,
+ embeddings,
+ dashvector_api_key="{DASHVECTOR_API_KEY}"
+ )
+ """
+ try:
+ import dashvector
+ except ImportError:
+ raise ValueError(
+ "Could not import dashvector python package. "
+ "Please install it with `pip install dashvector`."
+ )
+
+ dashvector_api_key = dashvector_api_key or get_from_env(
+ "dashvector_api_key", "DASHVECTOR_API_KEY"
+ )
+
+ dashvector_client = dashvector.Client(api_key=dashvector_api_key)
+ dashvector_client.delete(collection_name)
+ collection = dashvector_client.get(collection_name)
+ if not collection:
+ dim = len(embedding.embed_query(texts[0]))
+ # create collection if not existed
+ resp = dashvector_client.create(collection_name, dimension=dim)
+ if resp:
+ collection = dashvector_client.get(collection_name)
+ else:
+ raise ValueError(
+ "Fail to create collection. " f"Error: {resp.message}."
+ )
+
+ dashvector_vector_db = cls(collection, embedding, text_field)
+ dashvector_vector_db.add_texts(texts, metadatas, ids, batch_size)
+ return dashvector_vector_db
diff --git a/libs/community/langchain_community/vectorstores/databricks_vector_search.py b/libs/community/langchain_community/vectorstores/databricks_vector_search.py
new file mode 100644
index 00000000000..5f5882ae100
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/databricks_vector_search.py
@@ -0,0 +1,473 @@
+from __future__ import annotations
+
+import json
+import logging
+import uuid
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VST, VectorStore
+
+if TYPE_CHECKING:
+ from databricks.vector_search.client import VectorSearchIndex
+
+logger = logging.getLogger(__name__)
+
+
+class DatabricksVectorSearch(VectorStore):
+ """`Databricks Vector Search` vector store.
+
+ To use, you should have the ``databricks-vectorsearch`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import DatabricksVectorSearch
+ from databricks.vector_search.client import VectorSearchClient
+
+ vs_client = VectorSearchClient()
+ vs_index = vs_client.get_index(
+ endpoint_name="vs_endpoint",
+ index_name="ml.llm.index"
+ )
+ vectorstore = DatabricksVectorSearch(vs_index)
+
+ Args:
+ index: A Databricks Vector Search index object.
+ embedding: The embedding model.
+ Required for direct-access index or delta-sync index
+ with self-managed embeddings.
+ text_column: The name of the text column to use for the embeddings.
+ Required for direct-access index or delta-sync index
+ with self-managed embeddings.
+ Make sure the text column specified is in the index.
+ columns: The list of column names to get when doing the search.
+ Defaults to ``[primary_key, text_column]``.
+
+ Delta-sync index with Databricks-managed embeddings manages the ingestion, deletion,
+ and embedding for you.
+ Manually ingestion/deletion of the documents/texts is not supported for delta-sync
+ index.
+
+ If you want to use a delta-sync index with self-managed embeddings, you need to
+ provide the embedding model and text column name to use for the embeddings.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import DatabricksVectorSearch
+ from databricks.vector_search.client import VectorSearchClient
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ vs_client = VectorSearchClient()
+ vs_index = vs_client.get_index(
+ endpoint_name="vs_endpoint",
+ index_name="ml.llm.index"
+ )
+ vectorstore = DatabricksVectorSearch(
+ index=vs_index,
+ embedding=OpenAIEmbeddings(),
+ text_column="document_content"
+ )
+
+ If you want to manage the documents ingestion/deletion yourself, you can use a
+ direct-access index.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import DatabricksVectorSearch
+ from databricks.vector_search.client import VectorSearchClient
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ vs_client = VectorSearchClient()
+ vs_index = vs_client.get_index(
+ endpoint_name="vs_endpoint",
+ index_name="ml.llm.index"
+ )
+ vectorstore = DatabricksVectorSearch(
+ index=vs_index,
+ embedding=OpenAIEmbeddings(),
+ text_column="document_content"
+ )
+ vectorstore.add_texts(
+ texts=["text1", "text2"]
+ )
+
+ For more information on Databricks Vector Search, see `Databricks Vector Search
+ documentation `.
+
+ """
+
+ def __init__(
+ self,
+ index: VectorSearchIndex,
+ *,
+ embedding: Optional[Embeddings] = None,
+ text_column: Optional[str] = None,
+ columns: Optional[List[str]] = None,
+ ):
+ try:
+ from databricks.vector_search.client import VectorSearchIndex
+ except ImportError as e:
+ raise ImportError(
+ "Could not import databricks-vectorsearch python package. "
+ "Please install it with `pip install databricks-vectorsearch`."
+ ) from e
+ # index
+ self.index = index
+ if not isinstance(index, VectorSearchIndex):
+ raise TypeError("index must be of type VectorSearchIndex.")
+
+ # index_details
+ index_details = self.index.describe()
+ self.primary_key = index_details["primary_key"]
+ self.index_type = index_details.get("index_type")
+ self._delta_sync_index_spec = index_details.get("delta_sync_index_spec", dict())
+ self._direct_access_index_spec = index_details.get(
+ "direct_access_index_spec", dict()
+ )
+
+ # text_column
+ if self._is_databricks_managed_embeddings():
+ index_source_column = self._embedding_source_column_name()
+ # check if input text column matches the source column of the index
+ if text_column is not None and text_column != index_source_column:
+ raise ValueError(
+ f"text_column '{text_column}' does not match with the "
+ f"source column of the index: '{index_source_column}'."
+ )
+ self.text_column = index_source_column
+ else:
+ self._require_arg(text_column, "text_column")
+ self.text_column = text_column
+
+ # columns
+ self.columns = columns or []
+ # add primary key column and source column if not in columns
+ if self.primary_key not in self.columns:
+ self.columns.append(self.primary_key)
+ if self.text_column and self.text_column not in self.columns:
+ self.columns.append(self.text_column)
+
+ # Validate specified columns are in the index
+ if self._is_direct_access_index():
+ index_schema = self._index_schema()
+ if index_schema:
+ for col in self.columns:
+ if col not in index_schema:
+ raise ValueError(
+ f"column '{col}' is not in the index's schema."
+ )
+
+ # embedding model
+ if not self._is_databricks_managed_embeddings():
+ # embedding model is required for direct-access index
+ # or delta-sync index with self-managed embedding
+ self._require_arg(embedding, "embedding")
+ self._embedding = embedding
+ # validate dimension matches
+ index_embedding_dimension = self._embedding_vector_column_dimension()
+ if index_embedding_dimension is not None:
+ inferred_embedding_dimension = self._infer_embedding_dimension()
+ if inferred_embedding_dimension != index_embedding_dimension:
+ raise ValueError(
+ f"embedding model's dimension '{inferred_embedding_dimension}' "
+ f"does not match with the index's dimension "
+ f"'{index_embedding_dimension}'."
+ )
+ else:
+ if embedding is not None:
+ logger.warning(
+ "embedding model is not used in delta-sync index with "
+ "Databricks-managed embeddings."
+ )
+ self._embedding = None
+
+ @classmethod
+ def from_texts(
+ cls: Type[VST],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> VST:
+ raise NotImplementedError(
+ "`from_texts` is not supported. "
+ "Use `add_texts` to add to existing direct-access index."
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[Any]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add texts to the index.
+
+ Only support direct-access index.
+
+ Args:
+ texts: List of texts to add.
+ metadatas: List of metadata for each text. Defaults to None.
+ ids: List of ids for each text. Defaults to None.
+ If not provided, a random uuid will be generated for each text.
+
+ Returns:
+ List of ids from adding the texts into the index.
+ """
+ self._op_require_direct_access_index("add_texts")
+ assert self.embeddings is not None, "embedding model is required."
+ # Wrap to list if input texts is a single string
+ if isinstance(texts, str):
+ texts = [texts]
+ texts = list(texts)
+ vectors = self.embeddings.embed_documents(texts)
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
+ metadatas = metadatas or [{} for _ in texts]
+
+ updates = [
+ {
+ self.primary_key: id_,
+ self.text_column: text,
+ self._embedding_vector_column_name(): vector,
+ **metadata,
+ }
+ for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas)
+ ]
+
+ upsert_resp = self.index.upsert(updates)
+ if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"):
+ failed_ids = upsert_resp.get("result", dict()).get(
+ "failed_primary_keys", []
+ )
+ if upsert_resp.get("status") == "FAILURE":
+ logger.error("Failed to add texts to the index.")
+ else:
+ logger.warning("Some texts failed to be added to the index.")
+ return [id_ for id_ in ids if id_ not in failed_ids]
+
+ return ids
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ """Access the query embedding object if available."""
+ return self._embedding
+
+ def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete documents from the index.
+
+ Only support direct-access index.
+
+ Args:
+ ids: List of ids of documents to delete.
+
+ Returns:
+ True if successful.
+ """
+ self._op_require_direct_access_index("delete")
+ if ids is None:
+ raise ValueError("ids must be provided.")
+ self.index.delete(ids)
+ return True
+
+ def similarity_search(
+ self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filters: Filters to apply to the query. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the embedding.
+ """
+ docs_with_score = self.similarity_search_with_score(
+ query=query, k=k, filters=filters, **kwargs
+ )
+ return [doc for doc, _ in docs_with_score]
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filters: Filters to apply to the query. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the embedding and score for each.
+ """
+ if self._is_databricks_managed_embeddings():
+ query_text = query
+ query_vector = None
+ else:
+ assert self.embeddings is not None, "embedding model is required."
+ query_text = None
+ query_vector = self.embeddings.embed_query(query)
+
+ search_resp = self.index.similarity_search(
+ columns=self.columns,
+ query_text=query_text,
+ query_vector=query_vector,
+ filters=filters,
+ num_results=k,
+ )
+ return self._parse_search_response(search_resp)
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filters: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filters: Filters to apply to the query. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the embedding.
+ """
+ docs_with_score = self.similarity_search_by_vector_with_score(
+ embedding=embedding, k=k, filters=filters, **kwargs
+ )
+ return [doc for doc, _ in docs_with_score]
+
+ def similarity_search_by_vector_with_score(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filters: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to embedding vector, along with scores.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filters: Filters to apply to the query. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the embedding and score for each.
+ """
+ if self._is_databricks_managed_embeddings():
+ raise ValueError(
+ "`similarity_search_by_vector` is not supported for index with "
+ "Databricks-managed embeddings."
+ )
+ search_resp = self.index.similarity_search(
+ columns=self.columns,
+ query_vector=embedding,
+ filters=filters,
+ num_results=k,
+ )
+ return self._parse_search_response(search_resp)
+
+ def _parse_search_response(self, search_resp: dict) -> List[Tuple[Document, float]]:
+ """Parse the search response into a list of Documents with score."""
+ columns = [
+ col["name"]
+ for col in search_resp.get("manifest", dict()).get("columns", [])
+ ]
+ docs_with_score = []
+ for result in search_resp.get("result", dict()).get("data_array", []):
+ doc_id = result[columns.index(self.primary_key)]
+ text_content = result[columns.index(self.text_column)]
+ metadata = {
+ col: value
+ for col, value in zip(columns[:-1], result[:-1])
+ if col not in [self.primary_key, self.text_column]
+ }
+ metadata[self.primary_key] = doc_id
+ score = result[-1]
+ doc = Document(page_content=text_content, metadata=metadata)
+ docs_with_score.append((doc, score))
+ return docs_with_score
+
+ def _index_schema(self) -> Optional[dict]:
+ """Return the index schema as a dictionary.
+ Return None if no schema found.
+ """
+ if self._is_direct_access_index():
+ schema_json = self._direct_access_index_spec.get("schema_json")
+ if schema_json is not None:
+ return json.loads(schema_json)
+ return None
+
+ def _embedding_vector_column_name(self) -> Optional[str]:
+ """Return the name of the embedding vector column.
+ None if the index is not a self-managed embedding index.
+ """
+ return self._embedding_vector_column().get("name")
+
+ def _embedding_vector_column_dimension(self) -> Optional[int]:
+ """Return the dimension of the embedding vector column.
+ None if the index is not a self-managed embedding index.
+ """
+ return self._embedding_vector_column().get("embedding_dimension")
+
+ def _embedding_vector_column(self) -> dict:
+ """Return the embedding vector column configs as a dictionary.
+ Empty if the index is not a self-managed embedding index.
+ """
+ index_spec = (
+ self._delta_sync_index_spec
+ if self._is_delta_sync_index()
+ else self._direct_access_index_spec
+ )
+ return next(iter(index_spec.get("embedding_vector_columns") or list()), dict())
+
+ def _embedding_source_column_name(self) -> Optional[str]:
+ """Return the name of the embedding source column.
+ None if the index is not a Databricks-managed embedding index.
+ """
+ return self._embedding_source_column().get("name")
+
+ def _embedding_source_column(self) -> dict:
+ """Return the embedding source column configs as a dictionary.
+ Empty if the index is not a Databricks-managed embedding index.
+ """
+ index_spec = self._delta_sync_index_spec
+ return next(iter(index_spec.get("embedding_source_columns") or list()), dict())
+
+ def _is_delta_sync_index(self) -> bool:
+ """Return True if the index is a delta-sync index."""
+ return self.index_type == "DELTA_SYNC"
+
+ def _is_direct_access_index(self) -> bool:
+ """Return True if the index is a direct-access index."""
+ return self.index_type == "DIRECT_ACCESS"
+
+ def _is_databricks_managed_embeddings(self) -> bool:
+ """Return True if the embeddings are managed by Databricks Vector Search."""
+ return (
+ self._is_delta_sync_index()
+ and self._embedding_source_column_name() is not None
+ )
+
+ def _infer_embedding_dimension(self) -> int:
+ """Infer the embedding dimension from the embedding function."""
+ assert self.embeddings is not None, "embedding model is required."
+ return len(self.embeddings.embed_query("test"))
+
+ def _op_require_direct_access_index(self, op_name: str) -> None:
+ """
+ Raise ValueError if the operation is not supported for direct-access index."""
+ if not self._is_direct_access_index():
+ raise ValueError(f"`{op_name}` is only supported for direct-access index.")
+
+ @staticmethod
+ def _require_arg(arg: Any, arg_name: str) -> None:
+ """Raise ValueError if the required arg with name `arg_name` is None."""
+ if not arg:
+ raise ValueError(f"`{arg_name}` is required for this index.")
diff --git a/libs/community/langchain_community/vectorstores/deeplake.py b/libs/community/langchain_community/vectorstores/deeplake.py
new file mode 100644
index 00000000000..7051988b92b
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/deeplake.py
@@ -0,0 +1,901 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+try:
+ import deeplake
+ from deeplake import VectorStore as DeepLakeVectorStore
+ from deeplake.core.fast_forwarding import version_compare
+
+ _DEEPLAKE_INSTALLED = True
+except ImportError:
+ _DEEPLAKE_INSTALLED = False
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+logger = logging.getLogger(__name__)
+
+
+class DeepLake(VectorStore):
+ """`Activeloop Deep Lake` vector store.
+
+ We integrated deeplake's similarity search and filtering for fast prototyping.
+ Now, it supports Tensor Query Language (TQL) for production use cases
+ over billion rows.
+
+ Why Deep Lake?
+
+ - Not only stores embeddings, but also the original data with version control.
+ - Serverless, doesn't require another service and can be used with major
+ cloud providers (S3, GCS, etc.)
+ - More than just a multi-modal vector store. You can use the dataset
+ to fine-tune your own LLM models.
+
+ To use, you should have the ``deeplake`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import DeepLake
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ vectorstore = DeepLake("langchain_store", embeddings.embed_query)
+ """
+
+ _LANGCHAIN_DEFAULT_DEEPLAKE_PATH = "./deeplake/"
+
+ def __init__(
+ self,
+ dataset_path: str = _LANGCHAIN_DEFAULT_DEEPLAKE_PATH,
+ token: Optional[str] = None,
+ embedding: Optional[Embeddings] = None,
+ embedding_function: Optional[Embeddings] = None,
+ read_only: bool = False,
+ ingestion_batch_size: int = 1000,
+ num_workers: int = 0,
+ verbose: bool = True,
+ exec_option: Optional[str] = None,
+ runtime: Optional[Dict] = None,
+ index_params: Optional[Dict[str, Union[int, str]]] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Creates an empty DeepLakeVectorStore or loads an existing one.
+
+ The DeepLakeVectorStore is located at the specified ``path``.
+
+ Examples:
+ >>> # Create a vector store with default tensors
+ >>> deeplake_vectorstore = DeepLake(
+ ... path = ,
+ ... )
+ >>>
+ >>> # Create a vector store in the Deep Lake Managed Tensor Database
+ >>> data = DeepLake(
+ ... path = "hub://org_id/dataset_name",
+ ... runtime = {"tensor_db": True},
+ ... )
+
+ Args:
+ dataset_path (str): Path to existing dataset or where to create
+ a new one. Defaults to _LANGCHAIN_DEFAULT_DEEPLAKE_PATH.
+ token (str, optional): Activeloop token, for fetching credentials
+ to the dataset at path if it is a Deep Lake dataset.
+ Tokens are normally autogenerated. Optional.
+ embedding (Embeddings, optional): Function to convert
+ either documents or query. Optional.
+ embedding_function (Embeddings, optional): Function to convert
+ either documents or query. Optional. Deprecated: keeping this
+ parameter for backwards compatibility.
+ read_only (bool): Open dataset in read-only mode. Default is False.
+ ingestion_batch_size (int): During data ingestion, data is divided
+ into batches. Batch size is the size of each batch.
+ Default is 1000.
+ num_workers (int): Number of workers to use during data ingestion.
+ Default is 0.
+ verbose (bool): Print dataset summary after each operation.
+ Default is True.
+ exec_option (str, optional): DeepLakeVectorStore supports 3 ways to perform
+ searching - "python", "compute_engine", "tensor_db" and auto.
+ Default is None.
+ - ``auto``- Selects the best execution method based on the storage
+ location of the Vector Store. It is the default option.
+ - ``python`` - Pure-python implementation that runs on the client.
+ WARNING: using this with big datasets can lead to memory
+ issues. Data can be stored anywhere.
+ - ``compute_engine`` - C++ implementation of the Deep Lake Compute
+ Engine that runs on the client. Can be used for any data stored in
+ or connected to Deep Lake. Not for in-memory or local datasets.
+ - ``tensor_db`` - Hosted Managed Tensor Database that is
+ responsible for storage and query execution. Only for data stored in
+ the Deep Lake Managed Database. Use runtime = {"db_engine": True}
+ during dataset creation.
+ runtime (Dict, optional): Parameters for creating the Vector Store in
+ Deep Lake's Managed Tensor Database. Not applicable when loading an
+ existing Vector Store. To create a Vector Store in the Managed Tensor
+ Database, set `runtime = {"tensor_db": True}`.
+ index_params (Optional[Dict[str, Union[int, str]]], optional): Dictionary
+ containing information about vector index that will be created. Defaults
+ to None, which will utilize ``DEFAULT_VECTORSTORE_INDEX_PARAMS`` from
+ ``deeplake.constants``. The specified key-values override the default
+ ones.
+ - threshold: The threshold for the dataset size above which an index
+ will be created for the embedding tensor. When the threshold value
+ is set to -1, index creation is turned off. Defaults to -1, which
+ turns off the index.
+ - distance_metric: This key specifies the method of calculating the
+ distance between vectors when creating the vector database (VDB)
+ index. It can either be a string that corresponds to a member of
+ the DistanceType enumeration, or the string value itself.
+ - If no value is provided, it defaults to "L2".
+ - "L2" corresponds to DistanceType.L2_NORM.
+ - "COS" corresponds to DistanceType.COSINE_SIMILARITY.
+ - additional_params: Additional parameters for fine-tuning the index.
+ **kwargs: Other optional keyword arguments.
+
+ Raises:
+ ValueError: If some condition is not met.
+ """
+
+ self.ingestion_batch_size = ingestion_batch_size
+ self.num_workers = num_workers
+ self.verbose = verbose
+
+ if _DEEPLAKE_INSTALLED is False:
+ raise ImportError(
+ "Could not import deeplake python package. "
+ "Please install it with `pip install deeplake[enterprise]`."
+ )
+
+ if (
+ runtime == {"tensor_db": True}
+ and version_compare(deeplake.__version__, "3.6.7") == -1
+ ):
+ raise ImportError(
+ "To use tensor_db option you need to update deeplake to `3.6.7` or "
+ "higher. "
+ f"Currently installed deeplake version is {deeplake.__version__}. "
+ )
+
+ self.dataset_path = dataset_path
+
+ if embedding_function:
+ logger.warning(
+ "Using embedding function is deprecated and will be removed "
+ "in the future. Please use embedding instead."
+ )
+
+ self.vectorstore = DeepLakeVectorStore(
+ path=self.dataset_path,
+ embedding_function=embedding_function or embedding,
+ read_only=read_only,
+ token=token,
+ exec_option=exec_option,
+ verbose=verbose,
+ runtime=runtime,
+ index_params=index_params,
+ **kwargs,
+ )
+
+ self._embedding_function = embedding_function or embedding
+ self._id_tensor_name = "ids" if "ids" in self.vectorstore.tensors() else "id"
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embedding_function
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Examples:
+ >>> ids = deeplake_vectorstore.add_texts(
+ ... texts = ,
+ ... metadatas = ,
+ ... ids = ,
+ ... )
+
+ Args:
+ texts (Iterable[str]): Texts to add to the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ ids (Optional[List[str]], optional): Optional list of IDs.
+ embedding_function (Optional[Embeddings], optional): Embedding function
+ to use to convert the text into embeddings.
+ **kwargs (Any): Any additional keyword arguments passed is not supported
+ by this method.
+
+ Returns:
+ List[str]: List of IDs of the added texts.
+ """
+ if kwargs:
+ unsupported_items = "`, `".join(set(kwargs.keys()))
+ raise TypeError(
+ f"`{unsupported_items}` is/are not a valid argument to add_text method"
+ )
+
+ kwargs = {}
+ if ids:
+ if self._id_tensor_name == "ids": # for backwards compatibility
+ kwargs["ids"] = ids
+ else:
+ kwargs["id"] = ids
+
+ if metadatas is None:
+ metadatas = [{}] * len(list(texts))
+
+ if not isinstance(texts, list):
+ texts = list(texts)
+
+ if texts is None:
+ raise ValueError("`texts` parameter shouldn't be None.")
+ elif len(texts) == 0:
+ raise ValueError("`texts` parameter shouldn't be empty.")
+
+ return self.vectorstore.add(
+ text=texts,
+ metadata=metadatas,
+ embedding_data=texts,
+ embedding_tensor="embedding",
+ embedding_function=self._embedding_function.embed_documents, # type: ignore
+ return_ids=True,
+ **kwargs,
+ )
+
+ def _search_tql(
+ self,
+ tql: Optional[str],
+ exec_option: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Function for performing tql_search.
+
+ Args:
+ tql (str): TQL Query string for direct evaluation.
+ Available only for `compute_engine` and `tensor_db`.
+ exec_option (str, optional): Supports 3 ways to search.
+ Could be "python", "compute_engine" or "tensor_db". Default is "python".
+ - ``python`` - Pure-python implementation for the client.
+ WARNING: not recommended for big datasets due to potential memory
+ issues.
+ - ``compute_engine`` - C++ implementation of Deep Lake Compute
+ Engine for the client. Not for in-memory or local datasets.
+ - ``tensor_db`` - Hosted Managed Tensor Database for storage
+ and query execution. Only for data in Deep Lake Managed Database.
+ Use runtime = {"db_engine": True} during dataset creation.
+ return_score (bool): Return score with document. Default is False.
+
+ Returns:
+ Tuple[List[Document], List[Tuple[Document, float]]] - A tuple of two lists.
+ The first list contains Documents, and the second list contains
+ tuples of Document and float score.
+
+ Raises:
+ ValueError: If return_score is True but some condition is not met.
+ """
+ result = self.vectorstore.search(
+ query=tql,
+ exec_option=exec_option,
+ )
+ metadatas = result["metadata"]
+ texts = result["text"]
+
+ docs = [
+ Document(
+ page_content=text,
+ metadata=metadata,
+ )
+ for text, metadata in zip(texts, metadatas)
+ ]
+
+ if kwargs:
+ unsupported_argument = next(iter(kwargs))
+ if kwargs[unsupported_argument] is not False:
+ raise ValueError(
+ f"specifying {unsupported_argument} is "
+ "not supported with tql search."
+ )
+
+ return docs
+
+ def _search(
+ self,
+ query: Optional[str] = None,
+ embedding: Optional[Union[List[float], np.ndarray]] = None,
+ embedding_function: Optional[Callable] = None,
+ k: int = 4,
+ distance_metric: Optional[str] = None,
+ use_maximal_marginal_relevance: bool = False,
+ fetch_k: Optional[int] = 20,
+ filter: Optional[Union[Dict, Callable]] = None,
+ return_score: bool = False,
+ exec_option: Optional[str] = None,
+ deep_memory: bool = False,
+ **kwargs: Any,
+ ) -> Any[List[Document], List[Tuple[Document, float]]]:
+ """
+ Return docs similar to query.
+
+ Args:
+ query (str, optional): Text to look up similar docs.
+ embedding (Union[List[float], np.ndarray], optional): Query's embedding.
+ embedding_function (Callable, optional): Function to convert `query`
+ into embedding.
+ k (int): Number of Documents to return.
+ distance_metric (Optional[str], optional): `L2` for Euclidean, `L1` for
+ Nuclear, `max` for L-infinity distance, `cos` for cosine similarity,
+ 'dot' for dot product.
+ filter (Union[Dict, Callable], optional): Additional filter prior
+ to the embedding search.
+ - ``Dict`` - Key-value search on tensors of htype json, on an
+ AND basis (a sample must satisfy all key-value filters to be True)
+ Dict = {"tensor_name_1": {"key": value},
+ "tensor_name_2": {"key": value}}
+ - ``Function`` - Any function compatible with `deeplake.filter`.
+ use_maximal_marginal_relevance (bool): Use maximal marginal relevance.
+ fetch_k (int): Number of Documents for MMR algorithm.
+ return_score (bool): Return the score.
+ exec_option (str, optional): Supports 3 ways to perform searching.
+ Could be "python", "compute_engine" or "tensor_db".
+ - ``python`` - Pure-python implementation for the client.
+ WARNING: not recommended for big datasets.
+ - ``compute_engine`` - C++ implementation of Deep Lake Compute
+ Engine for the client. Not for in-memory or local datasets.
+ - ``tensor_db`` - Hosted Managed Tensor Database for storage
+ and query execution. Only for data in Deep Lake Managed Database.
+ Use runtime = {"db_engine": True} during dataset creation.
+ deep_memory (bool): Whether to use the Deep Memory model for improving
+ search results. Defaults to False if deep_memory is not specified in
+ the Vector Store initialization. If True, the distance metric is set
+ to "deepmemory_distance", which represents the metric with which the
+ model was trained. The search is performed using the Deep Memory model.
+ If False, the distance metric is set to "COS" or whatever distance
+ metric user specifies.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ List of Documents by the specified distance metric,
+ if return_score True, return a tuple of (Document, score)
+
+ Raises:
+ ValueError: if both `embedding` and `embedding_function` are not specified.
+ """
+
+ if kwargs.get("tql"):
+ return self._search_tql(
+ tql=kwargs["tql"],
+ exec_option=exec_option,
+ return_score=return_score,
+ embedding=embedding,
+ embedding_function=embedding_function,
+ distance_metric=distance_metric,
+ use_maximal_marginal_relevance=use_maximal_marginal_relevance,
+ filter=filter,
+ )
+
+ if embedding_function:
+ if isinstance(embedding_function, Embeddings):
+ _embedding_function = embedding_function.embed_query
+ else:
+ _embedding_function = embedding_function
+ elif self._embedding_function:
+ _embedding_function = self._embedding_function.embed_query
+ else:
+ _embedding_function = None
+
+ if embedding is None:
+ if _embedding_function is None:
+ raise ValueError(
+ "Either `embedding` or `embedding_function` needs to be"
+ " specified."
+ )
+
+ embedding = _embedding_function(query) if query else None
+
+ if isinstance(embedding, list):
+ embedding = np.array(embedding, dtype=np.float32)
+ if len(embedding.shape) > 1:
+ embedding = embedding[0]
+
+ result = self.vectorstore.search(
+ embedding=embedding,
+ k=fetch_k if use_maximal_marginal_relevance else k,
+ distance_metric=distance_metric,
+ filter=filter,
+ exec_option=exec_option,
+ return_tensors=["embedding", "metadata", "text", self._id_tensor_name],
+ deep_memory=deep_memory,
+ )
+
+ scores = result["score"]
+ embeddings = result["embedding"]
+ metadatas = result["metadata"]
+ texts = result["text"]
+
+ if use_maximal_marginal_relevance:
+ lambda_mult = kwargs.get("lambda_mult", 0.5)
+ indices = maximal_marginal_relevance( # type: ignore
+ embedding, # type: ignore
+ embeddings,
+ k=min(k, len(texts)),
+ lambda_mult=lambda_mult,
+ )
+
+ scores = [scores[i] for i in indices]
+ texts = [texts[i] for i in indices]
+ metadatas = [metadatas[i] for i in indices]
+
+ docs = [
+ Document(
+ page_content=text,
+ metadata=metadata,
+ )
+ for text, metadata in zip(texts, metadatas)
+ ]
+
+ if return_score:
+ return [(doc, score) for doc, score in zip(docs, scores)]
+
+ return docs
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """
+ Return docs most similar to query.
+
+ Examples:
+ >>> # Search using an embedding
+ >>> data = vector_store.similarity_search(
+ ... query=,
+ ... k=,
+ ... exec_option=,
+ ... )
+ >>> # Run tql search:
+ >>> data = vector_store.similarity_search(
+ ... query=None,
+ ... tql="SELECT * WHERE id == ",
+ ... exec_option="compute_engine",
+ ... )
+
+ Args:
+ k (int): Number of Documents to return. Defaults to 4.
+ query (str): Text to look up similar documents.
+ **kwargs: Additional keyword arguments include:
+ embedding (Callable): Embedding function to use. Defaults to None.
+ distance_metric (str): 'L2' for Euclidean, 'L1' for Nuclear, 'max'
+ for L-infinity, 'cos' for cosine, 'dot' for dot product.
+ Defaults to 'L2'.
+ filter (Union[Dict, Callable], optional): Additional filter
+ before embedding search.
+ - Dict: Key-value search on tensors of htype json,
+ (sample must satisfy all key-value filters)
+ Dict = {"tensor_1": {"key": value}, "tensor_2": {"key": value}}
+ - Function: Compatible with `deeplake.filter`.
+ Defaults to None.
+ exec_option (str): Supports 3 ways to perform searching.
+ 'python', 'compute_engine', or 'tensor_db'. Defaults to 'python'.
+ - 'python': Pure-python implementation for the client.
+ WARNING: not recommended for big datasets.
+ - 'compute_engine': C++ implementation of the Compute Engine for
+ the client. Not for in-memory or local datasets.
+ - 'tensor_db': Managed Tensor Database for storage and query.
+ Only for data in Deep Lake Managed Database.
+ Use `runtime = {"db_engine": True}` during dataset creation.
+ deep_memory (bool): Whether to use the Deep Memory model for improving
+ search results. Defaults to False if deep_memory is not specified
+ in the Vector Store initialization. If True, the distance metric
+ is set to "deepmemory_distance", which represents the metric with
+ which the model was trained. The search is performed using the Deep
+ Memory model. If False, the distance metric is set to "COS" or
+ whatever distance metric user specifies.
+
+ Returns:
+ List[Document]: List of Documents most similar to the query vector.
+ """
+
+ return self._search(
+ query=query,
+ k=k,
+ use_maximal_marginal_relevance=False,
+ return_score=False,
+ **kwargs,
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: Union[List[float], np.ndarray],
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """
+ Return docs most similar to embedding vector.
+
+ Examples:
+ >>> # Search using an embedding
+ >>> data = vector_store.similarity_search_by_vector(
+ ... embedding=,
+ ... k=,
+ ... exec_option=,
+ ... )
+
+ Args:
+ embedding (Union[List[float], np.ndarray]):
+ Embedding to find similar docs.
+ k (int): Number of Documents to return. Defaults to 4.
+ **kwargs: Additional keyword arguments including:
+ filter (Union[Dict, Callable], optional):
+ Additional filter before embedding search.
+ - ``Dict`` - Key-value search on tensors of htype json. True
+ if all key-value filters are satisfied.
+ Dict = {"tensor_name_1": {"key": value},
+ "tensor_name_2": {"key": value}}
+ - ``Function`` - Any function compatible with
+ `deeplake.filter`.
+ Defaults to None.
+ exec_option (str): Options for search execution include
+ "python", "compute_engine", or "tensor_db". Defaults to
+ "python".
+ - "python" - Pure-python implementation running on the client.
+ Can be used for data stored anywhere. WARNING: using this
+ option with big datasets is discouraged due to potential
+ memory issues.
+ - "compute_engine" - Performant C++ implementation of the Deep
+ Lake Compute Engine. Runs on the client and can be used for
+ any data stored in or connected to Deep Lake. It cannot be
+ used with in-memory or local datasets.
+ - "tensor_db" - Performant, fully-hosted Managed Tensor Database.
+ Responsible for storage and query execution. Only available
+ for data stored in the Deep Lake Managed Database.
+ To store datasets in this database, specify
+ `runtime = {"db_engine": True}` during dataset creation.
+ distance_metric (str): `L2` for Euclidean, `L1` for Nuclear,
+ `max` for L-infinity distance, `cos` for cosine similarity,
+ 'dot' for dot product. Defaults to `L2`.
+ deep_memory (bool): Whether to use the Deep Memory model for improving
+ search results. Defaults to False if deep_memory is not specified
+ in the Vector Store initialization. If True, the distance metric
+ is set to "deepmemory_distance", which represents the metric with
+ which the model was trained. The search is performed using the Deep
+ Memory model. If False, the distance metric is set to "COS" or
+ whatever distance metric user specifies.
+
+ Returns:
+ List[Document]: List of Documents most similar to the query vector.
+ """
+
+ return self._search(
+ embedding=embedding,
+ k=k,
+ use_maximal_marginal_relevance=False,
+ return_score=False,
+ **kwargs,
+ )
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """
+ Run similarity search with Deep Lake with distance returned.
+
+ Examples:
+ >>> data = vector_store.similarity_search_with_score(
+ ... query=,
+ ... embedding=
+ ... k=,
+ ... exec_option=,
+ ... )
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ **kwargs: Additional keyword arguments. Some of these arguments are:
+ distance_metric: `L2` for Euclidean, `L1` for Nuclear, `max` L-infinity
+ distance, `cos` for cosine similarity, 'dot' for dot product.
+ Defaults to `L2`.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ embedding_function (Callable): Embedding function to use. Defaults
+ to None.
+ exec_option (str): DeepLakeVectorStore supports 3 ways to perform
+ searching. It could be either "python", "compute_engine" or
+ "tensor_db". Defaults to "python".
+ - "python" - Pure-python implementation running on the client.
+ Can be used for data stored anywhere. WARNING: using this
+ option with big datasets is discouraged due to potential
+ memory issues.
+ - "compute_engine" - Performant C++ implementation of the Deep
+ Lake Compute Engine. Runs on the client and can be used for
+ any data stored in or connected to Deep Lake. It cannot be used
+ with in-memory or local datasets.
+ - "tensor_db" - Performant, fully-hosted Managed Tensor Database.
+ Responsible for storage and query execution. Only available for
+ data stored in the Deep Lake Managed Database. To store datasets
+ in this database, specify `runtime = {"db_engine": True}`
+ during dataset creation.
+ deep_memory (bool): Whether to use the Deep Memory model for improving
+ search results. Defaults to False if deep_memory is not specified
+ in the Vector Store initialization. If True, the distance metric
+ is set to "deepmemory_distance", which represents the metric with
+ which the model was trained. The search is performed using the Deep
+ Memory model. If False, the distance metric is set to "COS" or
+ whatever distance metric user specifies.
+
+ Returns:
+ List[Tuple[Document, float]]: List of documents most similar to the query
+ text with distance in float."""
+
+ return self._search(
+ query=query,
+ k=k,
+ return_score=True,
+ **kwargs,
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ exec_option: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """
+ Return docs selected using the maximal marginal relevance. Maximal marginal
+ relevance optimizes for similarity to query AND diversity among selected docs.
+
+ Examples:
+ >>> data = vector_store.max_marginal_relevance_search_by_vector(
+ ... embedding=,
+ ... fetch_k=,
+ ... k=,
+ ... exec_option=,
+ ... )
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch for MMR algorithm.
+ lambda_mult: Number between 0 and 1 determining the degree of diversity.
+ 0 corresponds to max diversity and 1 to min diversity. Defaults to 0.5.
+ exec_option (str): DeepLakeVectorStore supports 3 ways for searching.
+ Could be "python", "compute_engine" or "tensor_db". Defaults to
+ "python".
+ - "python" - Pure-python implementation running on the client.
+ Can be used for data stored anywhere. WARNING: using this
+ option with big datasets is discouraged due to potential
+ memory issues.
+ - "compute_engine" - Performant C++ implementation of the Deep
+ Lake Compute Engine. Runs on the client and can be used for
+ any data stored in or connected to Deep Lake. It cannot be used
+ with in-memory or local datasets.
+ - "tensor_db" - Performant, fully-hosted Managed Tensor Database.
+ Responsible for storage and query execution. Only available for
+ data stored in the Deep Lake Managed Database. To store datasets
+ in this database, specify `runtime = {"db_engine": True}`
+ during dataset creation.
+ deep_memory (bool): Whether to use the Deep Memory model for improving
+ search results. Defaults to False if deep_memory is not specified
+ in the Vector Store initialization. If True, the distance metric
+ is set to "deepmemory_distance", which represents the metric with
+ which the model was trained. The search is performed using the Deep
+ Memory model. If False, the distance metric is set to "COS" or
+ whatever distance metric user specifies.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ List[Documents] - A list of documents.
+ """
+
+ return self._search(
+ embedding=embedding,
+ k=k,
+ fetch_k=fetch_k,
+ use_maximal_marginal_relevance=True,
+ lambda_mult=lambda_mult,
+ exec_option=exec_option,
+ **kwargs,
+ )
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ exec_option: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Examples:
+ >>> # Search using an embedding
+ >>> data = vector_store.max_marginal_relevance_search(
+ ... query = ,
+ ... embedding_function = ,
+ ... k = ,
+ ... exec_option = ,
+ ... )
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents for MMR algorithm.
+ lambda_mult: Value between 0 and 1. 0 corresponds
+ to maximum diversity and 1 to minimum.
+ Defaults to 0.5.
+ exec_option (str): Supports 3 ways to perform searching.
+ - "python" - Pure-python implementation running on the client.
+ Can be used for data stored anywhere. WARNING: using this
+ option with big datasets is discouraged due to potential
+ memory issues.
+ - "compute_engine" - Performant C++ implementation of the Deep
+ Lake Compute Engine. Runs on the client and can be used for
+ any data stored in or connected to Deep Lake. It cannot be
+ used with in-memory or local datasets.
+ - "tensor_db" - Performant, fully-hosted Managed Tensor Database.
+ Responsible for storage and query execution. Only available
+ for data stored in the Deep Lake Managed Database. To store
+ datasets in this database, specify
+ `runtime = {"db_engine": True}` during dataset creation.
+ deep_memory (bool): Whether to use the Deep Memory model for improving
+ search results. Defaults to False if deep_memory is not specified
+ in the Vector Store initialization. If True, the distance metric
+ is set to "deepmemory_distance", which represents the metric with
+ which the model was trained. The search is performed using the Deep
+ Memory model. If False, the distance metric is set to "COS" or
+ whatever distance metric user specifies.
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+
+ Raises:
+ ValueError: when MRR search is on but embedding function is
+ not specified.
+ """
+ embedding_function = kwargs.get("embedding") or self._embedding_function
+ if embedding_function is None:
+ raise ValueError(
+ "For MMR search, you must specify an embedding function on"
+ " `creation` or during add call."
+ )
+ return self._search(
+ query=query,
+ k=k,
+ fetch_k=fetch_k,
+ use_maximal_marginal_relevance=True,
+ lambda_mult=lambda_mult,
+ exec_option=exec_option,
+ embedding_function=embedding_function, # type: ignore
+ **kwargs,
+ )
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ dataset_path: str = _LANGCHAIN_DEFAULT_DEEPLAKE_PATH,
+ **kwargs: Any,
+ ) -> DeepLake:
+ """Create a Deep Lake dataset from a raw documents.
+
+ If a dataset_path is specified, the dataset will be persisted in that location,
+ otherwise by default at `./deeplake`
+
+ Examples:
+ >>> # Search using an embedding
+ >>> vector_store = DeepLake.from_texts(
+ ... texts = ,
+ ... embedding_function = ,
+ ... k = ,
+ ... exec_option = ,
+ ... )
+
+ Args:
+ dataset_path (str): - The full path to the dataset. Can be:
+ - Deep Lake cloud path of the form ``hub://username/dataset_name``.
+ To write to Deep Lake cloud datasets,
+ ensure that you are logged in to Deep Lake
+ (use 'activeloop login' from command line)
+ - AWS S3 path of the form ``s3://bucketname/path/to/dataset``.
+ Credentials are required in either the environment
+ - Google Cloud Storage path of the form
+ ``gcs://bucketname/path/to/dataset`` Credentials are required
+ in either the environment
+ - Local file system path of the form ``./path/to/dataset`` or
+ ``~/path/to/dataset`` or ``path/to/dataset``.
+ - In-memory path of the form ``mem://path/to/dataset`` which doesn't
+ save the dataset, but keeps it in memory instead.
+ Should be used only for testing as it does not persist.
+ texts (List[Document]): List of documents to add.
+ embedding (Optional[Embeddings]): Embedding function. Defaults to None.
+ Note, in other places, it is called embedding_function.
+ metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
+ ids (Optional[List[str]]): List of document IDs. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ DeepLake: Deep Lake dataset.
+ """
+ deeplake_dataset = cls(dataset_path=dataset_path, embedding=embedding, **kwargs)
+ deeplake_dataset.add_texts(
+ texts=texts,
+ metadatas=metadatas,
+ ids=ids,
+ )
+ return deeplake_dataset
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> bool:
+ """Delete the entities in the dataset.
+
+ Args:
+ ids (Optional[List[str]], optional): The document_ids to delete.
+ Defaults to None.
+ **kwargs: Other keyword arguments that subclasses might use.
+ - filter (Optional[Dict[str, str]], optional): The filter to delete by.
+ - delete_all (Optional[bool], optional): Whether to drop the dataset.
+
+ Returns:
+ bool: Whether the delete operation was successful.
+ """
+ filter = kwargs.get("filter")
+ delete_all = kwargs.get("delete_all")
+
+ self.vectorstore.delete(ids=ids, filter=filter, delete_all=delete_all)
+
+ return True
+
+ @classmethod
+ def force_delete_by_path(cls, path: str) -> None:
+ """Force delete dataset by path.
+
+ Args:
+ path (str): path of the dataset to delete.
+
+ Raises:
+ ValueError: if deeplake is not installed.
+ """
+
+ try:
+ import deeplake
+ except ImportError:
+ raise ValueError(
+ "Could not import deeplake python package. "
+ "Please install it with `pip install deeplake`."
+ )
+ deeplake.delete(path, large_ok=True, force=True)
+
+ def delete_dataset(self) -> None:
+ """Delete the collection."""
+ self.delete(delete_all=True)
+
+ def ds(self) -> Any:
+ logger.warning(
+ "this method is deprecated and will be removed, "
+ "better to use `db.vectorstore.dataset` instead."
+ )
+ return self.vectorstore.dataset
diff --git a/libs/community/langchain_community/vectorstores/dingo.py b/libs/community/langchain_community/vectorstores/dingo.py
new file mode 100644
index 00000000000..57a61982222
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/dingo.py
@@ -0,0 +1,376 @@
+from __future__ import annotations
+
+import logging
+import uuid
+from typing import Any, Iterable, List, Optional, Tuple
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+logger = logging.getLogger(__name__)
+
+
+class Dingo(VectorStore):
+ """`Dingo` vector store.
+
+ To use, you should have the ``dingodb`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Dingo
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ dingo = Dingo(embeddings, "text")
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ text_key: str,
+ *,
+ client: Any = None,
+ index_name: Optional[str] = None,
+ dimension: int = 1024,
+ host: Optional[List[str]] = None,
+ user: str = "root",
+ password: str = "123123",
+ self_id: bool = False,
+ ):
+ """Initialize with Dingo client."""
+ try:
+ import dingodb
+ except ImportError:
+ raise ImportError(
+ "Could not import dingo python package. "
+ "Please install it with `pip install dingodb."
+ )
+
+ host = host if host is not None else ["172.20.31.10:13000"]
+
+ # collection
+ if client is not None:
+ dingo_client = client
+ else:
+ try:
+ # connect to dingo db
+ dingo_client = dingodb.DingoDB(user, password, host)
+ except ValueError as e:
+ raise ValueError(f"Dingo failed to connect: {e}")
+
+ self._text_key = text_key
+ self._client = dingo_client
+
+ if (
+ index_name is not None
+ and index_name not in dingo_client.get_index()
+ and index_name.upper() not in dingo_client.get_index()
+ ):
+ if self_id is True:
+ dingo_client.create_index(
+ index_name, dimension=dimension, auto_id=False
+ )
+ else:
+ dingo_client.create_index(index_name, dimension=dimension)
+
+ self._index_name = index_name
+ self._embedding = embedding
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embedding
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ text_key: str = "text",
+ batch_size: int = 500,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids to associate with the texts.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+
+ """
+
+ # Embed and create the documents
+ ids = ids or [str(uuid.uuid1().int)[:13] for _ in texts]
+ metadatas_list = []
+ texts = list(texts)
+ embeds = self._embedding.embed_documents(texts)
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+ metadata[self._text_key] = text
+ metadatas_list.append(metadata)
+ # upsert to Dingo
+ for i in range(0, len(list(texts)), batch_size):
+ j = i + batch_size
+ add_res = self._client.vector_add(
+ self._index_name, metadatas_list[i:j], embeds[i:j], ids[i:j]
+ )
+ if not add_res:
+ raise Exception("vector add fail")
+
+ return ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ search_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return Dingo documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ search_params: Dictionary of argument(s) to filter on metadata
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query, k=k, search_params=search_params
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ search_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return Dingo documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ search_params: Dictionary of argument(s) to filter on metadata
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ docs = []
+ query_obj = self._embedding.embed_query(query)
+ results = self._client.vector_search(
+ self._index_name, xq=query_obj, top_k=k, search_params=search_params
+ )
+
+ if not results:
+ return []
+
+ for res in results[0]["vectorWithDistances"]:
+ metadatas = res["scalarData"]
+ id = res["id"]
+ score = res["distance"]
+ text = metadatas[self._text_key]["fields"][0]["data"]
+ metadata = {"id": id, "text": text, "score": score}
+ for meta_key in metadatas.keys():
+ metadata[meta_key] = metadatas[meta_key]["fields"][0]["data"]
+ docs.append((Document(page_content=text, metadata=metadata), score))
+
+ return docs
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ search_params: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ results = self._client.vector_search(
+ self._index_name, [embedding], search_params=search_params, top_k=k
+ )
+
+ mmr_selected = maximal_marginal_relevance(
+ np.array([embedding], dtype=np.float32),
+ [
+ item["vector"]["floatValues"]
+ for item in results[0]["vectorWithDistances"]
+ ],
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ selected = []
+ for i in mmr_selected:
+ meta_data = {}
+ for k, v in results[0]["vectorWithDistances"][i]["scalarData"].items():
+ meta_data.update({str(k): v["fields"][0]["data"]})
+ selected.append(meta_data)
+ return [
+ Document(page_content=metadata.pop(self._text_key), metadata=metadata)
+ for metadata in selected
+ ]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ search_params: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding = self._embedding.embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ embedding, k, fetch_k, lambda_mult, search_params
+ )
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ text_key: str = "text",
+ index_name: Optional[str] = None,
+ dimension: int = 1024,
+ client: Any = None,
+ host: List[str] = ["172.20.31.10:13000"],
+ user: str = "root",
+ password: str = "123123",
+ batch_size: int = 500,
+ **kwargs: Any,
+ ) -> Dingo:
+ """Construct Dingo wrapper from raw documents.
+
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Adds the documents to a provided Dingo index
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Dingo
+ from langchain_community.embeddings import OpenAIEmbeddings
+ import dingodb
+ sss
+ embeddings = OpenAIEmbeddings()
+ dingo = Dingo.from_texts(
+ texts,
+ embeddings,
+ index_name="langchain-demo"
+ )
+ """
+ try:
+ import dingodb
+ except ImportError:
+ raise ImportError(
+ "Could not import dingo python package. "
+ "Please install it with `pip install dingodb`."
+ )
+
+ if client is not None:
+ dingo_client = client
+ else:
+ try:
+ # connect to dingo db
+ dingo_client = dingodb.DingoDB(user, password, host)
+ except ValueError as e:
+ raise ValueError(f"Dingo failed to connect: {e}")
+ if kwargs is not None and kwargs.get("self_id") is True:
+ if (
+ index_name is not None
+ and index_name not in dingo_client.get_index()
+ and index_name.upper() not in dingo_client.get_index()
+ ):
+ dingo_client.create_index(
+ index_name, dimension=dimension, auto_id=False
+ )
+ else:
+ if (
+ index_name is not None
+ and index_name not in dingo_client.get_index()
+ and index_name.upper() not in dingo_client.get_index()
+ ):
+ dingo_client.create_index(index_name, dimension=dimension)
+
+ # Embed and create the documents
+
+ ids = ids or [str(uuid.uuid1().int)[:13] for _ in texts]
+ metadatas_list = []
+ texts = list(texts)
+ embeds = embedding.embed_documents(texts)
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+ metadata[text_key] = text
+ metadatas_list.append(metadata)
+
+ # upsert to Dingo
+ for i in range(0, len(list(texts)), batch_size):
+ j = i + batch_size
+ add_res = dingo_client.vector_add(
+ index_name, metadatas_list[i:j], embeds[i:j], ids[i:j]
+ )
+ if not add_res:
+ raise Exception("vector add fail")
+ return cls(embedding, text_key, client=dingo_client, index_name=index_name)
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Any:
+ """Delete by vector IDs or filter.
+ Args:
+ ids: List of ids to delete.
+ """
+
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ return self._client.vector_delete(self._index_name, ids=ids)
diff --git a/libs/community/langchain_community/vectorstores/docarray/__init__.py b/libs/community/langchain_community/vectorstores/docarray/__init__.py
new file mode 100644
index 00000000000..b5877fec88b
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/docarray/__init__.py
@@ -0,0 +1,7 @@
+from langchain_community.vectorstores.docarray.hnsw import DocArrayHnswSearch
+from langchain_community.vectorstores.docarray.in_memory import DocArrayInMemorySearch
+
+__all__ = [
+ "DocArrayHnswSearch",
+ "DocArrayInMemorySearch",
+]
diff --git a/libs/community/langchain_community/vectorstores/docarray/base.py b/libs/community/langchain_community/vectorstores/docarray/base.py
new file mode 100644
index 00000000000..bfb00b7af0f
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/docarray/base.py
@@ -0,0 +1,203 @@
+from abc import ABC
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import Field
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+if TYPE_CHECKING:
+ from docarray import BaseDoc
+ from docarray.index.abstract import BaseDocIndex
+
+
+def _check_docarray_import() -> None:
+ try:
+ import docarray
+
+ da_version = docarray.__version__.split(".")
+ if int(da_version[0]) == 0 and int(da_version[1]) <= 31:
+ raise ImportError(
+ f"To use the DocArrayHnswSearch VectorStore the docarray "
+ f"version >=0.32.0 is expected, received: {docarray.__version__}."
+ f"To upgrade, please run: `pip install -U docarray`."
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import docarray python package. "
+ 'Please install it with `pip install "langchain[docarray]"`.'
+ )
+
+
+class DocArrayIndex(VectorStore, ABC):
+ """Base class for `DocArray` based vector stores."""
+
+ def __init__(
+ self,
+ doc_index: "BaseDocIndex",
+ embedding: Embeddings,
+ ):
+ """Initialize a vector store from DocArray's DocIndex."""
+ self.doc_index = doc_index
+ self.embedding = embedding
+
+ @staticmethod
+ def _get_doc_cls(**embeddings_params: Any) -> Type["BaseDoc"]:
+ """Get docarray Document class describing the schema of DocIndex."""
+ from docarray import BaseDoc
+ from docarray.typing import NdArray
+
+ class DocArrayDoc(BaseDoc):
+ text: Optional[str]
+ embedding: Optional[NdArray] = Field(**embeddings_params)
+ metadata: Optional[dict]
+
+ return DocArrayDoc
+
+ @property
+ def doc_cls(self) -> Type["BaseDoc"]:
+ if self.doc_index._schema is None:
+ raise ValueError("doc_index expected to have non-null _schema attribute.")
+ return self.doc_index._schema
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Embed texts and add to the vector store.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ ids: List[str] = []
+ embeddings = self.embedding.embed_documents(list(texts))
+ for i, (t, e) in enumerate(zip(texts, embeddings)):
+ m = metadatas[i] if metadatas else {}
+ doc = self.doc_cls(text=t, embedding=e, metadata=m)
+ self.doc_index.index([doc])
+ ids.append(str(doc.id))
+
+ return ids
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of documents most similar to the query text and
+ cosine distance in float for each.
+ Lower score represents more similarity.
+ """
+ query_embedding = self.embedding.embed_query(query)
+ query_doc = self.doc_cls(embedding=query_embedding) # type: ignore
+ docs, scores = self.doc_index.find(query_doc, search_field="embedding", limit=k)
+
+ result = [
+ (Document(page_content=doc.text, metadata=doc.metadata), score)
+ for doc, score in zip(docs, scores)
+ ]
+ return result
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ results = self.similarity_search_with_score(query, k=k, **kwargs)
+ return [doc for doc, _ in results]
+
+ def _similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and relevance scores, normalized on a scale from 0 to 1.
+
+ 0 is dissimilar, 1 is most similar.
+ """
+ raise NotImplementedError()
+
+ def similarity_search_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+
+ query_doc = self.doc_cls(embedding=embedding) # type: ignore
+ docs = self.doc_index.find(
+ query_doc, search_field="embedding", limit=k
+ ).documents
+
+ result = [
+ Document(page_content=doc.text, metadata=doc.metadata) for doc in docs
+ ]
+ return result
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ query_embedding = self.embedding.embed_query(query)
+ query_doc = self.doc_cls(embedding=query_embedding) # type: ignore
+
+ docs = self.doc_index.find(
+ query_doc, search_field="embedding", limit=fetch_k
+ ).documents
+
+ mmr_selected = maximal_marginal_relevance(
+ np.array(query_embedding), docs.embedding, k=k
+ )
+ results = [
+ Document(page_content=docs[idx].text, metadata=docs[idx].metadata)
+ for idx in mmr_selected
+ ]
+ return results
diff --git a/libs/community/langchain_community/vectorstores/docarray/hnsw.py b/libs/community/langchain_community/vectorstores/docarray/hnsw.py
new file mode 100644
index 00000000000..43948471849
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/docarray/hnsw.py
@@ -0,0 +1,109 @@
+from __future__ import annotations
+
+from typing import Any, List, Literal, Optional
+
+from langchain_core.embeddings import Embeddings
+
+from langchain_community.vectorstores.docarray.base import (
+ DocArrayIndex,
+ _check_docarray_import,
+)
+
+
+class DocArrayHnswSearch(DocArrayIndex):
+ """`HnswLib` storage using `DocArray` package.
+
+ To use it, you should have the ``docarray`` package with version >=0.32.0 installed.
+ You can install it with `pip install "langchain[docarray]"`.
+ """
+
+ @classmethod
+ def from_params(
+ cls,
+ embedding: Embeddings,
+ work_dir: str,
+ n_dim: int,
+ dist_metric: Literal["cosine", "ip", "l2"] = "cosine",
+ max_elements: int = 1024,
+ index: bool = True,
+ ef_construction: int = 200,
+ ef: int = 10,
+ M: int = 16,
+ allow_replace_deleted: bool = True,
+ num_threads: int = 1,
+ **kwargs: Any,
+ ) -> DocArrayHnswSearch:
+ """Initialize DocArrayHnswSearch store.
+
+ Args:
+ embedding (Embeddings): Embedding function.
+ work_dir (str): path to the location where all the data will be stored.
+ n_dim (int): dimension of an embedding.
+ dist_metric (str): Distance metric for DocArrayHnswSearch can be one of:
+ "cosine", "ip", and "l2". Defaults to "cosine".
+ max_elements (int): Maximum number of vectors that can be stored.
+ Defaults to 1024.
+ index (bool): Whether an index should be built for this field.
+ Defaults to True.
+ ef_construction (int): defines a construction time/accuracy trade-off.
+ Defaults to 200.
+ ef (int): parameter controlling query time/accuracy trade-off.
+ Defaults to 10.
+ M (int): parameter that defines the maximum number of outgoing
+ connections in the graph. Defaults to 16.
+ allow_replace_deleted (bool): Enables replacing of deleted elements
+ with new added ones. Defaults to True.
+ num_threads (int): Sets the number of cpu threads to use. Defaults to 1.
+ **kwargs: Other keyword arguments to be passed to the get_doc_cls method.
+ """
+ _check_docarray_import()
+ from docarray.index import HnswDocumentIndex
+
+ doc_cls = cls._get_doc_cls(
+ dim=n_dim,
+ space=dist_metric,
+ max_elements=max_elements,
+ index=index,
+ ef_construction=ef_construction,
+ ef=ef,
+ M=M,
+ allow_replace_deleted=allow_replace_deleted,
+ num_threads=num_threads,
+ **kwargs,
+ )
+ doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir) # type: ignore
+ return cls(doc_index, embedding)
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ work_dir: Optional[str] = None,
+ n_dim: Optional[int] = None,
+ **kwargs: Any,
+ ) -> DocArrayHnswSearch:
+ """Create an DocArrayHnswSearch store and insert data.
+
+
+ Args:
+ texts (List[str]): Text data.
+ embedding (Embeddings): Embedding function.
+ metadatas (Optional[List[dict]]): Metadata for each text if it exists.
+ Defaults to None.
+ work_dir (str): path to the location where all the data will be stored.
+ n_dim (int): dimension of an embedding.
+ **kwargs: Other keyword arguments to be passed to the __init__ method.
+
+ Returns:
+ DocArrayHnswSearch Vector Store
+ """
+ if work_dir is None:
+ raise ValueError("`work_dir` parameter has not been set.")
+ if n_dim is None:
+ raise ValueError("`n_dim` parameter has not been set.")
+
+ store = cls.from_params(embedding, work_dir, n_dim, **kwargs)
+ store.add_texts(texts=texts, metadatas=metadatas)
+ return store
diff --git a/libs/community/langchain_community/vectorstores/docarray/in_memory.py b/libs/community/langchain_community/vectorstores/docarray/in_memory.py
new file mode 100644
index 00000000000..468e5be20f5
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/docarray/in_memory.py
@@ -0,0 +1,70 @@
+"""Wrapper around in-memory storage."""
+from __future__ import annotations
+
+from typing import Any, Dict, List, Literal, Optional
+
+from langchain_core.embeddings import Embeddings
+
+from langchain_community.vectorstores.docarray.base import (
+ DocArrayIndex,
+ _check_docarray_import,
+)
+
+
+class DocArrayInMemorySearch(DocArrayIndex):
+ """In-memory `DocArray` storage for exact search.
+
+ To use it, you should have the ``docarray`` package with version >=0.32.0 installed.
+ You can install it with `pip install "langchain[docarray]"`.
+ """
+
+ @classmethod
+ def from_params(
+ cls,
+ embedding: Embeddings,
+ metric: Literal[
+ "cosine_sim", "euclidian_dist", "sgeuclidean_dist"
+ ] = "cosine_sim",
+ **kwargs: Any,
+ ) -> DocArrayInMemorySearch:
+ """Initialize DocArrayInMemorySearch store.
+
+ Args:
+ embedding (Embeddings): Embedding function.
+ metric (str): metric for exact nearest-neighbor search.
+ Can be one of: "cosine_sim", "euclidean_dist" and "sqeuclidean_dist".
+ Defaults to "cosine_sim".
+ **kwargs: Other keyword arguments to be passed to the get_doc_cls method.
+ """
+ _check_docarray_import()
+ from docarray.index import InMemoryExactNNIndex
+
+ doc_cls = cls._get_doc_cls(space=metric, **kwargs)
+ doc_index = InMemoryExactNNIndex[doc_cls]() # type: ignore
+ return cls(doc_index, embedding)
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ **kwargs: Any,
+ ) -> DocArrayInMemorySearch:
+ """Create an DocArrayInMemorySearch store and insert data.
+
+ Args:
+ texts (List[str]): Text data.
+ embedding (Embeddings): Embedding function.
+ metadatas (Optional[List[Dict[Any, Any]]]): Metadata for each text
+ if it exists. Defaults to None.
+ metric (str): metric for exact nearest-neighbor search.
+ Can be one of: "cosine_sim", "euclidean_dist" and "sqeuclidean_dist".
+ Defaults to "cosine_sim".
+
+ Returns:
+ DocArrayInMemorySearch Vector Store
+ """
+ store = cls.from_params(embedding, **kwargs)
+ store.add_texts(texts=texts, metadatas=metadatas)
+ return store
diff --git a/libs/community/langchain_community/vectorstores/elastic_vector_search.py b/libs/community/langchain_community/vectorstores/elastic_vector_search.py
new file mode 100644
index 00000000000..058bde90d8b
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/elastic_vector_search.py
@@ -0,0 +1,798 @@
+from __future__ import annotations
+
+import uuid
+import warnings
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+)
+
+from langchain_core._api import deprecated
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ from elasticsearch import Elasticsearch
+
+
+def _default_text_mapping(dim: int) -> Dict:
+ return {
+ "properties": {
+ "text": {"type": "text"},
+ "vector": {"type": "dense_vector", "dims": dim},
+ }
+ }
+
+
+def _default_script_query(query_vector: List[float], filter: Optional[dict]) -> Dict:
+ if filter:
+ ((key, value),) = filter.items()
+ filter = {"match": {f"metadata.{key}.keyword": f"{value}"}}
+ else:
+ filter = {"match_all": {}}
+ return {
+ "script_score": {
+ "query": filter,
+ "script": {
+ "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
+ "params": {"query_vector": query_vector},
+ },
+ }
+ }
+
+
+class ElasticVectorSearch(VectorStore):
+ """
+
+ ElasticVectorSearch uses the brute force method of searching on vectors.
+
+ Recommended to use ElasticsearchStore instead, which gives you the option
+ to uses the approx HNSW algorithm which performs better on large datasets.
+
+ ElasticsearchStore also supports metadata filtering, customising the
+ query retriever and much more!
+
+ You can read more on ElasticsearchStore:
+ https://python.langchain.com/docs/integrations/vectorstores/elasticsearch
+
+ To connect to an `Elasticsearch` instance that does not require
+ login credentials, pass the Elasticsearch URL and index name along with the
+ embedding object to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticVectorSearch
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embedding = OpenAIEmbeddings()
+ elastic_vector_search = ElasticVectorSearch(
+ elasticsearch_url="http://localhost:9200",
+ index_name="test_index",
+ embedding=embedding
+ )
+
+
+ To connect to an Elasticsearch instance that requires login credentials,
+ including Elastic Cloud, use the Elasticsearch URL format
+ https://username:password@es_host:9243. For example, to connect to Elastic
+ Cloud, create the Elasticsearch URL with the required authentication details and
+ pass it to the ElasticVectorSearch constructor as the named parameter
+ elasticsearch_url.
+
+ You can obtain your Elastic Cloud URL and login credentials by logging in to the
+ Elastic Cloud console at https://cloud.elastic.co, selecting your deployment, and
+ navigating to the "Deployments" page.
+
+ To obtain your Elastic Cloud password for the default "elastic" user:
+
+ 1. Log in to the Elastic Cloud console at https://cloud.elastic.co
+ 2. Go to "Security" > "Users"
+ 3. Locate the "elastic" user and click "Edit"
+ 4. Click "Reset password"
+ 5. Follow the prompts to reset the password
+
+ The format for Elastic Cloud URLs is
+ https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticVectorSearch
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embedding = OpenAIEmbeddings()
+
+ elastic_host = "cluster_id.region_id.gcp.cloud.es.io"
+ elasticsearch_url = f"https://username:password@{elastic_host}:9243"
+ elastic_vector_search = ElasticVectorSearch(
+ elasticsearch_url=elasticsearch_url,
+ index_name="test_index",
+ embedding=embedding
+ )
+
+ Args:
+ elasticsearch_url (str): The URL for the Elasticsearch instance.
+ index_name (str): The name of the Elasticsearch index for the embeddings.
+ embedding (Embeddings): An object that provides the ability to embed text.
+ It should be an instance of a class that subclasses the Embeddings
+ abstract base class, such as OpenAIEmbeddings()
+
+ Raises:
+ ValueError: If the elasticsearch python package is not installed.
+ """
+
+ def __init__(
+ self,
+ elasticsearch_url: str,
+ index_name: str,
+ embedding: Embeddings,
+ *,
+ ssl_verify: Optional[Dict[str, Any]] = None,
+ ):
+ """Initialize with necessary components."""
+ warnings.warn(
+ "ElasticVectorSearch will be removed in a future release. See"
+ "Elasticsearch integration docs on how to upgrade."
+ )
+
+ try:
+ import elasticsearch
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+ self.embedding = embedding
+ self.index_name = index_name
+ _ssl_verify = ssl_verify or {}
+ try:
+ self.client = elasticsearch.Elasticsearch(
+ elasticsearch_url,
+ **_ssl_verify,
+ headers={"user-agent": self.get_user_agent()},
+ )
+ except ValueError as e:
+ raise ValueError(
+ f"Your elasticsearch client string is mis-formatted. Got error: {e} "
+ )
+
+ @staticmethod
+ def get_user_agent() -> str:
+ from langchain_community import __version__
+
+ return f"langchain-py-dvs/{__version__}"
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ refresh_indices: bool = True,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of unique IDs.
+ refresh_indices: bool to refresh ElasticSearch indices
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ try:
+ from elasticsearch.exceptions import NotFoundError
+ from elasticsearch.helpers import bulk
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+ requests = []
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
+ embeddings = self.embedding.embed_documents(list(texts))
+ dim = len(embeddings[0])
+ mapping = _default_text_mapping(dim)
+
+ # check to see if the index already exists
+ try:
+ self.client.indices.get(index=self.index_name)
+ except NotFoundError:
+ # TODO would be nice to create index before embedding,
+ # just to save expensive steps for last
+ self.create_index(self.client, self.index_name, mapping)
+
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+ request = {
+ "_op_type": "index",
+ "_index": self.index_name,
+ "vector": embeddings[i],
+ "text": text,
+ "metadata": metadata,
+ "_id": ids[i],
+ }
+ requests.append(request)
+ bulk(self.client, requests)
+
+ if refresh_indices:
+ self.client.indices.refresh(index=self.index_name)
+ return ids
+
+ def similarity_search(
+ self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
+ documents = [d[0] for d in docs_and_scores]
+ return documents
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ Returns:
+ List of Documents most similar to the query.
+ """
+ embedding = self.embedding.embed_query(query)
+ script_query = _default_script_query(embedding, filter)
+ response = self.client_search(
+ self.client, self.index_name, script_query, size=k
+ )
+ hits = [hit for hit in response["hits"]["hits"]]
+ docs_and_scores = [
+ (
+ Document(
+ page_content=hit["_source"]["text"],
+ metadata=hit["_source"]["metadata"],
+ ),
+ hit["_score"],
+ )
+ for hit in hits
+ ]
+ return docs_and_scores
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ index_name: Optional[str] = None,
+ refresh_indices: bool = True,
+ **kwargs: Any,
+ ) -> ElasticVectorSearch:
+ """Construct ElasticVectorSearch wrapper from raw documents.
+
+ This is a user-friendly interface that:
+ 1. Embeds documents.
+ 2. Creates a new index for the embeddings in the Elasticsearch instance.
+ 3. Adds the documents to the newly created Elasticsearch index.
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticVectorSearch
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ elastic_vector_search = ElasticVectorSearch.from_texts(
+ texts,
+ embeddings,
+ elasticsearch_url="http://localhost:9200"
+ )
+ """
+ elasticsearch_url = get_from_dict_or_env(
+ kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
+ )
+ if "elasticsearch_url" in kwargs:
+ del kwargs["elasticsearch_url"]
+ index_name = index_name or uuid.uuid4().hex
+ vectorsearch = cls(elasticsearch_url, index_name, embedding, **kwargs)
+ vectorsearch.add_texts(
+ texts, metadatas=metadatas, ids=ids, refresh_indices=refresh_indices
+ )
+ return vectorsearch
+
+ def create_index(self, client: Any, index_name: str, mapping: Dict) -> None:
+ version_num = client.info()["version"]["number"][0]
+ version_num = int(version_num)
+ if version_num >= 8:
+ client.indices.create(index=index_name, mappings=mapping)
+ else:
+ client.indices.create(index=index_name, body={"mappings": mapping})
+
+ def client_search(
+ self, client: Any, index_name: str, script_query: Dict, size: int
+ ) -> Any:
+ version_num = client.info()["version"]["number"][0]
+ version_num = int(version_num)
+ if version_num >= 8:
+ response = client.search(index=index_name, query=script_query, size=size)
+ else:
+ response = client.search(
+ index=index_name, body={"query": script_query, "size": size}
+ )
+ return response
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
+ """Delete by vector IDs.
+
+ Args:
+ ids: List of ids to delete.
+ """
+
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ # TODO: Check if this can be done in bulk
+ for id in ids:
+ self.client.delete(index=self.index_name, id=id)
+
+
+@deprecated("0.0.265", alternative="ElasticsearchStore class.", pending=True)
+class ElasticKnnSearch(VectorStore):
+ """[DEPRECATED] `Elasticsearch` with k-nearest neighbor search
+ (`k-NN`) vector store.
+
+ Recommended to use ElasticsearchStore instead, which supports
+ metadata filtering, customising the query retriever and much more!
+
+ You can read more on ElasticsearchStore:
+ https://python.langchain.com/docs/integrations/vectorstores/elasticsearch
+
+ It creates an Elasticsearch index of text data that
+ can be searched using k-NN search. The text data is transformed into
+ vector embeddings using a provided embedding model, and these embeddings
+ are stored in the Elasticsearch index.
+
+ Attributes:
+ index_name (str): The name of the Elasticsearch index.
+ embedding (Embeddings): The embedding model to use for transforming text data
+ into vector embeddings.
+ es_connection (Elasticsearch, optional): An existing Elasticsearch connection.
+ es_cloud_id (str, optional): The Cloud ID of your Elasticsearch Service
+ deployment.
+ es_user (str, optional): The username for your Elasticsearch Service deployment.
+ es_password (str, optional): The password for your Elasticsearch Service
+ deployment.
+ vector_query_field (str, optional): The name of the field in the Elasticsearch
+ index that contains the vector embeddings.
+ query_field (str, optional): The name of the field in the Elasticsearch index
+ that contains the original text data.
+
+ Usage:
+ >>> from embeddings import Embeddings
+ >>> embedding = Embeddings.load('glove')
+ >>> es_search = ElasticKnnSearch('my_index', embedding)
+ >>> es_search.add_texts(['Hello world!', 'Another text'])
+ >>> results = es_search.knn_search('Hello')
+ [(Document(page_content='Hello world!', metadata={}), 0.9)]
+ """
+
+ def __init__(
+ self,
+ index_name: str,
+ embedding: Embeddings,
+ es_connection: Optional["Elasticsearch"] = None,
+ es_cloud_id: Optional[str] = None,
+ es_user: Optional[str] = None,
+ es_password: Optional[str] = None,
+ vector_query_field: Optional[str] = "vector",
+ query_field: Optional[str] = "text",
+ ):
+ try:
+ import elasticsearch
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+
+ warnings.warn(
+ "ElasticKnnSearch will be removed in a future release."
+ "Use ElasticsearchStore instead. See Elasticsearch "
+ "integration docs on how to upgrade."
+ )
+ self.embedding = embedding
+ self.index_name = index_name
+ self.query_field = query_field
+ self.vector_query_field = vector_query_field
+
+ # If a pre-existing Elasticsearch connection is provided, use it.
+ if es_connection is not None:
+ self.client = es_connection
+ else:
+ # If credentials for a new Elasticsearch connection are provided,
+ # create a new connection.
+ if es_cloud_id and es_user and es_password:
+ self.client = elasticsearch.Elasticsearch(
+ cloud_id=es_cloud_id, basic_auth=(es_user, es_password)
+ )
+ else:
+ raise ValueError(
+ """Either provide a pre-existing Elasticsearch connection, \
+ or valid credentials for creating a new connection."""
+ )
+
+ @staticmethod
+ def _default_knn_mapping(
+ dims: int, similarity: Optional[str] = "dot_product"
+ ) -> Dict:
+ return {
+ "properties": {
+ "text": {"type": "text"},
+ "vector": {
+ "type": "dense_vector",
+ "dims": dims,
+ "index": True,
+ "similarity": similarity,
+ },
+ }
+ }
+
+ def _default_knn_query(
+ self,
+ query_vector: Optional[List[float]] = None,
+ query: Optional[str] = None,
+ model_id: Optional[str] = None,
+ k: Optional[int] = 10,
+ num_candidates: Optional[int] = 10,
+ ) -> Dict:
+ knn: Dict = {
+ "field": self.vector_query_field,
+ "k": k,
+ "num_candidates": num_candidates,
+ }
+
+ # Case 1: `query_vector` is provided, but not `model_id` -> use query_vector
+ if query_vector and not model_id:
+ knn["query_vector"] = query_vector
+
+ # Case 2: `query` and `model_id` are provided, -> use query_vector_builder
+ elif query and model_id:
+ knn["query_vector_builder"] = {
+ "text_embedding": {
+ "model_id": model_id, # use 'model_id' argument
+ "model_text": query, # use 'query' argument
+ }
+ }
+
+ else:
+ raise ValueError(
+ "Either `query_vector` or `model_id` must be provided, but not both."
+ )
+
+ return knn
+
+ def similarity_search(
+ self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
+ ) -> List[Document]:
+ """
+ Pass through to `knn_search`
+ """
+ results = self.knn_search(query=query, k=k, **kwargs)
+ return [doc for doc, score in results]
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 10, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Pass through to `knn_search including score`"""
+ return self.knn_search(query=query, k=k, **kwargs)
+
+ def knn_search(
+ self,
+ query: Optional[str] = None,
+ k: Optional[int] = 10,
+ query_vector: Optional[List[float]] = None,
+ model_id: Optional[str] = None,
+ size: Optional[int] = 10,
+ source: Optional[bool] = True,
+ fields: Optional[
+ Union[List[Mapping[str, Any]], Tuple[Mapping[str, Any], ...], None]
+ ] = None,
+ page_content: Optional[str] = "text",
+ ) -> List[Tuple[Document, float]]:
+ """
+ Perform a k-NN search on the Elasticsearch index.
+
+ Args:
+ query (str, optional): The query text to search for.
+ k (int, optional): The number of nearest neighbors to return.
+ query_vector (List[float], optional): The query vector to search for.
+ model_id (str, optional): The ID of the model to use for transforming the
+ query text into a vector.
+ size (int, optional): The number of search results to return.
+ source (bool, optional): Whether to return the source of the search results.
+ fields (List[Mapping[str, Any]], optional): The fields to return in the
+ search results.
+ page_content (str, optional): The name of the field that contains the page
+ content.
+
+ Returns:
+ A list of tuples, where each tuple contains a Document object and a score.
+ """
+
+ # if not source and (fields == None or page_content not in fields):
+ if not source and (
+ fields is None or not any(page_content in field for field in fields)
+ ):
+ raise ValueError("If source=False `page_content` field must be in `fields`")
+
+ knn_query_body = self._default_knn_query(
+ query_vector=query_vector, query=query, model_id=model_id, k=k
+ )
+
+ # Perform the kNN search on the Elasticsearch index and return the results.
+ response = self.client.search(
+ index=self.index_name,
+ knn=knn_query_body,
+ size=size,
+ source=source,
+ fields=fields,
+ )
+
+ hits = [hit for hit in response["hits"]["hits"]]
+ docs_and_scores = [
+ (
+ Document(
+ page_content=hit["_source"][page_content]
+ if source
+ else hit["fields"][page_content][0],
+ metadata=hit["fields"] if fields else {},
+ ),
+ hit["_score"],
+ )
+ for hit in hits
+ ]
+
+ return docs_and_scores
+
+ def knn_hybrid_search(
+ self,
+ query: Optional[str] = None,
+ k: Optional[int] = 10,
+ query_vector: Optional[List[float]] = None,
+ model_id: Optional[str] = None,
+ size: Optional[int] = 10,
+ source: Optional[bool] = True,
+ knn_boost: Optional[float] = 0.9,
+ query_boost: Optional[float] = 0.1,
+ fields: Optional[
+ Union[List[Mapping[str, Any]], Tuple[Mapping[str, Any], ...], None]
+ ] = None,
+ page_content: Optional[str] = "text",
+ ) -> List[Tuple[Document, float]]:
+ """
+ Perform a hybrid k-NN and text search on the Elasticsearch index.
+
+ Args:
+ query (str, optional): The query text to search for.
+ k (int, optional): The number of nearest neighbors to return.
+ query_vector (List[float], optional): The query vector to search for.
+ model_id (str, optional): The ID of the model to use for transforming the
+ query text into a vector.
+ size (int, optional): The number of search results to return.
+ source (bool, optional): Whether to return the source of the search results.
+ knn_boost (float, optional): The boost value to apply to the k-NN search
+ results.
+ query_boost (float, optional): The boost value to apply to the text search
+ results.
+ fields (List[Mapping[str, Any]], optional): The fields to return in the
+ search results.
+ page_content (str, optional): The name of the field that contains the page
+ content.
+
+ Returns:
+ A list of tuples, where each tuple contains a Document object and a score.
+ """
+
+ # if not source and (fields == None or page_content not in fields):
+ if not source and (
+ fields is None or not any(page_content in field for field in fields)
+ ):
+ raise ValueError("If source=False `page_content` field must be in `fields`")
+
+ knn_query_body = self._default_knn_query(
+ query_vector=query_vector, query=query, model_id=model_id, k=k
+ )
+
+ # Modify the knn_query_body to add a "boost" parameter
+ knn_query_body["boost"] = knn_boost
+
+ # Generate the body of the standard Elasticsearch query
+ match_query_body = {
+ "match": {self.query_field: {"query": query, "boost": query_boost}}
+ }
+
+ # Perform the hybrid search on the Elasticsearch index and return the results.
+ response = self.client.search(
+ index=self.index_name,
+ query=match_query_body,
+ knn=knn_query_body,
+ fields=fields,
+ size=size,
+ source=source,
+ )
+
+ hits = [hit for hit in response["hits"]["hits"]]
+ docs_and_scores = [
+ (
+ Document(
+ page_content=hit["_source"][page_content]
+ if source
+ else hit["fields"][page_content][0],
+ metadata=hit["fields"] if fields else {},
+ ),
+ hit["_score"],
+ )
+ for hit in hits
+ ]
+
+ return docs_and_scores
+
+ def create_knn_index(self, mapping: Dict) -> None:
+ """
+ Create a new k-NN index in Elasticsearch.
+
+ Args:
+ mapping (Dict): The mapping to use for the new index.
+
+ Returns:
+ None
+ """
+
+ self.client.indices.create(index=self.index_name, mappings=mapping)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ model_id: Optional[str] = None,
+ refresh_indices: bool = False,
+ **kwargs: Any,
+ ) -> List[str]:
+ """
+ Add a list of texts to the Elasticsearch index.
+
+ Args:
+ texts (Iterable[str]): The texts to add to the index.
+ metadatas (List[Dict[Any, Any]], optional): A list of metadata dictionaries
+ to associate with the texts.
+ model_id (str, optional): The ID of the model to use for transforming the
+ texts into vectors.
+ refresh_indices (bool, optional): Whether to refresh the Elasticsearch
+ indices after adding the texts.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ A list of IDs for the added texts.
+ """
+
+ # Check if the index exists.
+ if not self.client.indices.exists(index=self.index_name):
+ dims = kwargs.get("dims")
+
+ if dims is None:
+ raise ValueError("ElasticKnnSearch requires 'dims' parameter")
+
+ similarity = kwargs.get("similarity")
+ optional_args = {}
+
+ if similarity is not None:
+ optional_args["similarity"] = similarity
+
+ mapping = self._default_knn_mapping(dims=dims, **optional_args)
+ self.create_knn_index(mapping)
+
+ embeddings = self.embedding.embed_documents(list(texts))
+
+ # body = []
+ body: List[Mapping[str, Any]] = []
+ for text, vector in zip(texts, embeddings):
+ body.extend(
+ [
+ {"index": {"_index": self.index_name}},
+ {"text": text, "vector": vector},
+ ]
+ )
+
+ responses = self.client.bulk(operations=body)
+
+ ids = [
+ item["index"]["_id"]
+ for item in responses["items"]
+ if item["index"]["result"] == "created"
+ ]
+
+ if refresh_indices:
+ self.client.indices.refresh(index=self.index_name)
+
+ return ids
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ **kwargs: Any,
+ ) -> ElasticKnnSearch:
+ """
+ Create a new ElasticKnnSearch instance and add a list of texts to the
+ Elasticsearch index.
+
+ Args:
+ texts (List[str]): The texts to add to the index.
+ embedding (Embeddings): The embedding model to use for transforming the
+ texts into vectors.
+ metadatas (List[Dict[Any, Any]], optional): A list of metadata dictionaries
+ to associate with the texts.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ A new ElasticKnnSearch instance.
+ """
+
+ index_name = kwargs.get("index_name", str(uuid.uuid4()))
+ es_connection = kwargs.get("es_connection")
+ es_cloud_id = kwargs.get("es_cloud_id")
+ es_user = kwargs.get("es_user")
+ es_password = kwargs.get("es_password")
+ vector_query_field = kwargs.get("vector_query_field", "vector")
+ query_field = kwargs.get("query_field", "text")
+ model_id = kwargs.get("model_id")
+ dims = kwargs.get("dims")
+
+ if dims is None:
+ raise ValueError("ElasticKnnSearch requires 'dims' parameter")
+
+ optional_args = {}
+
+ if vector_query_field is not None:
+ optional_args["vector_query_field"] = vector_query_field
+
+ if query_field is not None:
+ optional_args["query_field"] = query_field
+
+ knnvectorsearch = cls(
+ index_name=index_name,
+ embedding=embedding,
+ es_connection=es_connection,
+ es_cloud_id=es_cloud_id,
+ es_user=es_user,
+ es_password=es_password,
+ **optional_args,
+ )
+ # Encode the provided texts and add them to the newly created index.
+ knnvectorsearch.add_texts(texts, model_id=model_id, dims=dims, **optional_args)
+
+ return knnvectorsearch
diff --git a/libs/community/langchain_community/vectorstores/elasticsearch.py b/libs/community/langchain_community/vectorstores/elasticsearch.py
new file mode 100644
index 00000000000..8eb3eab0ab0
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/elasticsearch.py
@@ -0,0 +1,1275 @@
+import logging
+import uuid
+from abc import ABC, abstractmethod
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Union,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import (
+ DistanceStrategy,
+ maximal_marginal_relevance,
+)
+
+if TYPE_CHECKING:
+ from elasticsearch import Elasticsearch
+
+logger = logging.getLogger(__name__)
+
+
+class BaseRetrievalStrategy(ABC):
+ """Base class for `Elasticsearch` retrieval strategies."""
+
+ @abstractmethod
+ def query(
+ self,
+ query_vector: Union[List[float], None],
+ query: Union[str, None],
+ *,
+ k: int,
+ fetch_k: int,
+ vector_query_field: str,
+ text_field: str,
+ filter: List[dict],
+ similarity: Union[DistanceStrategy, None],
+ ) -> Dict:
+ """
+ Executes when a search is performed on the store.
+
+ Args:
+ query_vector: The query vector,
+ or None if not using vector-based query.
+ query: The text query, or None if not using text-based query.
+ k: The total number of results to retrieve.
+ fetch_k: The number of results to fetch initially.
+ vector_query_field: The field containing the vector
+ representations in the index.
+ text_field: The field containing the text data in the index.
+ filter: List of filter clauses to apply to the query.
+ similarity: The similarity strategy to use, or None if not using one.
+
+ Returns:
+ Dict: The Elasticsearch query body.
+ """
+
+ @abstractmethod
+ def index(
+ self,
+ dims_length: Union[int, None],
+ vector_query_field: str,
+ similarity: Union[DistanceStrategy, None],
+ ) -> Dict:
+ """
+ Executes when the index is created.
+
+ Args:
+ dims_length: Numeric length of the embedding vectors,
+ or None if not using vector-based query.
+ vector_query_field: The field containing the vector
+ representations in the index.
+ similarity: The similarity strategy to use,
+ or None if not using one.
+
+ Returns:
+ Dict: The Elasticsearch settings and mappings for the strategy.
+ """
+
+ def before_index_setup(
+ self, client: "Elasticsearch", text_field: str, vector_query_field: str
+ ) -> None:
+ """
+ Executes before the index is created. Used for setting up
+ any required Elasticsearch resources like a pipeline.
+
+ Args:
+ client: The Elasticsearch client.
+ text_field: The field containing the text data in the index.
+ vector_query_field: The field containing the vector
+ representations in the index.
+ """
+
+ def require_inference(self) -> bool:
+ """
+ Returns whether or not the strategy requires inference
+ to be performed on the text before it is added to the index.
+
+ Returns:
+ bool: Whether or not the strategy requires inference
+ to be performed on the text before it is added to the index.
+ """
+ return True
+
+
+class ApproxRetrievalStrategy(BaseRetrievalStrategy):
+ """Approximate retrieval strategy using the `HNSW` algorithm."""
+
+ def __init__(
+ self,
+ query_model_id: Optional[str] = None,
+ hybrid: Optional[bool] = False,
+ rrf: Optional[Union[dict, bool]] = True,
+ ):
+ self.query_model_id = query_model_id
+ self.hybrid = hybrid
+
+ # RRF has two optional parameters
+ # 'rank_constant', 'window_size'
+ # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
+ self.rrf = rrf
+
+ def query(
+ self,
+ query_vector: Union[List[float], None],
+ query: Union[str, None],
+ k: int,
+ fetch_k: int,
+ vector_query_field: str,
+ text_field: str,
+ filter: List[dict],
+ similarity: Union[DistanceStrategy, None],
+ ) -> Dict:
+ knn = {
+ "filter": filter,
+ "field": vector_query_field,
+ "k": k,
+ "num_candidates": fetch_k,
+ }
+
+ # Embedding provided via the embedding function
+ if query_vector and not self.query_model_id:
+ knn["query_vector"] = query_vector
+
+ # Case 2: Used when model has been deployed to
+ # Elasticsearch and can infer the query vector from the query text
+ elif query and self.query_model_id:
+ knn["query_vector_builder"] = {
+ "text_embedding": {
+ "model_id": self.query_model_id, # use 'model_id' argument
+ "model_text": query, # use 'query' argument
+ }
+ }
+
+ else:
+ raise ValueError(
+ "You must provide an embedding function or a"
+ " query_model_id to perform a similarity search."
+ )
+
+ # If hybrid, add a query to the knn query
+ # RRF is used to even the score from the knn query and text query
+ # RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
+ # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
+ if self.hybrid:
+ query_body = {
+ "knn": knn,
+ "query": {
+ "bool": {
+ "must": [
+ {
+ "match": {
+ text_field: {
+ "query": query,
+ }
+ }
+ }
+ ],
+ "filter": filter,
+ }
+ },
+ }
+
+ if isinstance(self.rrf, dict):
+ query_body["rank"] = {"rrf": self.rrf}
+ elif isinstance(self.rrf, bool) and self.rrf is True:
+ query_body["rank"] = {"rrf": {}}
+
+ return query_body
+ else:
+ return {"knn": knn}
+
+ def index(
+ self,
+ dims_length: Union[int, None],
+ vector_query_field: str,
+ similarity: Union[DistanceStrategy, None],
+ ) -> Dict:
+ """Create the mapping for the Elasticsearch index."""
+
+ if similarity is DistanceStrategy.COSINE:
+ similarityAlgo = "cosine"
+ elif similarity is DistanceStrategy.EUCLIDEAN_DISTANCE:
+ similarityAlgo = "l2_norm"
+ elif similarity is DistanceStrategy.DOT_PRODUCT:
+ similarityAlgo = "dot_product"
+ else:
+ raise ValueError(f"Similarity {similarity} not supported.")
+
+ return {
+ "mappings": {
+ "properties": {
+ vector_query_field: {
+ "type": "dense_vector",
+ "dims": dims_length,
+ "index": True,
+ "similarity": similarityAlgo,
+ },
+ }
+ }
+ }
+
+
+class ExactRetrievalStrategy(BaseRetrievalStrategy):
+ """Exact retrieval strategy using the `script_score` query."""
+
+ def query(
+ self,
+ query_vector: Union[List[float], None],
+ query: Union[str, None],
+ k: int,
+ fetch_k: int,
+ vector_query_field: str,
+ text_field: str,
+ filter: Union[List[dict], None],
+ similarity: Union[DistanceStrategy, None],
+ ) -> Dict:
+ if similarity is DistanceStrategy.COSINE:
+ similarityAlgo = (
+ f"cosineSimilarity(params.query_vector, '{vector_query_field}') + 1.0"
+ )
+ elif similarity is DistanceStrategy.EUCLIDEAN_DISTANCE:
+ similarityAlgo = (
+ f"1 / (1 + l2norm(params.query_vector, '{vector_query_field}'))"
+ )
+ elif similarity is DistanceStrategy.DOT_PRODUCT:
+ similarityAlgo = f"""
+ double value = dotProduct(params.query_vector, '{vector_query_field}');
+ return sigmoid(1, Math.E, -value);
+ """
+ else:
+ raise ValueError(f"Similarity {similarity} not supported.")
+
+ queryBool: Dict = {"match_all": {}}
+ if filter:
+ queryBool = {"bool": {"filter": filter}}
+
+ return {
+ "query": {
+ "script_score": {
+ "query": queryBool,
+ "script": {
+ "source": similarityAlgo,
+ "params": {"query_vector": query_vector},
+ },
+ },
+ }
+ }
+
+ def index(
+ self,
+ dims_length: Union[int, None],
+ vector_query_field: str,
+ similarity: Union[DistanceStrategy, None],
+ ) -> Dict:
+ """Create the mapping for the Elasticsearch index."""
+
+ return {
+ "mappings": {
+ "properties": {
+ vector_query_field: {
+ "type": "dense_vector",
+ "dims": dims_length,
+ "index": False,
+ },
+ }
+ }
+ }
+
+
+class SparseRetrievalStrategy(BaseRetrievalStrategy):
+ """Sparse retrieval strategy using the `text_expansion` processor."""
+
+ def __init__(self, model_id: Optional[str] = None):
+ self.model_id = model_id or ".elser_model_1"
+
+ def query(
+ self,
+ query_vector: Union[List[float], None],
+ query: Union[str, None],
+ k: int,
+ fetch_k: int,
+ vector_query_field: str,
+ text_field: str,
+ filter: List[dict],
+ similarity: Union[DistanceStrategy, None],
+ ) -> Dict:
+ return {
+ "query": {
+ "bool": {
+ "must": [
+ {
+ "text_expansion": {
+ f"{vector_query_field}.tokens": {
+ "model_id": self.model_id,
+ "model_text": query,
+ }
+ }
+ }
+ ],
+ "filter": filter,
+ }
+ }
+ }
+
+ def _get_pipeline_name(self) -> str:
+ return f"{self.model_id}_sparse_embedding"
+
+ def before_index_setup(
+ self, client: "Elasticsearch", text_field: str, vector_query_field: str
+ ) -> None:
+ # If model_id is provided, create a pipeline for the model
+ if self.model_id:
+ client.ingest.put_pipeline(
+ id=self._get_pipeline_name(),
+ description="Embedding pipeline for langchain vectorstore",
+ processors=[
+ {
+ "inference": {
+ "model_id": self.model_id,
+ "target_field": vector_query_field,
+ "field_map": {text_field: "text_field"},
+ "inference_config": {
+ "text_expansion": {"results_field": "tokens"}
+ },
+ }
+ }
+ ],
+ )
+
+ def index(
+ self,
+ dims_length: Union[int, None],
+ vector_query_field: str,
+ similarity: Union[DistanceStrategy, None],
+ ) -> Dict:
+ return {
+ "mappings": {
+ "properties": {
+ vector_query_field: {
+ "properties": {"tokens": {"type": "rank_features"}}
+ }
+ }
+ },
+ "settings": {"default_pipeline": self._get_pipeline_name()},
+ }
+
+ def require_inference(self) -> bool:
+ return False
+
+
+class ElasticsearchStore(VectorStore):
+ """`Elasticsearch` vector store.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticsearchStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ vectorstore = ElasticsearchStore(
+ embedding=OpenAIEmbeddings(),
+ index_name="langchain-demo",
+ es_url="http://localhost:9200"
+ )
+
+ Args:
+ index_name: Name of the Elasticsearch index to create.
+ es_url: URL of the Elasticsearch instance to connect to.
+ cloud_id: Cloud ID of the Elasticsearch instance to connect to.
+ es_user: Username to use when connecting to Elasticsearch.
+ es_password: Password to use when connecting to Elasticsearch.
+ es_api_key: API key to use when connecting to Elasticsearch.
+ es_connection: Optional pre-existing Elasticsearch connection.
+ vector_query_field: Optional. Name of the field to store
+ the embedding vectors in.
+ query_field: Optional. Name of the field to store the texts in.
+ strategy: Optional. Retrieval strategy to use when searching the index.
+ Defaults to ApproxRetrievalStrategy. Can be one of
+ ExactRetrievalStrategy, ApproxRetrievalStrategy,
+ or SparseRetrievalStrategy.
+ distance_strategy: Optional. Distance strategy to use when
+ searching the index.
+ Defaults to COSINE. Can be one of COSINE,
+ EUCLIDEAN_DISTANCE, or DOT_PRODUCT.
+
+ If you want to use a cloud hosted Elasticsearch instance, you can pass in the
+ cloud_id argument instead of the es_url argument.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticsearchStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ vectorstore = ElasticsearchStore(
+ embedding=OpenAIEmbeddings(),
+ index_name="langchain-demo",
+ es_cloud_id=""
+ es_user="elastic",
+ es_password=""
+ )
+
+ You can also connect to an existing Elasticsearch instance by passing in a
+ pre-existing Elasticsearch connection via the es_connection argument.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticsearchStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ from elasticsearch import Elasticsearch
+
+ es_connection = Elasticsearch("http://localhost:9200")
+
+ vectorstore = ElasticsearchStore(
+ embedding=OpenAIEmbeddings(),
+ index_name="langchain-demo",
+ es_connection=es_connection
+ )
+
+ ElasticsearchStore by default uses the ApproxRetrievalStrategy, which uses the
+ HNSW algorithm to perform approximate nearest neighbor search. This is the
+ fastest and most memory efficient algorithm.
+
+ If you want to use the Brute force / Exact strategy for searching vectors, you
+ can pass in the ExactRetrievalStrategy to the ElasticsearchStore constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticsearchStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ vectorstore = ElasticsearchStore(
+ embedding=OpenAIEmbeddings(),
+ index_name="langchain-demo",
+ es_url="http://localhost:9200",
+ strategy=ElasticsearchStore.ExactRetrievalStrategy()
+ )
+
+ Both strategies require that you know the similarity metric you want to use
+ when creating the index. The default is cosine similarity, but you can also
+ use dot product or euclidean distance.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticsearchStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ from langchain_community.vectorstores.utils import DistanceStrategy
+
+ vectorstore = ElasticsearchStore(
+ embedding=OpenAIEmbeddings(),
+ index_name="langchain-demo",
+ es_url="http://localhost:9200",
+ distance_strategy="DOT_PRODUCT"
+ )
+
+ """
+
+ def __init__(
+ self,
+ index_name: str,
+ *,
+ embedding: Optional[Embeddings] = None,
+ es_connection: Optional["Elasticsearch"] = None,
+ es_url: Optional[str] = None,
+ es_cloud_id: Optional[str] = None,
+ es_user: Optional[str] = None,
+ es_api_key: Optional[str] = None,
+ es_password: Optional[str] = None,
+ vector_query_field: str = "vector",
+ query_field: str = "text",
+ distance_strategy: Optional[
+ Literal[
+ DistanceStrategy.COSINE,
+ DistanceStrategy.DOT_PRODUCT,
+ DistanceStrategy.EUCLIDEAN_DISTANCE,
+ ]
+ ] = None,
+ strategy: BaseRetrievalStrategy = ApproxRetrievalStrategy(),
+ ):
+ self.embedding = embedding
+ self.index_name = index_name
+ self.query_field = query_field
+ self.vector_query_field = vector_query_field
+ self.distance_strategy = (
+ DistanceStrategy.COSINE
+ if distance_strategy is None
+ else DistanceStrategy[distance_strategy]
+ )
+ self.strategy = strategy
+
+ if es_connection is not None:
+ self.client = es_connection.options(
+ headers={"user-agent": self.get_user_agent()}
+ )
+ elif es_url is not None or es_cloud_id is not None:
+ self.client = ElasticsearchStore.connect_to_elasticsearch(
+ es_url=es_url,
+ username=es_user,
+ password=es_password,
+ cloud_id=es_cloud_id,
+ api_key=es_api_key,
+ )
+ else:
+ raise ValueError(
+ """Either provide a pre-existing Elasticsearch connection, \
+ or valid credentials for creating a new connection."""
+ )
+
+ @staticmethod
+ def get_user_agent() -> str:
+ from langchain_community import __version__
+
+ return f"langchain-py-vs/{__version__}"
+
+ @staticmethod
+ def connect_to_elasticsearch(
+ *,
+ es_url: Optional[str] = None,
+ cloud_id: Optional[str] = None,
+ api_key: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ ) -> "Elasticsearch":
+ try:
+ import elasticsearch
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+
+ if es_url and cloud_id:
+ raise ValueError(
+ "Both es_url and cloud_id are defined. Please provide only one."
+ )
+
+ connection_params: Dict[str, Any] = {}
+
+ if es_url:
+ connection_params["hosts"] = [es_url]
+ elif cloud_id:
+ connection_params["cloud_id"] = cloud_id
+ else:
+ raise ValueError("Please provide either elasticsearch_url or cloud_id.")
+
+ if api_key:
+ connection_params["api_key"] = api_key
+ elif username and password:
+ connection_params["basic_auth"] = (username, password)
+
+ es_client = elasticsearch.Elasticsearch(
+ **connection_params,
+ headers={"user-agent": ElasticsearchStore.get_user_agent()},
+ )
+ try:
+ es_client.info()
+ except Exception as e:
+ logger.error(f"Error connecting to Elasticsearch: {e}")
+ raise e
+
+ return es_client
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self.embedding
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 50,
+ filter: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return Elasticsearch documents most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k (int): Number of Documents to fetch to pass to knn num_candidates.
+ filter: Array of Elasticsearch filter clauses to apply to the query.
+
+ Returns:
+ List of Documents most similar to the query,
+ in descending order of similarity.
+ """
+
+ results = self._search(
+ query=query, k=k, fetch_k=fetch_k, filter=filter, **kwargs
+ )
+ return [doc for doc, _ in results]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ fields: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query (str): Text to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult (float): Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ fields: Other fields to get from elasticsearch source. These fields
+ will be added to the document metadata.
+
+ Returns:
+ List[Document]: A list of Documents selected by maximal marginal relevance.
+ """
+ if self.embedding is None:
+ raise ValueError("You must provide an embedding function to perform MMR")
+ remove_vector_query_field_from_metadata = True
+ if fields is None:
+ fields = [self.vector_query_field]
+ elif self.vector_query_field not in fields:
+ fields.append(self.vector_query_field)
+ else:
+ remove_vector_query_field_from_metadata = False
+
+ # Embed the query
+ query_embedding = self.embedding.embed_query(query)
+
+ # Fetch the initial documents
+ got_docs = self._search(
+ query_vector=query_embedding, k=fetch_k, fields=fields, **kwargs
+ )
+
+ # Get the embeddings for the fetched documents
+ got_embeddings = [doc.metadata[self.vector_query_field] for doc, _ in got_docs]
+
+ # Select documents using maximal marginal relevance
+ selected_indices = maximal_marginal_relevance(
+ np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k
+ )
+ selected_docs = [got_docs[i][0] for i in selected_indices]
+
+ if remove_vector_query_field_from_metadata:
+ for doc in selected_docs:
+ del doc.metadata[self.vector_query_field]
+
+ return selected_docs
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, filter: Optional[List[dict]] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Return Elasticsearch documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Array of Elasticsearch filter clauses to apply to the query.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ return self._search(query=query, k=k, filter=filter, **kwargs)
+
+ def similarity_search_by_vector_with_relevance_scores(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[List[Dict]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return Elasticsearch documents most similar to query, along with scores.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Array of Elasticsearch filter clauses to apply to the query.
+
+ Returns:
+ List of Documents most similar to the embedding and score for each
+ """
+ return self._search(query_vector=embedding, k=k, filter=filter, **kwargs)
+
+ def _search(
+ self,
+ query: Optional[str] = None,
+ k: int = 4,
+ query_vector: Union[List[float], None] = None,
+ fetch_k: int = 50,
+ fields: Optional[List[str]] = None,
+ filter: Optional[List[dict]] = None,
+ custom_query: Optional[Callable[[Dict, Union[str, None]], Dict]] = None,
+ doc_builder: Optional[Callable[[Dict], Document]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return Elasticsearch documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ query_vector: Embedding to look up documents similar to.
+ fetch_k: Number of candidates to fetch from each shard.
+ Defaults to 50.
+ fields: List of fields to return from Elasticsearch.
+ Defaults to only returning the text field.
+ filter: Array of Elasticsearch filter clauses to apply to the query.
+ custom_query: Function to modify the Elasticsearch
+ query body before it is sent to Elasticsearch.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ if fields is None:
+ fields = []
+
+ if "metadata" not in fields:
+ fields.append("metadata")
+
+ if self.query_field not in fields:
+ fields.append(self.query_field)
+
+ if self.embedding and query is not None:
+ query_vector = self.embedding.embed_query(query)
+
+ query_body = self.strategy.query(
+ query_vector=query_vector,
+ query=query,
+ k=k,
+ fetch_k=fetch_k,
+ vector_query_field=self.vector_query_field,
+ text_field=self.query_field,
+ filter=filter or [],
+ similarity=self.distance_strategy,
+ )
+
+ logger.debug(f"Query body: {query_body}")
+
+ if custom_query is not None:
+ query_body = custom_query(query_body, query)
+ logger.debug(f"Calling custom_query, Query body now: {query_body}")
+ # Perform the kNN search on the Elasticsearch index and return the results.
+ response = self.client.search(
+ index=self.index_name,
+ **query_body,
+ size=k,
+ source=fields,
+ )
+
+ def default_doc_builder(hit: Dict) -> Document:
+ return Document(
+ page_content=hit["_source"].get(self.query_field, ""),
+ metadata=hit["_source"]["metadata"],
+ )
+
+ doc_builder = doc_builder or default_doc_builder
+
+ docs_and_scores = []
+ for hit in response["hits"]["hits"]:
+ for field in fields:
+ if field in hit["_source"] and field not in [
+ "metadata",
+ self.query_field,
+ ]:
+ if "metadata" not in hit["_source"]:
+ hit["_source"]["metadata"] = {}
+ hit["_source"]["metadata"][field] = hit["_source"][field]
+
+ docs_and_scores.append(
+ (
+ doc_builder(hit),
+ hit["_score"],
+ )
+ )
+ return docs_and_scores
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ refresh_indices: Optional[bool] = True,
+ **kwargs: Any,
+ ) -> Optional[bool]:
+ """Delete documents from the Elasticsearch index.
+
+ Args:
+ ids: List of ids of documents to delete.
+ refresh_indices: Whether to refresh the index
+ after deleting documents. Defaults to True.
+ """
+ try:
+ from elasticsearch.helpers import BulkIndexError, bulk
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+
+ body = []
+
+ if ids is None:
+ raise ValueError("ids must be provided.")
+
+ for _id in ids:
+ body.append({"_op_type": "delete", "_index": self.index_name, "_id": _id})
+
+ if len(body) > 0:
+ try:
+ bulk(self.client, body, refresh=refresh_indices, ignore_status=404)
+ logger.debug(f"Deleted {len(body)} texts from index")
+
+ return True
+ except BulkIndexError as e:
+ logger.error(f"Error deleting texts: {e}")
+ firstError = e.errors[0].get("index", {}).get("error", {})
+ logger.error(f"First error reason: {firstError.get('reason')}")
+ raise e
+
+ else:
+ logger.debug("No texts to delete from index")
+ return False
+
+ def _create_index_if_not_exists(
+ self, index_name: str, dims_length: Optional[int] = None
+ ) -> None:
+ """Create the Elasticsearch index if it doesn't already exist.
+
+ Args:
+ index_name: Name of the Elasticsearch index to create.
+ dims_length: Length of the embedding vectors.
+ """
+
+ if self.client.indices.exists(index=index_name):
+ logger.debug(f"Index {index_name} already exists. Skipping creation.")
+
+ else:
+ if dims_length is None and self.strategy.require_inference():
+ raise ValueError(
+ "Cannot create index without specifying dims_length "
+ "when the index doesn't already exist. We infer "
+ "dims_length from the first embedding. Check that "
+ "you have provided an embedding function."
+ )
+
+ self.strategy.before_index_setup(
+ client=self.client,
+ text_field=self.query_field,
+ vector_query_field=self.vector_query_field,
+ )
+
+ indexSettings = self.strategy.index(
+ vector_query_field=self.vector_query_field,
+ dims_length=dims_length,
+ similarity=self.distance_strategy,
+ )
+ logger.debug(
+ f"Creating index {index_name} with mappings {indexSettings['mappings']}"
+ )
+ self.client.indices.create(index=index_name, **indexSettings)
+
+ def __add(
+ self,
+ texts: Iterable[str],
+ embeddings: Optional[List[List[float]]],
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ ids: Optional[List[str]] = None,
+ refresh_indices: bool = True,
+ create_index_if_not_exists: bool = True,
+ bulk_kwargs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ try:
+ from elasticsearch.helpers import BulkIndexError, bulk
+ except ImportError:
+ raise ImportError(
+ "Could not import elasticsearch python package. "
+ "Please install it with `pip install elasticsearch`."
+ )
+ bulk_kwargs = bulk_kwargs or {}
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
+ requests = []
+
+ if create_index_if_not_exists:
+ if embeddings:
+ dims_length = len(embeddings[0])
+ else:
+ dims_length = None
+
+ self._create_index_if_not_exists(
+ index_name=self.index_name, dims_length=dims_length
+ )
+
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+
+ request = {
+ "_op_type": "index",
+ "_index": self.index_name,
+ self.query_field: text,
+ "metadata": metadata,
+ "_id": ids[i],
+ }
+ if embeddings:
+ request[self.vector_query_field] = embeddings[i]
+
+ requests.append(request)
+
+ if len(requests) > 0:
+ try:
+ success, failed = bulk(
+ self.client,
+ requests,
+ stats_only=True,
+ refresh=refresh_indices,
+ **bulk_kwargs,
+ )
+ logger.debug(
+ f"Added {success} and failed to add {failed} texts to index"
+ )
+
+ logger.debug(f"added texts {ids} to index")
+ return ids
+ except BulkIndexError as e:
+ logger.error(f"Error adding texts: {e}")
+ firstError = e.errors[0].get("index", {}).get("error", {})
+ logger.error(f"First error reason: {firstError.get('reason')}")
+ raise e
+
+ else:
+ logger.debug("No texts to add to index")
+ return []
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ ids: Optional[List[str]] = None,
+ refresh_indices: bool = True,
+ create_index_if_not_exists: bool = True,
+ bulk_kwargs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids to associate with the texts.
+ refresh_indices: Whether to refresh the Elasticsearch indices
+ after adding the texts.
+ create_index_if_not_exists: Whether to create the Elasticsearch
+ index if it doesn't already exist.
+ *bulk_kwargs: Additional arguments to pass to Elasticsearch bulk.
+ - chunk_size: Optional. Number of texts to add to the
+ index at a time. Defaults to 500.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ if self.embedding is not None:
+ # If no search_type requires inference, we use the provided
+ # embedding function to embed the texts.
+ embeddings = self.embedding.embed_documents(list(texts))
+ else:
+ # the search_type doesn't require inference, so we don't need to
+ # embed the texts.
+ embeddings = None
+
+ return self.__add(
+ texts,
+ embeddings,
+ metadatas=metadatas,
+ ids=ids,
+ refresh_indices=refresh_indices,
+ create_index_if_not_exists=create_index_if_not_exists,
+ bulk_kwargs=bulk_kwargs,
+ kwargs=kwargs,
+ )
+
+ def add_embeddings(
+ self,
+ text_embeddings: Iterable[Tuple[str, List[float]]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ refresh_indices: bool = True,
+ create_index_if_not_exists: bool = True,
+ bulk_kwargs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add the given texts and embeddings to the vectorstore.
+
+ Args:
+ text_embeddings: Iterable pairs of string and embedding to
+ add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of unique IDs.
+ refresh_indices: Whether to refresh the Elasticsearch indices
+ after adding the texts.
+ create_index_if_not_exists: Whether to create the Elasticsearch
+ index if it doesn't already exist.
+ *bulk_kwargs: Additional arguments to pass to Elasticsearch bulk.
+ - chunk_size: Optional. Number of texts to add to the
+ index at a time. Defaults to 500.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ texts, embeddings = zip(*text_embeddings)
+ return self.__add(
+ list(texts),
+ list(embeddings),
+ metadatas=metadatas,
+ ids=ids,
+ refresh_indices=refresh_indices,
+ create_index_if_not_exists=create_index_if_not_exists,
+ bulk_kwargs=bulk_kwargs,
+ kwargs=kwargs,
+ )
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[Dict[str, Any]]] = None,
+ bulk_kwargs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> "ElasticsearchStore":
+ """Construct ElasticsearchStore wrapper from raw documents.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticsearchStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ db = ElasticsearchStore.from_texts(
+ texts,
+ // embeddings optional if using
+ // a strategy that doesn't require inference
+ embeddings,
+ index_name="langchain-demo",
+ es_url="http://localhost:9200"
+ )
+
+ Args:
+ texts: List of texts to add to the Elasticsearch index.
+ embedding: Embedding function to use to embed the texts.
+ metadatas: Optional list of metadatas associated with the texts.
+ index_name: Name of the Elasticsearch index to create.
+ es_url: URL of the Elasticsearch instance to connect to.
+ cloud_id: Cloud ID of the Elasticsearch instance to connect to.
+ es_user: Username to use when connecting to Elasticsearch.
+ es_password: Password to use when connecting to Elasticsearch.
+ es_api_key: API key to use when connecting to Elasticsearch.
+ es_connection: Optional pre-existing Elasticsearch connection.
+ vector_query_field: Optional. Name of the field to
+ store the embedding vectors in.
+ query_field: Optional. Name of the field to store the texts in.
+ distance_strategy: Optional. Name of the distance
+ strategy to use. Defaults to "COSINE".
+ can be one of "COSINE",
+ "EUCLIDEAN_DISTANCE", "DOT_PRODUCT".
+ bulk_kwargs: Optional. Additional arguments to pass to
+ Elasticsearch bulk.
+ """
+
+ elasticsearchStore = ElasticsearchStore._create_cls_from_kwargs(
+ embedding=embedding, **kwargs
+ )
+
+ # Encode the provided texts and add them to the newly created index.
+ elasticsearchStore.add_texts(
+ texts, metadatas=metadatas, bulk_kwargs=bulk_kwargs
+ )
+
+ return elasticsearchStore
+
+ @staticmethod
+ def _create_cls_from_kwargs(
+ embedding: Optional[Embeddings] = None, **kwargs: Any
+ ) -> "ElasticsearchStore":
+ index_name = kwargs.get("index_name")
+
+ if index_name is None:
+ raise ValueError("Please provide an index_name.")
+
+ es_connection = kwargs.get("es_connection")
+ es_cloud_id = kwargs.get("es_cloud_id")
+ es_url = kwargs.get("es_url")
+ es_user = kwargs.get("es_user")
+ es_password = kwargs.get("es_password")
+ es_api_key = kwargs.get("es_api_key")
+ vector_query_field = kwargs.get("vector_query_field")
+ query_field = kwargs.get("query_field")
+ distance_strategy = kwargs.get("distance_strategy")
+ strategy = kwargs.get("strategy", ElasticsearchStore.ApproxRetrievalStrategy())
+
+ optional_args = {}
+
+ if vector_query_field is not None:
+ optional_args["vector_query_field"] = vector_query_field
+
+ if query_field is not None:
+ optional_args["query_field"] = query_field
+
+ return ElasticsearchStore(
+ index_name=index_name,
+ embedding=embedding,
+ es_url=es_url,
+ es_connection=es_connection,
+ es_cloud_id=es_cloud_id,
+ es_user=es_user,
+ es_password=es_password,
+ es_api_key=es_api_key,
+ strategy=strategy,
+ distance_strategy=distance_strategy,
+ **optional_args,
+ )
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: List[Document],
+ embedding: Optional[Embeddings] = None,
+ bulk_kwargs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> "ElasticsearchStore":
+ """Construct ElasticsearchStore wrapper from documents.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ElasticsearchStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ db = ElasticsearchStore.from_documents(
+ texts,
+ embeddings,
+ index_name="langchain-demo",
+ es_url="http://localhost:9200"
+ )
+
+ Args:
+ texts: List of texts to add to the Elasticsearch index.
+ embedding: Embedding function to use to embed the texts.
+ Do not provide if using a strategy
+ that doesn't require inference.
+ metadatas: Optional list of metadatas associated with the texts.
+ index_name: Name of the Elasticsearch index to create.
+ es_url: URL of the Elasticsearch instance to connect to.
+ cloud_id: Cloud ID of the Elasticsearch instance to connect to.
+ es_user: Username to use when connecting to Elasticsearch.
+ es_password: Password to use when connecting to Elasticsearch.
+ es_api_key: API key to use when connecting to Elasticsearch.
+ es_connection: Optional pre-existing Elasticsearch connection.
+ vector_query_field: Optional. Name of the field
+ to store the embedding vectors in.
+ query_field: Optional. Name of the field to store the texts in.
+ bulk_kwargs: Optional. Additional arguments to pass to
+ Elasticsearch bulk.
+ """
+
+ elasticsearchStore = ElasticsearchStore._create_cls_from_kwargs(
+ embedding=embedding, **kwargs
+ )
+ # Encode the provided texts and add them to the newly created index.
+ elasticsearchStore.add_documents(documents, bulk_kwargs=bulk_kwargs)
+
+ return elasticsearchStore
+
+ @staticmethod
+ def ExactRetrievalStrategy() -> "ExactRetrievalStrategy":
+ """Used to perform brute force / exact
+ nearest neighbor search via script_score."""
+ return ExactRetrievalStrategy()
+
+ @staticmethod
+ def ApproxRetrievalStrategy(
+ query_model_id: Optional[str] = None,
+ hybrid: Optional[bool] = False,
+ rrf: Optional[Union[dict, bool]] = True,
+ ) -> "ApproxRetrievalStrategy":
+ """Used to perform approximate nearest neighbor search
+ using the HNSW algorithm.
+
+ At build index time, this strategy will create a
+ dense vector field in the index and store the
+ embedding vectors in the index.
+
+ At query time, the text will either be embedded using the
+ provided embedding function or the query_model_id
+ will be used to embed the text using the model
+ deployed to Elasticsearch.
+
+ if query_model_id is used, do not provide an embedding function.
+
+ Args:
+ query_model_id: Optional. ID of the model to use to
+ embed the query text within the stack. Requires
+ embedding model to be deployed to Elasticsearch.
+ hybrid: Optional. If True, will perform a hybrid search
+ using both the knn query and a text query.
+ Defaults to False.
+ rrf: Optional. rrf is Reciprocal Rank Fusion.
+ When `hybrid` is True,
+ and `rrf` is True, then rrf: {}.
+ and `rrf` is False, then rrf is omitted.
+ and isinstance(rrf, dict) is True, then pass in the dict values.
+ rrf could be passed for adjusting 'rank_constant' and 'window_size'.
+ """
+ return ApproxRetrievalStrategy(
+ query_model_id=query_model_id, hybrid=hybrid, rrf=rrf
+ )
+
+ @staticmethod
+ def SparseVectorRetrievalStrategy(
+ model_id: Optional[str] = None,
+ ) -> "SparseRetrievalStrategy":
+ """Used to perform sparse vector search via text_expansion.
+ Used for when you want to use ELSER model to perform document search.
+
+ At build index time, this strategy will create a pipeline that
+ will embed the text using the ELSER model and store the
+ resulting tokens in the index.
+
+ At query time, the text will be embedded using the ELSER
+ model and the resulting tokens will be used to
+ perform a text_expansion query.
+
+ Args:
+ model_id: Optional. Default is ".elser_model_1".
+ ID of the model to use to embed the query text
+ within the stack. Requires embedding model to be
+ deployed to Elasticsearch.
+ """
+ return SparseRetrievalStrategy(model_id=model_id)
diff --git a/libs/community/langchain_community/vectorstores/epsilla.py b/libs/community/langchain_community/vectorstores/epsilla.py
new file mode 100644
index 00000000000..b41a07d8025
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/epsilla.py
@@ -0,0 +1,375 @@
+"""Wrapper around Epsilla vector database."""
+from __future__ import annotations
+
+import logging
+import uuid
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ from pyepsilla import vectordb
+
+logger = logging.getLogger()
+
+
+class Epsilla(VectorStore):
+ """
+ Wrapper around Epsilla vector database.
+
+ As a prerequisite, you need to install ``pyepsilla`` package
+ and have a running Epsilla vector database (for example, through our docker image)
+ See the following documentation for how to run an Epsilla vector database:
+ https://epsilla-inc.gitbook.io/epsilladb/quick-start
+
+ Args:
+ client (Any): Epsilla client to connect to.
+ embeddings (Embeddings): Function used to embed the texts.
+ db_path (Optional[str]): The path where the database will be persisted.
+ Defaults to "/tmp/langchain-epsilla".
+ db_name (Optional[str]): Give a name to the loaded database.
+ Defaults to "langchain_store".
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Epsilla
+ from pyepsilla import vectordb
+
+ client = vectordb.Client()
+ embeddings = OpenAIEmbeddings()
+ db_path = "/tmp/vectorstore"
+ db_name = "langchain_store"
+ epsilla = Epsilla(client, embeddings, db_path, db_name)
+ """
+
+ _LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
+ _LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
+ _LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
+
+ def __init__(
+ self,
+ client: Any,
+ embeddings: Embeddings,
+ db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
+ db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
+ ):
+ """Initialize with necessary components."""
+ try:
+ import pyepsilla
+ except ImportError as e:
+ raise ImportError(
+ "Could not import pyepsilla python package. "
+ "Please install pyepsilla package with `pip install pyepsilla`."
+ ) from e
+
+ if not isinstance(client, pyepsilla.vectordb.Client):
+ raise TypeError(
+ f"client should be an instance of pyepsilla.vectordb.Client, "
+ f"got {type(client)}"
+ )
+
+ self._client: vectordb.Client = client
+ self._db_name = db_name
+ self._embeddings = embeddings
+ self._collection_name = Epsilla._LANGCHAIN_DEFAULT_TABLE_NAME
+ self._client.load_db(db_name=db_name, db_path=db_path)
+ self._client.use_db(db_name=db_name)
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embeddings
+
+ def use_collection(self, collection_name: str) -> None:
+ """
+ Set default collection to use.
+
+ Args:
+ collection_name (str): The name of the collection.
+ """
+ self._collection_name = collection_name
+
+ def clear_data(self, collection_name: str = "") -> None:
+ """
+ Clear data in a collection.
+
+ Args:
+ collection_name (Optional[str]): The name of the collection.
+ If not provided, the default collection will be used.
+ """
+ if not collection_name:
+ collection_name = self._collection_name
+ self._client.drop_table(collection_name)
+
+ def get(
+ self, collection_name: str = "", response_fields: Optional[List[str]] = None
+ ) -> List[dict]:
+ """Get the collection.
+
+ Args:
+ collection_name (Optional[str]): The name of the collection
+ to retrieve data from.
+ If not provided, the default collection will be used.
+ response_fields (Optional[List[str]]): List of field names in the result.
+ If not specified, all available fields will be responded.
+
+ Returns:
+ A list of the retrieved data.
+ """
+ if not collection_name:
+ collection_name = self._collection_name
+ status_code, response = self._client.get(
+ table_name=collection_name, response_fields=response_fields
+ )
+ if status_code != 200:
+ logger.error(f"Failed to get records: {response['message']}")
+ raise Exception("Error: {}.".format(response["message"]))
+ return response["result"]
+
+ def _create_collection(
+ self, table_name: str, embeddings: list, metadatas: Optional[list[dict]] = None
+ ) -> None:
+ if not embeddings:
+ raise ValueError("Embeddings list is empty.")
+
+ dim = len(embeddings[0])
+ fields: List[dict] = [
+ {"name": "id", "dataType": "INT"},
+ {"name": "text", "dataType": "STRING"},
+ {"name": "embeddings", "dataType": "VECTOR_FLOAT", "dimensions": dim},
+ ]
+ if metadatas is not None:
+ field_names = [field["name"] for field in fields]
+ for metadata in metadatas:
+ for key, value in metadata.items():
+ if key in field_names:
+ continue
+ d_type: str
+ if isinstance(value, str):
+ d_type = "STRING"
+ elif isinstance(value, int):
+ d_type = "INT"
+ elif isinstance(value, float):
+ d_type = "FLOAT"
+ elif isinstance(value, bool):
+ d_type = "BOOL"
+ else:
+ raise ValueError(f"Unsupported data type for {key}.")
+ fields.append({"name": key, "dataType": d_type})
+ field_names.append(key)
+
+ status_code, response = self._client.create_table(
+ table_name, table_fields=fields
+ )
+ if status_code != 200:
+ if status_code == 409:
+ logger.info(f"Continuing with the existing table {table_name}.")
+ else:
+ logger.error(
+ f"Failed to create collection {table_name}: {response['message']}"
+ )
+ raise Exception("Error: {}.".format(response["message"]))
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ collection_name: Optional[str] = "",
+ drop_old: Optional[bool] = False,
+ **kwargs: Any,
+ ) -> List[str]:
+ """
+ Embed texts and add them to the database.
+
+ Args:
+ texts (Iterable[str]): The texts to embed.
+ metadatas (Optional[List[dict]]): Metadata dicts
+ attached to each of the texts. Defaults to None.
+ collection_name (Optional[str]): Which collection to use.
+ Defaults to "langchain_collection".
+ If provided, default collection name will be set as well.
+ drop_old (Optional[bool]): Whether to drop the previous collection
+ and create a new one. Defaults to False.
+
+ Returns:
+ List of ids of the added texts.
+ """
+ if not collection_name:
+ collection_name = self._collection_name
+ else:
+ self._collection_name = collection_name
+
+ if drop_old:
+ self._client.drop_db(db_name=collection_name)
+
+ texts = list(texts)
+ try:
+ embeddings = self._embeddings.embed_documents(texts)
+ except NotImplementedError:
+ embeddings = [self._embeddings.embed_query(x) for x in texts]
+
+ if len(embeddings) == 0:
+ logger.debug("Nothing to insert, skipping.")
+ return []
+
+ self._create_collection(
+ table_name=collection_name, embeddings=embeddings, metadatas=metadatas
+ )
+
+ ids = [hash(uuid.uuid4()) for _ in texts]
+ records = []
+ for index, id in enumerate(ids):
+ record = {
+ "id": id,
+ "text": texts[index],
+ "embeddings": embeddings[index],
+ }
+ if metadatas is not None:
+ metadata = metadatas[index].items()
+ for key, value in metadata:
+ record[key] = value
+ records.append(record)
+
+ status_code, response = self._client.insert(
+ table_name=collection_name, records=records
+ )
+ if status_code != 200:
+ logger.error(
+ f"Failed to add records to {collection_name}: {response['message']}"
+ )
+ raise Exception("Error: {}.".format(response["message"]))
+ return [str(id) for id in ids]
+
+ def similarity_search(
+ self, query: str, k: int = 4, collection_name: str = "", **kwargs: Any
+ ) -> List[Document]:
+ """
+ Return the documents that are semantically most relevant to the query.
+
+ Args:
+ query (str): String to query the vectorstore with.
+ k (Optional[int]): Number of documents to return. Defaults to 4.
+ collection_name (Optional[str]): Collection to use.
+ Defaults to "langchain_store" or the one provided before.
+ Returns:
+ List of documents that are semantically most relevant to the query
+ """
+ if not collection_name:
+ collection_name = self._collection_name
+ query_vector = self._embeddings.embed_query(query)
+ status_code, response = self._client.query(
+ table_name=collection_name,
+ query_field="embeddings",
+ query_vector=query_vector,
+ limit=k,
+ )
+ if status_code != 200:
+ logger.error(f"Search failed: {response['message']}.")
+ raise Exception("Error: {}.".format(response["message"]))
+
+ exclude_keys = ["id", "text", "embeddings"]
+ return list(
+ map(
+ lambda item: Document(
+ page_content=item["text"],
+ metadata={
+ key: item[key] for key in item if key not in exclude_keys
+ },
+ ),
+ response["result"],
+ )
+ )
+
+ @classmethod
+ def from_texts(
+ cls: Type[Epsilla],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ client: Any = None,
+ db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
+ db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
+ collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
+ drop_old: Optional[bool] = False,
+ **kwargs: Any,
+ ) -> Epsilla:
+ """Create an Epsilla vectorstore from raw documents.
+
+ Args:
+ texts (List[str]): List of text data to be inserted.
+ embeddings (Embeddings): Embedding function.
+ client (pyepsilla.vectordb.Client): Epsilla client to connect to.
+ metadatas (Optional[List[dict]]): Metadata for each text.
+ Defaults to None.
+ db_path (Optional[str]): The path where the database will be persisted.
+ Defaults to "/tmp/langchain-epsilla".
+ db_name (Optional[str]): Give a name to the loaded database.
+ Defaults to "langchain_store".
+ collection_name (Optional[str]): Which collection to use.
+ Defaults to "langchain_collection".
+ If provided, default collection name will be set as well.
+ drop_old (Optional[bool]): Whether to drop the previous collection
+ and create a new one. Defaults to False.
+
+ Returns:
+ Epsilla: Epsilla vector store.
+ """
+ instance = Epsilla(client, embedding, db_path=db_path, db_name=db_name)
+ instance.add_texts(
+ texts,
+ metadatas=metadatas,
+ collection_name=collection_name,
+ drop_old=drop_old,
+ **kwargs,
+ )
+
+ return instance
+
+ @classmethod
+ def from_documents(
+ cls: Type[Epsilla],
+ documents: List[Document],
+ embedding: Embeddings,
+ client: Any = None,
+ db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
+ db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
+ collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
+ drop_old: Optional[bool] = False,
+ **kwargs: Any,
+ ) -> Epsilla:
+ """Create an Epsilla vectorstore from a list of documents.
+
+ Args:
+ texts (List[str]): List of text data to be inserted.
+ embeddings (Embeddings): Embedding function.
+ client (pyepsilla.vectordb.Client): Epsilla client to connect to.
+ metadatas (Optional[List[dict]]): Metadata for each text.
+ Defaults to None.
+ db_path (Optional[str]): The path where the database will be persisted.
+ Defaults to "/tmp/langchain-epsilla".
+ db_name (Optional[str]): Give a name to the loaded database.
+ Defaults to "langchain_store".
+ collection_name (Optional[str]): Which collection to use.
+ Defaults to "langchain_collection".
+ If provided, default collection name will be set as well.
+ drop_old (Optional[bool]): Whether to drop the previous collection
+ and create a new one. Defaults to False.
+
+ Returns:
+ Epsilla: Epsilla vector store.
+ """
+ texts = [doc.page_content for doc in documents]
+ metadatas = [doc.metadata for doc in documents]
+
+ return cls.from_texts(
+ texts,
+ embedding,
+ metadatas=metadatas,
+ client=client,
+ db_path=db_path,
+ db_name=db_name,
+ collection_name=collection_name,
+ drop_old=drop_old,
+ **kwargs,
+ )
diff --git a/libs/community/langchain_community/vectorstores/faiss.py b/libs/community/langchain_community/vectorstores/faiss.py
new file mode 100644
index 00000000000..7473af4a7b5
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/faiss.py
@@ -0,0 +1,1167 @@
+from __future__ import annotations
+
+import asyncio
+import logging
+import operator
+import os
+import pickle
+import uuid
+import warnings
+from functools import partial
+from pathlib import Path
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sized,
+ Tuple,
+ Union,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.docstore.base import AddableMixin, Docstore
+from langchain_community.docstore.in_memory import InMemoryDocstore
+from langchain_community.vectorstores.utils import (
+ DistanceStrategy,
+ maximal_marginal_relevance,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
+ """
+ Import faiss if available, otherwise raise error.
+ If FAISS_NO_AVX2 environment variable is set, it will be considered
+ to load FAISS with no AVX2 optimization.
+
+ Args:
+ no_avx2: Load FAISS strictly with no AVX2 optimization
+ so that the vectorstore is portable and compatible with other devices.
+ """
+ if no_avx2 is None and "FAISS_NO_AVX2" in os.environ:
+ no_avx2 = bool(os.getenv("FAISS_NO_AVX2"))
+
+ try:
+ if no_avx2:
+ from faiss import swigfaiss as faiss
+ else:
+ import faiss
+ except ImportError:
+ raise ImportError(
+ "Could not import faiss python package. "
+ "Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
+ "or `pip install faiss-cpu` (depending on Python version)."
+ )
+ return faiss
+
+
+def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None:
+ if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y):
+ raise ValueError(
+ f"{x_name} and {y_name} expected to be equal length but "
+ f"len({x_name})={len(x)} and len({y_name})={len(y)}"
+ )
+ return
+
+
+class FAISS(VectorStore):
+ """`Meta Faiss` vector store.
+
+ To use, you must have the ``faiss`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ from langchain_community.vectorstores import FAISS
+
+ embeddings = OpenAIEmbeddings()
+ texts = ["FAISS is an important library", "LangChain supports FAISS"]
+ faiss = FAISS.from_texts(texts, embeddings)
+
+ """
+
+ def __init__(
+ self,
+ embedding_function: Union[
+ Callable[[str], List[float]],
+ Embeddings,
+ ],
+ index: Any,
+ docstore: Docstore,
+ index_to_docstore_id: Dict[int, str],
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ normalize_L2: bool = False,
+ distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
+ ):
+ """Initialize with necessary components."""
+ if not isinstance(embedding_function, Embeddings):
+ logger.warning(
+ "`embedding_function` is expected to be an Embeddings object, support "
+ "for passing in a function will soon be removed."
+ )
+ self.embedding_function = embedding_function
+ self.index = index
+ self.docstore = docstore
+ self.index_to_docstore_id = index_to_docstore_id
+ self.distance_strategy = distance_strategy
+ self.override_relevance_score_fn = relevance_score_fn
+ self._normalize_L2 = normalize_L2
+ if (
+ self.distance_strategy != DistanceStrategy.EUCLIDEAN_DISTANCE
+ and self._normalize_L2
+ ):
+ warnings.warn(
+ "Normalizing L2 is not applicable for metric type: {strategy}".format(
+ strategy=self.distance_strategy
+ )
+ )
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return (
+ self.embedding_function
+ if isinstance(self.embedding_function, Embeddings)
+ else None
+ )
+
+ def _embed_documents(self, texts: List[str]) -> List[List[float]]:
+ if isinstance(self.embedding_function, Embeddings):
+ return self.embedding_function.embed_documents(texts)
+ else:
+ return [self.embedding_function(text) for text in texts]
+
+ async def _aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ if isinstance(self.embedding_function, Embeddings):
+ return await self.embedding_function.aembed_documents(texts)
+ else:
+ # return await asyncio.gather(
+ # [self.embedding_function(text) for text in texts]
+ # )
+ raise Exception(
+ "`embedding_function` is expected to be an Embeddings object, support "
+ "for passing in a function will soon be removed."
+ )
+
+ def _embed_query(self, text: str) -> List[float]:
+ if isinstance(self.embedding_function, Embeddings):
+ return self.embedding_function.embed_query(text)
+ else:
+ return self.embedding_function(text)
+
+ async def _aembed_query(self, text: str) -> List[float]:
+ if isinstance(self.embedding_function, Embeddings):
+ return await self.embedding_function.aembed_query(text)
+ else:
+ # return await self.embedding_function(text)
+ raise Exception(
+ "`embedding_function` is expected to be an Embeddings object, support "
+ "for passing in a function will soon be removed."
+ )
+
+ def __add(
+ self,
+ texts: Iterable[str],
+ embeddings: Iterable[List[float]],
+ metadatas: Optional[Iterable[dict]] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[str]:
+ faiss = dependable_faiss_import()
+
+ if not isinstance(self.docstore, AddableMixin):
+ raise ValueError(
+ "If trying to add texts, the underlying docstore should support "
+ f"adding items, which {self.docstore} does not"
+ )
+
+ _len_check_if_sized(texts, metadatas, "texts", "metadatas")
+ _metadatas = metadatas or ({} for _ in texts)
+ documents = [
+ Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas)
+ ]
+
+ _len_check_if_sized(documents, embeddings, "documents", "embeddings")
+ _len_check_if_sized(documents, ids, "documents", "ids")
+
+ # Add to the index.
+ vector = np.array(embeddings, dtype=np.float32)
+ if self._normalize_L2:
+ faiss.normalize_L2(vector)
+ self.index.add(vector)
+
+ # Add information to docstore and index.
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
+ self.docstore.add({id_: doc for id_, doc in zip(ids, documents)})
+ starting_len = len(self.index_to_docstore_id)
+ index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)}
+ self.index_to_docstore_id.update(index_to_id)
+ return ids
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of unique IDs.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ texts = list(texts)
+ embeddings = self._embed_documents(texts)
+ return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
+
+ async def aadd_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore
+ asynchronously.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of unique IDs.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ texts = list(texts)
+ embeddings = await self._aembed_documents(texts)
+ return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
+
+ def add_embeddings(
+ self,
+ text_embeddings: Iterable[Tuple[str, List[float]]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add the given texts and embeddings to the vectorstore.
+
+ Args:
+ text_embeddings: Iterable pairs of string and embedding to
+ add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of unique IDs.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ # Embed and create the documents.
+ texts, embeddings = zip(*text_embeddings)
+ return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ embedding: Embedding vector to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+ **kwargs: kwargs to be passed to similarity search. Can include:
+ score_threshold: Optional, a floating point value between 0 to 1 to
+ filter the resulting set of retrieved docs
+
+ Returns:
+ List of documents most similar to the query text and L2 distance
+ in float for each. Lower score represents more similarity.
+ """
+ faiss = dependable_faiss_import()
+ vector = np.array([embedding], dtype=np.float32)
+ if self._normalize_L2:
+ faiss.normalize_L2(vector)
+ scores, indices = self.index.search(vector, k if filter is None else fetch_k)
+ docs = []
+ for j, i in enumerate(indices[0]):
+ if i == -1:
+ # This happens when not enough docs are returned.
+ continue
+ _id = self.index_to_docstore_id[i]
+ doc = self.docstore.search(_id)
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {_id}, got {doc}")
+ if filter is not None:
+ filter = {
+ key: [value] if not isinstance(value, list) else value
+ for key, value in filter.items()
+ }
+ if all(doc.metadata.get(key) in value for key, value in filter.items()):
+ docs.append((doc, scores[0][j]))
+ else:
+ docs.append((doc, scores[0][j]))
+
+ score_threshold = kwargs.get("score_threshold")
+ if score_threshold is not None:
+ cmp = (
+ operator.ge
+ if self.distance_strategy
+ in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
+ else operator.le
+ )
+ docs = [
+ (doc, similarity)
+ for doc, similarity in docs
+ if cmp(similarity, score_threshold)
+ ]
+ return docs[:k]
+
+ async def asimilarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query asynchronously.
+
+ Args:
+ embedding: Embedding vector to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+ **kwargs: kwargs to be passed to similarity search. Can include:
+ score_threshold: Optional, a floating point value between 0 to 1 to
+ filter the resulting set of retrieved docs
+
+ Returns:
+ List of documents most similar to the query text and L2 distance
+ in float for each. Lower score represents more similarity.
+ """
+
+ # This is a temporary workaround to make the similarity search asynchronous.
+ func = partial(
+ self.similarity_search_with_score_by_vector,
+ embedding,
+ k=k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return await asyncio.get_event_loop().run_in_executor(None, func)
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of documents most similar to the query text with
+ L2 distance in float. Lower score represents more similarity.
+ """
+ embedding = self._embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return docs
+
+ async def asimilarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query asynchronously.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of documents most similar to the query text with
+ L2 distance in float. Lower score represents more similarity.
+ """
+ embedding = await self._aembed_query(query)
+ docs = await self.asimilarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of Documents most similar to the embedding.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ async def asimilarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector asynchronously.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of Documents most similar to the embedding.
+ """
+ docs_and_scores = await self.asimilarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query, k, filter=filter, fetch_k=fetch_k, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ async def asimilarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query asynchronously.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ docs_and_scores = await self.asimilarity_search_with_score(
+ query, k, filter=filter, fetch_k=fetch_k, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def max_marginal_relevance_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ *,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and their similarity scores selected using the maximal marginal
+ relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents and similarity scores selected by maximal marginal
+ relevance and score for each.
+ """
+ scores, indices = self.index.search(
+ np.array([embedding], dtype=np.float32),
+ fetch_k if filter is None else fetch_k * 2,
+ )
+ if filter is not None:
+ filtered_indices = []
+ for i in indices[0]:
+ if i == -1:
+ # This happens when not enough docs are returned.
+ continue
+ _id = self.index_to_docstore_id[i]
+ doc = self.docstore.search(_id)
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {_id}, got {doc}")
+ if all(
+ doc.metadata.get(key) in value
+ if isinstance(value, list)
+ else doc.metadata.get(key) == value
+ for key, value in filter.items()
+ ):
+ filtered_indices.append(i)
+ indices = np.array([filtered_indices])
+ # -1 happens when not enough docs are returned.
+ embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
+ mmr_selected = maximal_marginal_relevance(
+ np.array([embedding], dtype=np.float32),
+ embeddings,
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ selected_indices = [indices[0][i] for i in mmr_selected]
+ selected_scores = [scores[0][i] for i in mmr_selected]
+ docs_and_scores = []
+ for i, score in zip(selected_indices, selected_scores):
+ if i == -1:
+ # This happens when not enough docs are returned.
+ continue
+ _id = self.index_to_docstore_id[i]
+ doc = self.docstore.search(_id)
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {_id}, got {doc}")
+ docs_and_scores.append((doc, score))
+ return docs_and_scores
+
+ async def amax_marginal_relevance_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ *,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and their similarity scores selected using the maximal marginal
+ relevance asynchronously.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents and similarity scores selected by maximal marginal
+ relevance and score for each.
+ """
+ # This is a temporary workaround to make the similarity search asynchronous.
+ func = partial(
+ self.max_marginal_relevance_search_with_score_by_vector,
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ )
+ return await asyncio.get_event_loop().run_in_executor(None, func)
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
+ embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ async def amax_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance asynchronously.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ docs_and_scores = (
+ await self.amax_marginal_relevance_search_with_score_by_vector(
+ embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
+ )
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering (if needed) to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding = self._embed_query(query)
+ docs = self.max_marginal_relevance_search_by_vector(
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ **kwargs,
+ )
+ return docs
+
+ async def amax_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance asynchronously.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering (if needed) to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding = await self._aembed_query(query)
+ docs = await self.amax_marginal_relevance_search_by_vector(
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ **kwargs,
+ )
+ return docs
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete by ID. These are the IDs in the vectorstore.
+
+ Args:
+ ids: List of ids to delete.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+ missing_ids = set(ids).difference(self.index_to_docstore_id.values())
+ if missing_ids:
+ raise ValueError(
+ f"Some specified ids do not exist in the current store. Ids not found: "
+ f"{missing_ids}"
+ )
+
+ reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()}
+ index_to_delete = [reversed_index[id_] for id_ in ids]
+
+ self.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
+ self.docstore.delete(ids)
+
+ remaining_ids = [
+ id_
+ for i, id_ in sorted(self.index_to_docstore_id.items())
+ if i not in index_to_delete
+ ]
+ self.index_to_docstore_id = {i: id_ for i, id_ in enumerate(remaining_ids)}
+
+ return True
+
+ def merge_from(self, target: FAISS) -> None:
+ """Merge another FAISS object with the current one.
+
+ Add the target FAISS to the current one.
+
+ Args:
+ target: FAISS object you wish to merge into the current one
+
+ Returns:
+ None.
+ """
+ if not isinstance(self.docstore, AddableMixin):
+ raise ValueError("Cannot merge with this type of docstore")
+ # Numerical index for target docs are incremental on existing ones
+ starting_len = len(self.index_to_docstore_id)
+
+ # Merge two IndexFlatL2
+ self.index.merge_from(target.index)
+
+ # Get id and docs from target FAISS object
+ full_info = []
+ for i, target_id in target.index_to_docstore_id.items():
+ doc = target.docstore.search(target_id)
+ if not isinstance(doc, Document):
+ raise ValueError("Document should be returned")
+ full_info.append((starting_len + i, target_id, doc))
+
+ # Add information to docstore and index_to_docstore_id.
+ self.docstore.add({_id: doc for _, _id, doc in full_info})
+ index_to_id = {index: _id for index, _id, _ in full_info}
+ self.index_to_docstore_id.update(index_to_id)
+
+ @classmethod
+ def __from(
+ cls,
+ texts: Iterable[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ metadatas: Optional[Iterable[dict]] = None,
+ ids: Optional[List[str]] = None,
+ normalize_L2: bool = False,
+ distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
+ **kwargs: Any,
+ ) -> FAISS:
+ faiss = dependable_faiss_import()
+ if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ index = faiss.IndexFlatIP(len(embeddings[0]))
+ else:
+ # Default to L2, currently other metric types not initialized.
+ index = faiss.IndexFlatL2(len(embeddings[0]))
+ vecstore = cls(
+ embedding,
+ index,
+ InMemoryDocstore(),
+ {},
+ normalize_L2=normalize_L2,
+ distance_strategy=distance_strategy,
+ **kwargs,
+ )
+ vecstore.__add(texts, embeddings, metadatas=metadatas, ids=ids)
+ return vecstore
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> FAISS:
+ """Construct FAISS wrapper from raw documents.
+
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Creates an in memory docstore
+ 3. Initializes the FAISS database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import FAISS
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ faiss = FAISS.from_texts(texts, embeddings)
+ """
+ embeddings = embedding.embed_documents(texts)
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ **kwargs,
+ )
+
+ @classmethod
+ async def afrom_texts(
+ cls,
+ texts: list[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> FAISS:
+ """Construct FAISS wrapper from raw documents asynchronously.
+
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Creates an in memory docstore
+ 3. Initializes the FAISS database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import FAISS
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ faiss = await FAISS.afrom_texts(texts, embeddings)
+ """
+ embeddings = await embedding.aembed_documents(texts)
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: Iterable[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[Iterable[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> FAISS:
+ """Construct FAISS wrapper from raw documents.
+
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Creates an in memory docstore
+ 3. Initializes the FAISS database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import FAISS
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = zip(texts, text_embeddings)
+ faiss = FAISS.from_embeddings(text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ **kwargs,
+ )
+
+ @classmethod
+ async def afrom_embeddings(
+ cls,
+ text_embeddings: Iterable[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[Iterable[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> FAISS:
+ """Construct FAISS wrapper from raw documents asynchronously."""
+ return cls.from_embeddings(
+ text_embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ **kwargs,
+ )
+
+ def save_local(self, folder_path: str, index_name: str = "index") -> None:
+ """Save FAISS index, docstore, and index_to_docstore_id to disk.
+
+ Args:
+ folder_path: folder path to save index, docstore,
+ and index_to_docstore_id to.
+ index_name: for saving with a specific index file name
+ """
+ path = Path(folder_path)
+ path.mkdir(exist_ok=True, parents=True)
+
+ # save index separately since it is not picklable
+ faiss = dependable_faiss_import()
+ faiss.write_index(
+ self.index, str(path / "{index_name}.faiss".format(index_name=index_name))
+ )
+
+ # save docstore and index_to_docstore_id
+ with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
+ pickle.dump((self.docstore, self.index_to_docstore_id), f)
+
+ @classmethod
+ def load_local(
+ cls,
+ folder_path: str,
+ embeddings: Embeddings,
+ index_name: str = "index",
+ **kwargs: Any,
+ ) -> FAISS:
+ """Load FAISS index, docstore, and index_to_docstore_id from disk.
+
+ Args:
+ folder_path: folder path to load index, docstore,
+ and index_to_docstore_id from.
+ embeddings: Embeddings to use when generating queries
+ index_name: for saving with a specific index file name
+ asynchronous: whether to use async version or not
+ """
+ path = Path(folder_path)
+ # load index separately since it is not picklable
+ faiss = dependable_faiss_import()
+ index = faiss.read_index(
+ str(path / "{index_name}.faiss".format(index_name=index_name))
+ )
+
+ # load docstore and index_to_docstore_id
+ with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
+ docstore, index_to_docstore_id = pickle.load(f)
+ return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
+
+ def serialize_to_bytes(self) -> bytes:
+ """Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
+ return pickle.dumps((self.index, self.docstore, self.index_to_docstore_id))
+
+ @classmethod
+ def deserialize_from_bytes(
+ cls,
+ serialized: bytes,
+ embeddings: Embeddings,
+ **kwargs: Any,
+ ) -> FAISS:
+ """Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
+ index, docstore, index_to_docstore_id = pickle.loads(serialized)
+ return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+ if self.override_relevance_score_fn is not None:
+ return self.override_relevance_score_fn
+
+ # Default strategy is to rely on distance strategy provided in
+ # vectorstore constructor
+ if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ return self._max_inner_product_relevance_score_fn
+ elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
+ # Default behavior is to use euclidean distance relevancy
+ return self._euclidean_relevance_score_fn
+ elif self.distance_strategy == DistanceStrategy.COSINE:
+ return self._cosine_relevance_score_fn
+ else:
+ raise ValueError(
+ "Unknown distance strategy, must be cosine, max_inner_product,"
+ " or euclidean"
+ )
+
+ def _similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and their similarity scores on a scale from 0 to 1."""
+ # Pop score threshold so that only relevancy scores, not raw scores, are
+ # filtered.
+ relevance_score_fn = self._select_relevance_score_fn()
+ if relevance_score_fn is None:
+ raise ValueError(
+ "normalize_score_fn must be provided to"
+ " FAISS constructor to normalize scores"
+ )
+ docs_and_scores = self.similarity_search_with_score(
+ query,
+ k=k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ docs_and_rel_scores = [
+ (doc, relevance_score_fn(score)) for doc, score in docs_and_scores
+ ]
+ return docs_and_rel_scores
+
+ async def _asimilarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and their similarity scores on a scale from 0 to 1."""
+ # Pop score threshold so that only relevancy scores, not raw scores, are
+ # filtered.
+ relevance_score_fn = self._select_relevance_score_fn()
+ if relevance_score_fn is None:
+ raise ValueError(
+ "normalize_score_fn must be provided to"
+ " FAISS constructor to normalize scores"
+ )
+ docs_and_scores = await self.asimilarity_search_with_score(
+ query,
+ k=k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ docs_and_rel_scores = [
+ (doc, relevance_score_fn(score)) for doc, score in docs_and_scores
+ ]
+ return docs_and_rel_scores
diff --git a/libs/community/langchain_community/vectorstores/hippo.py b/libs/community/langchain_community/vectorstores/hippo.py
new file mode 100644
index 00000000000..328a2f1e992
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/hippo.py
@@ -0,0 +1,677 @@
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ from transwarp_hippo_api.hippo_client import HippoClient
+
+# Default connection
+DEFAULT_HIPPO_CONNECTION = {
+ "host": "localhost",
+ "port": "7788",
+ "username": "admin",
+ "password": "admin",
+}
+
+logger = logging.getLogger(__name__)
+
+
+class Hippo(VectorStore):
+ """`Hippo` vector store.
+
+ You need to install `hippo-api` and run Hippo.
+
+ Please visit our official website for how to run a Hippo instance:
+ https://www.transwarp.cn/starwarp
+
+ Args:
+ embedding_function (Embeddings): Function used to embed the text.
+ table_name (str): Which Hippo table to use. Defaults to
+ "test".
+ database_name (str): Which Hippo database to use. Defaults to
+ "default".
+ number_of_shards (int): The number of shards for the Hippo table.Defaults to
+ 1.
+ number_of_replicas (int): The number of replicas for the Hippo table.Defaults to
+ 1.
+ connection_args (Optional[dict[str, any]]): The connection args used for
+ this class comes in the form of a dict.
+ index_params (Optional[dict]): Which index params to use. Defaults to
+ IVF_FLAT.
+ drop_old (Optional[bool]): Whether to drop the current collection. Defaults
+ to False.
+ primary_field (str): Name of the primary key field. Defaults to "pk".
+ text_field (str): Name of the text field. Defaults to "text".
+ vector_field (str): Name of the vector field. Defaults to "vector".
+
+ The connection args used for this class comes in the form of a dict,
+ here are a few of the options:
+ host (str): The host of Hippo instance. Default at "localhost".
+ port (str/int): The port of Hippo instance. Default at 7788.
+ user (str): Use which user to connect to Hippo instance. If user and
+ password are provided, we will add related header in every RPC call.
+ password (str): Required when user is provided. The password
+ corresponding to the user.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Hippo
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embedding = OpenAIEmbeddings()
+ # Connect to a hippo instance on localhost
+ vector_store = Hippo.from_documents(
+ docs,
+ embedding=embeddings,
+ table_name="langchain_test",
+ connection_args=HIPPO_CONNECTION
+ )
+
+ Raises:
+ ValueError: If the hippo-api python package is not installed.
+ """
+
+ def __init__(
+ self,
+ embedding_function: Embeddings,
+ table_name: str = "test",
+ database_name: str = "default",
+ number_of_shards: int = 1,
+ number_of_replicas: int = 1,
+ connection_args: Optional[Dict[str, Any]] = None,
+ index_params: Optional[dict] = None,
+ drop_old: Optional[bool] = False,
+ ):
+ self.number_of_shards = number_of_shards
+ self.number_of_replicas = number_of_replicas
+ self.embedding_func = embedding_function
+ self.table_name = table_name
+ self.database_name = database_name
+ self.index_params = index_params
+
+ # In order for a collection to be compatible,
+ # 'pk' should be an auto-increment primary key and string
+ self._primary_field = "pk"
+ # In order for compatibility, the text field will need to be called "text"
+ self._text_field = "text"
+ # In order for compatibility, the vector field needs to be called "vector"
+ self._vector_field = "vector"
+ self.fields: List[str] = []
+ # Create the connection to the server
+ if connection_args is None:
+ connection_args = DEFAULT_HIPPO_CONNECTION
+ self.hc = self._create_connection_alias(connection_args)
+ self.col: Any = None
+
+ # If the collection exists, delete it
+ try:
+ if (
+ self.hc.check_table_exists(self.table_name, self.database_name)
+ and drop_old
+ ):
+ self.hc.delete_table(self.table_name, self.database_name)
+ except Exception as e:
+ logging.error(
+ f"An error occurred while deleting the table " f"{self.table_name}: {e}"
+ )
+ raise
+
+ try:
+ if self.hc.check_table_exists(self.table_name, self.database_name):
+ self.col = self.hc.get_table(self.table_name, self.database_name)
+ except Exception as e:
+ logging.error(
+ f"An error occurred while getting the table " f"{self.table_name}: {e}"
+ )
+ raise
+
+ # Initialize the vector database
+ self._get_env()
+
+ def _create_connection_alias(self, connection_args: dict) -> HippoClient:
+ """Create the connection to the Hippo server."""
+ # Grab the connection arguments that are used for checking existing connection
+ try:
+ from transwarp_hippo_api.hippo_client import HippoClient
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import transwarp_hipp_api, please install with "
+ "`pip install hippo-api`."
+ ) from e
+
+ host: str = connection_args.get("host", None)
+ port: int = connection_args.get("port", None)
+ username: str = connection_args.get("username", "shiva")
+ password: str = connection_args.get("password", "shiva")
+
+ # Order of use is host/port, uri, address
+ if host is not None and port is not None:
+ if "," in host:
+ hosts = host.split(",")
+ given_address = ",".join([f"{h}:{port}" for h in hosts])
+ else:
+ given_address = str(host) + ":" + str(port)
+ else:
+ raise ValueError("Missing standard address type for reuse attempt")
+
+ try:
+ logger.info(f"create HippoClient[{given_address}]")
+ return HippoClient([given_address], username=username, pwd=password)
+ except Exception as e:
+ logger.error("Failed to create new connection")
+ raise e
+
+ def _get_env(
+ self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None
+ ) -> None:
+ logger.info("init ...")
+ if embeddings is not None:
+ logger.info("create collection")
+ self._create_collection(embeddings, metadatas)
+ self._extract_fields()
+ self._create_index()
+
+ def _create_collection(
+ self, embeddings: list, metadatas: Optional[List[dict]] = None
+ ) -> None:
+ from transwarp_hippo_api.hippo_client import HippoField
+ from transwarp_hippo_api.hippo_type import HippoType
+
+ # Determine embedding dim
+ dim = len(embeddings[0])
+ logger.debug(f"[_create_collection] dim: {dim}")
+ fields = []
+
+ # Create the primary key field
+ fields.append(HippoField(self._primary_field, True, HippoType.STRING))
+
+ # Create the text field
+
+ fields.append(HippoField(self._text_field, False, HippoType.STRING))
+
+ # Create the vector field, supports binary or float vectors
+ # to The binary vector type is to be developed.
+ fields.append(
+ HippoField(
+ self._vector_field,
+ False,
+ HippoType.FLOAT_VECTOR,
+ type_params={"dimension": dim},
+ )
+ )
+ # to In Hippo,there is no method similar to the infer_type_data
+ # types, so currently all non-vector data is converted to string type.
+
+ if metadatas:
+ # # Create FieldSchema for each entry in metadata.
+ for key, value in metadatas[0].items():
+ # # Infer the corresponding datatype of the metadata
+ if isinstance(value, list):
+ value_dim = len(value)
+ fields.append(
+ HippoField(
+ key,
+ False,
+ HippoType.FLOAT_VECTOR,
+ type_params={"dimension": value_dim},
+ )
+ )
+ else:
+ fields.append(HippoField(key, False, HippoType.STRING))
+
+ logger.debug(f"[_create_collection] fields: {fields}")
+
+ # Create the collection
+ self.hc.create_table(
+ name=self.table_name,
+ auto_id=True,
+ fields=fields,
+ database_name=self.database_name,
+ number_of_shards=self.number_of_shards,
+ number_of_replicas=self.number_of_replicas,
+ )
+ self.col = self.hc.get_table(self.table_name, self.database_name)
+ logger.info(
+ f"[_create_collection] : "
+ f"create table {self.table_name} in {self.database_name} successfully"
+ )
+
+ def _extract_fields(self) -> None:
+ """Grab the existing fields from the Collection"""
+ from transwarp_hippo_api.hippo_client import HippoTable
+
+ if isinstance(self.col, HippoTable):
+ schema = self.col.schema
+ logger.debug(f"[_extract_fields] schema:{schema}")
+ for x in schema:
+ self.fields.append(x.name)
+ logger.debug(f"04 [_extract_fields] fields:{self.fields}")
+
+ # TO CAN: Translated into English, your statement would be: "Currently,
+ # only the field named 'vector' (the automatically created vector field)
+ # is checked for indexing. Indexes need to be created manually for other
+ # vector type columns.
+ def _get_index(self) -> Optional[Dict[str, Any]]:
+ """Return the vector index information if it exists"""
+ from transwarp_hippo_api.hippo_client import HippoTable
+
+ if isinstance(self.col, HippoTable):
+ table_info = self.hc.get_table_info(
+ self.table_name, self.database_name
+ ).get(self.table_name, {})
+ embedding_indexes = table_info.get("embedding_indexes", None)
+ if embedding_indexes is None:
+ return None
+ else:
+ for x in self.hc.get_table_info(self.table_name, self.database_name)[
+ self.table_name
+ ]["embedding_indexes"]:
+ logger.debug(f"[_get_index] embedding_indexes {embedding_indexes}")
+ if x["column"] == self._vector_field:
+ return x
+ return None
+
+ # TO Indexes can only be created for the self._vector_field field.
+ def _create_index(self) -> None:
+ """Create a index on the collection"""
+ from transwarp_hippo_api.hippo_client import HippoTable
+ from transwarp_hippo_api.hippo_type import IndexType, MetricType
+
+ if isinstance(self.col, HippoTable) and self._get_index() is None:
+ if self._get_index() is None:
+ if self.index_params is None:
+ self.index_params = {
+ "index_name": "langchain_auto_create",
+ "metric_type": MetricType.L2,
+ "index_type": IndexType.IVF_FLAT,
+ "nlist": 10,
+ }
+
+ self.col.create_index(
+ self._vector_field,
+ self.index_params["index_name"],
+ self.index_params["index_type"],
+ self.index_params["metric_type"],
+ nlist=self.index_params["nlist"],
+ )
+ logger.debug(
+ self.col.activate_index(self.index_params["index_name"])
+ )
+ logger.info("create index successfully")
+ else:
+ index_dict = {
+ "IVF_FLAT": IndexType.IVF_FLAT,
+ "FLAT": IndexType.FLAT,
+ "IVF_SQ": IndexType.IVF_SQ,
+ "IVF_PQ": IndexType.IVF_PQ,
+ "HNSW": IndexType.HNSW,
+ }
+
+ metric_dict = {
+ "ip": MetricType.IP,
+ "IP": MetricType.IP,
+ "l2": MetricType.L2,
+ "L2": MetricType.L2,
+ }
+ self.index_params["metric_type"] = metric_dict[
+ self.index_params["metric_type"]
+ ]
+
+ if self.index_params["index_type"] == "FLAT":
+ self.index_params["index_type"] = index_dict[
+ self.index_params["index_type"]
+ ]
+ self.col.create_index(
+ self._vector_field,
+ self.index_params["index_name"],
+ self.index_params["index_type"],
+ self.index_params["metric_type"],
+ )
+ logger.debug(
+ self.col.activate_index(self.index_params["index_name"])
+ )
+ elif (
+ self.index_params["index_type"] == "IVF_FLAT"
+ or self.index_params["index_type"] == "IVF_SQ"
+ ):
+ self.index_params["index_type"] = index_dict[
+ self.index_params["index_type"]
+ ]
+ self.col.create_index(
+ self._vector_field,
+ self.index_params["index_name"],
+ self.index_params["index_type"],
+ self.index_params["metric_type"],
+ nlist=self.index_params.get("nlist", 10),
+ nprobe=self.index_params.get("nprobe", 10),
+ )
+ logger.debug(
+ self.col.activate_index(self.index_params["index_name"])
+ )
+ elif self.index_params["index_type"] == "IVF_PQ":
+ self.index_params["index_type"] = index_dict[
+ self.index_params["index_type"]
+ ]
+ self.col.create_index(
+ self._vector_field,
+ self.index_params["index_name"],
+ self.index_params["index_type"],
+ self.index_params["metric_type"],
+ nlist=self.index_params.get("nlist", 10),
+ nprobe=self.index_params.get("nprobe", 10),
+ nbits=self.index_params.get("nbits", 8),
+ m=self.index_params.get("m"),
+ )
+ logger.debug(
+ self.col.activate_index(self.index_params["index_name"])
+ )
+ elif self.index_params["index_type"] == "HNSW":
+ self.index_params["index_type"] = index_dict[
+ self.index_params["index_type"]
+ ]
+ self.col.create_index(
+ self._vector_field,
+ self.index_params["index_name"],
+ self.index_params["index_type"],
+ self.index_params["metric_type"],
+ M=self.index_params.get("M"),
+ ef_construction=self.index_params.get("ef_construction"),
+ ef_search=self.index_params.get("ef_search"),
+ )
+ logger.debug(
+ self.col.activate_index(self.index_params["index_name"])
+ )
+ else:
+ raise ValueError(
+ "Index name does not match, "
+ "please enter the correct index name. "
+ "(FLAT, IVF_FLAT, IVF_PQ,IVF_SQ, HNSW)"
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ timeout: Optional[int] = None,
+ batch_size: int = 1000,
+ **kwargs: Any,
+ ) -> List[str]:
+ """
+ Add text to the collection.
+
+ Args:
+ texts: An iterable that contains the text to be added.
+ metadatas: An optional list of dictionaries,
+ each dictionary contains the metadata associated with a text.
+ timeout: Optional timeout, in seconds.
+ batch_size: The number of texts inserted in each batch, defaults to 1000.
+ **kwargs: Other optional parameters.
+
+ Returns:
+ A list of strings, containing the unique identifiers of the inserted texts.
+
+ Note:
+ If the collection has not yet been created,
+ this method will create a new collection.
+ """
+ from transwarp_hippo_api.hippo_client import HippoTable
+
+ if not texts or all(t == "" for t in texts):
+ logger.debug("Nothing to insert, skipping.")
+ return []
+ texts = list(texts)
+
+ logger.debug(f"[add_texts] texts: {texts}")
+
+ try:
+ embeddings = self.embedding_func.embed_documents(texts)
+ except NotImplementedError:
+ embeddings = [self.embedding_func.embed_query(x) for x in texts]
+
+ if len(embeddings) == 0:
+ logger.debug("Nothing to insert, skipping.")
+ return []
+
+ logger.debug(f"[add_texts] len_embeddings:{len(embeddings)}")
+
+ # ε¦ζθΏζ²‘ζεε»Ίcollectionεεε»Ίcollection
+ if not isinstance(self.col, HippoTable):
+ self._get_env(embeddings, metadatas)
+
+ # Dict to hold all insert columns
+ insert_dict: Dict[str, list] = {
+ self._text_field: texts,
+ self._vector_field: embeddings,
+ }
+ logger.debug(f"[add_texts] metadatas:{metadatas}")
+ logger.debug(f"[add_texts] fields:{self.fields}")
+ if metadatas is not None:
+ for d in metadatas:
+ for key, value in d.items():
+ if key in self.fields:
+ insert_dict.setdefault(key, []).append(value)
+
+ logger.debug(insert_dict[self._text_field])
+
+ # Total insert count
+ vectors: list = insert_dict[self._vector_field]
+ total_count = len(vectors)
+
+ if "pk" in self.fields:
+ self.fields.remove("pk")
+
+ logger.debug(f"[add_texts] total_count:{total_count}")
+ for i in range(0, total_count, batch_size):
+ # Grab end index
+ end = min(i + batch_size, total_count)
+ # Convert dict to list of lists batch for insertion
+ insert_list = [insert_dict[x][i:end] for x in self.fields]
+ try:
+ res = self.col.insert_rows(insert_list)
+ logger.info(f"05 [add_texts] insert {res}")
+ except Exception as e:
+ logger.error(
+ "Failed to insert batch starting at entity: %s/%s", i, total_count
+ )
+ raise e
+ return [""]
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """
+ Perform a similarity search on the query string.
+
+ Args:
+ query (str): The text to search for.
+ k (int, optional): The number of results to return. Default is 4.
+ param (dict, optional): Specifies the search parameters for the index.
+ Defaults to None.
+ expr (str, optional): Filtering expression. Defaults to None.
+ timeout (int, optional): Time to wait before a timeout error.
+ Defaults to None.
+ kwargs: Keyword arguments for Collection.search().
+
+ Returns:
+ List[Document]: The document results of the search.
+ """
+
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+ res = self.similarity_search_with_score(
+ query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+ )
+ return [doc for doc, _ in res]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """
+ Performs a search on the query string and returns results with scores.
+
+ Args:
+ query (str): The text being searched.
+ k (int, optional): The number of results to return.
+ Default is 4.
+ param (dict): Specifies the search parameters for the index.
+ Default is None.
+ expr (str, optional): Filtering expression. Default is None.
+ timeout (int, optional): The waiting time before a timeout error.
+ Default is None.
+ kwargs: Keyword arguments for Collection.search().
+
+ Returns:
+ List[float], List[Tuple[Document, any, any]]:
+ """
+
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+
+ # Embed the query text.
+ embedding = self.embedding_func.embed_query(query)
+
+ ret = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+ )
+ return ret
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """
+ Performs a search on the query string and returns results with scores.
+
+ Args:
+ embedding (List[float]): The embedding vector being searched.
+ k (int, optional): The number of results to return.
+ Default is 4.
+ param (dict): Specifies the search parameters for the index.
+ Default is None.
+ expr (str, optional): Filtering expression. Default is None.
+ timeout (int, optional): The waiting time before a timeout error.
+ Default is None.
+ kwargs: Keyword arguments for Collection.search().
+
+ Returns:
+ List[Tuple[Document, float]]: Resulting documents and scores.
+ """
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+
+ # if param is None:
+ # param = self.search_params
+
+ # Determine result metadata fields.
+ output_fields = self.fields[:]
+ output_fields.remove(self._vector_field)
+
+ # Perform the search.
+ logger.debug(f"search_field:{self._vector_field}")
+ logger.debug(f"vectors:{[embedding]}")
+ logger.debug(f"output_fields:{output_fields}")
+ logger.debug(f"topk:{k}")
+ logger.debug(f"dsl:{expr}")
+
+ res = self.col.query(
+ search_field=self._vector_field,
+ vectors=[embedding],
+ output_fields=output_fields,
+ topk=k,
+ dsl=expr,
+ )
+ # Organize results.
+ logger.debug(f"[similarity_search_with_score_by_vector] res:{res}")
+ score_col = self._text_field + "%scores"
+ ret = []
+ count = 0
+ for items in zip(*[res[0][field] for field in output_fields]):
+ meta = {field: value for field, value in zip(output_fields, items)}
+ doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
+ logger.debug(
+ f"[similarity_search_with_score_by_vector] "
+ f"res[0][score_col]:{res[0][score_col]}"
+ )
+ score = res[0][score_col][count]
+ count += 1
+ ret.append((doc, score))
+
+ return ret
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ table_name: str = "test",
+ database_name: str = "default",
+ connection_args: Dict[str, Any] = DEFAULT_HIPPO_CONNECTION,
+ index_params: Optional[Dict[Any, Any]] = None,
+ search_params: Optional[Dict[str, Any]] = None,
+ drop_old: bool = False,
+ **kwargs: Any,
+ ) -> "Hippo":
+ """
+ Creates an instance of the VST class from the given texts.
+
+ Args:
+ texts (List[str]): List of texts to be added.
+ embedding (Embeddings): Embedding model for the texts.
+ metadatas (List[dict], optional):
+ List of metadata dictionaries for each text.Defaults to None.
+ table_name (str): Name of the table. Defaults to "test".
+ database_name (str): Name of the database. Defaults to "default".
+ connection_args (dict[str, Any]): Connection parameters.
+ Defaults to DEFAULT_HIPPO_CONNECTION.
+ index_params (dict): Indexing parameters. Defaults to None.
+ search_params (dict): Search parameters. Defaults to an empty dictionary.
+ drop_old (bool): Whether to drop the old collection. Defaults to False.
+ kwargs: Other arguments.
+
+ Returns:
+ Hippo: An instance of the VST class.
+ """
+
+ if search_params is None:
+ search_params = {}
+ logger.info("00 [from_texts] init the class of Hippo")
+ vector_db = cls(
+ embedding_function=embedding,
+ table_name=table_name,
+ database_name=database_name,
+ connection_args=connection_args,
+ index_params=index_params,
+ drop_old=drop_old,
+ **kwargs,
+ )
+ logger.debug(f"[from_texts] texts:{texts}")
+ logger.debug(f"[from_texts] metadatas:{metadatas}")
+ vector_db.add_texts(texts=texts, metadatas=metadatas)
+ return vector_db
diff --git a/libs/community/langchain_community/vectorstores/hologres.py b/libs/community/langchain_community/vectorstores/hologres.py
new file mode 100644
index 00000000000..b2572f40c3b
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/hologres.py
@@ -0,0 +1,421 @@
+from __future__ import annotations
+
+import logging
+import uuid
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore
+
+ADA_TOKEN_COUNT = 1536
+_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_pg_embedding"
+
+
+class Hologres(VectorStore):
+ """`Hologres API` vector store.
+
+ - `connection_string` is a hologres connection string.
+ - `embedding_function` any embedding function implementing
+ `langchain.embeddings.base.Embeddings` interface.
+ - `ndims` is the number of dimensions of the embedding output.
+ - `table_name` is the name of the table to store embeddings and data.
+ (default: langchain_pg_embedding)
+ - NOTE: The table will be created when initializing the store (if not exists)
+ So, make sure the user has the right permissions to create tables.
+ - `pre_delete_table` if True, will delete the table if it exists.
+ (default: False)
+ - Useful for testing.
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ embedding_function: Embeddings,
+ ndims: int = ADA_TOKEN_COUNT,
+ table_name: str = _LANGCHAIN_DEFAULT_TABLE_NAME,
+ pre_delete_table: bool = False,
+ logger: Optional[logging.Logger] = None,
+ ) -> None:
+ self.connection_string = connection_string
+ self.ndims = ndims
+ self.table_name = table_name
+ self.embedding_function = embedding_function
+ self.pre_delete_table = pre_delete_table
+ self.logger = logger or logging.getLogger(__name__)
+ self.__post_init__()
+
+ def __post_init__(
+ self,
+ ) -> None:
+ """
+ Initialize the store.
+ """
+ from hologres_vector import HologresVector
+
+ self.storage = HologresVector(
+ self.connection_string,
+ ndims=self.ndims,
+ table_name=self.table_name,
+ table_schema={"document": "text"},
+ pre_delete_table=self.pre_delete_table,
+ )
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_function
+
+ @classmethod
+ def __from(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding_function: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ ndims: int = ADA_TOKEN_COUNT,
+ table_name: str = _LANGCHAIN_DEFAULT_TABLE_NAME,
+ pre_delete_table: bool = False,
+ **kwargs: Any,
+ ) -> Hologres:
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ connection_string = cls.get_connection_string(kwargs)
+
+ store = cls(
+ connection_string=connection_string,
+ embedding_function=embedding_function,
+ ndims=ndims,
+ table_name=table_name,
+ pre_delete_table=pre_delete_table,
+ )
+
+ store.add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ return store
+
+ def add_embeddings(
+ self,
+ texts: Iterable[str],
+ embeddings: List[List[float]],
+ metadatas: List[dict],
+ ids: List[str],
+ **kwargs: Any,
+ ) -> None:
+ """Add embeddings to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ embeddings: List of list of embedding vectors.
+ metadatas: List of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+ """
+ try:
+ schema_datas = [{"document": t} for t in texts]
+ self.storage.upsert_vectors(embeddings, ids, metadatas, schema_datas)
+ except Exception as e:
+ self.logger.exception(e)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ embeddings = self.embedding_function.embed_documents(list(texts))
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ self.add_embeddings(texts, embeddings, metadatas, ids, **kwargs)
+
+ return ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search with Hologres with distance.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ embedding = self.embedding_function.embed_query(text=query)
+ return self.similarity_search_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ embedding = self.embedding_function.embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return docs
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[dict] = None,
+ ) -> List[Tuple[Document, float]]:
+ results: List[dict[str, Any]] = self.storage.search(
+ embedding, k=k, select_columns=["document"], metadata_filters=filter
+ )
+
+ docs = [
+ (
+ Document(
+ page_content=result["document"],
+ metadata=result["metadata"],
+ ),
+ result["distance"],
+ )
+ for result in results
+ ]
+ return docs
+
+ @classmethod
+ def from_texts(
+ cls: Type[Hologres],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ndims: int = ADA_TOKEN_COUNT,
+ table_name: str = _LANGCHAIN_DEFAULT_TABLE_NAME,
+ ids: Optional[List[str]] = None,
+ pre_delete_table: bool = False,
+ **kwargs: Any,
+ ) -> Hologres:
+ """
+ Return VectorStore initialized from texts and embeddings.
+ Hologres connection string is required
+ "Either pass it as a parameter
+ or set the HOLOGRES_CONNECTION_STRING environment variable.
+ Create the connection string by calling
+ HologresVector.connection_string_from_db_params
+ """
+ embeddings = embedding.embed_documents(list(texts))
+
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ ndims=ndims,
+ table_name=table_name,
+ pre_delete_table=pre_delete_table,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ndims: int = ADA_TOKEN_COUNT,
+ table_name: str = _LANGCHAIN_DEFAULT_TABLE_NAME,
+ ids: Optional[List[str]] = None,
+ pre_delete_table: bool = False,
+ **kwargs: Any,
+ ) -> Hologres:
+ """Construct Hologres wrapper from raw documents and pre-
+ generated embeddings.
+
+ Return VectorStore initialized from documents and embeddings.
+ Hologres connection string is required
+ "Either pass it as a parameter
+ or set the HOLOGRES_CONNECTION_STRING environment variable.
+ Create the connection string by calling
+ HologresVector.connection_string_from_db_params
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Hologres
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ faiss = Hologres.from_embeddings(text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ ndims=ndims,
+ table_name=table_name,
+ pre_delete_table=pre_delete_table,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_existing_index(
+ cls: Type[Hologres],
+ embedding: Embeddings,
+ ndims: int = ADA_TOKEN_COUNT,
+ table_name: str = _LANGCHAIN_DEFAULT_TABLE_NAME,
+ pre_delete_table: bool = False,
+ **kwargs: Any,
+ ) -> Hologres:
+ """
+ Get instance of an existing Hologres store.This method will
+ return the instance of the store without inserting any new
+ embeddings
+ """
+
+ connection_string = cls.get_connection_string(kwargs)
+
+ store = cls(
+ connection_string=connection_string,
+ ndims=ndims,
+ table_name=table_name,
+ embedding_function=embedding,
+ pre_delete_table=pre_delete_table,
+ )
+
+ return store
+
+ @classmethod
+ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
+ connection_string: str = get_from_dict_or_env(
+ data=kwargs,
+ key="connection_string",
+ env_key="HOLOGRES_CONNECTION_STRING",
+ )
+
+ if not connection_string:
+ raise ValueError(
+ "Hologres connection string is required"
+ "Either pass it as a parameter"
+ "or set the HOLOGRES_CONNECTION_STRING environment variable."
+ "Create the connection string by calling"
+ "HologresVector.connection_string_from_db_params"
+ )
+
+ return connection_string
+
+ @classmethod
+ def from_documents(
+ cls: Type[Hologres],
+ documents: List[Document],
+ embedding: Embeddings,
+ ndims: int = ADA_TOKEN_COUNT,
+ table_name: str = _LANGCHAIN_DEFAULT_TABLE_NAME,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> Hologres:
+ """
+ Return VectorStore initialized from documents and embeddings.
+ Hologres connection string is required
+ "Either pass it as a parameter
+ or set the HOLOGRES_CONNECTION_STRING environment variable.
+ Create the connection string by calling
+ HologresVector.connection_string_from_db_params
+ """
+
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+ connection_string = cls.get_connection_string(kwargs)
+
+ kwargs["connection_string"] = connection_string
+
+ return cls.from_texts(
+ texts=texts,
+ pre_delete_collection=pre_delete_collection,
+ embedding=embedding,
+ metadatas=metadatas,
+ ids=ids,
+ ndims=ndims,
+ table_name=table_name,
+ **kwargs,
+ )
+
+ @classmethod
+ def connection_string_from_db_params(
+ cls,
+ host: str,
+ port: int,
+ database: str,
+ user: str,
+ password: str,
+ ) -> str:
+ """Return connection string from database parameters."""
+ return (
+ f"dbname={database} user={user} password={password} host={host} port={port}"
+ )
diff --git a/libs/community/langchain_community/vectorstores/lancedb.py b/libs/community/langchain_community/vectorstores/lancedb.py
new file mode 100644
index 00000000000..4ca68c92ca6
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/lancedb.py
@@ -0,0 +1,136 @@
+from __future__ import annotations
+
+import uuid
+from typing import Any, Iterable, List, Optional
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+
+class LanceDB(VectorStore):
+ """`LanceDB` vector store.
+
+ To use, you should have ``lancedb`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ db = lancedb.connect('./lancedb')
+ table = db.open_table('my_table')
+ vectorstore = LanceDB(table, embedding_function)
+ vectorstore.add_texts(['text1', 'text2'])
+ result = vectorstore.similarity_search('text1')
+ """
+
+ def __init__(
+ self,
+ connection: Any,
+ embedding: Embeddings,
+ vector_key: Optional[str] = "vector",
+ id_key: Optional[str] = "id",
+ text_key: Optional[str] = "text",
+ ):
+ """Initialize with Lance DB connection"""
+ try:
+ import lancedb
+ except ImportError:
+ raise ImportError(
+ "Could not import lancedb python package. "
+ "Please install it with `pip install lancedb`."
+ )
+ if not isinstance(connection, lancedb.db.LanceTable):
+ raise ValueError(
+ "connection should be an instance of lancedb.db.LanceTable, ",
+ f"got {type(connection)}",
+ )
+ self._connection = connection
+ self._embedding = embedding
+ self._vector_key = vector_key
+ self._id_key = id_key
+ self._text_key = text_key
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embedding
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Turn texts into embedding and add it to the database
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids to associate with the texts.
+
+ Returns:
+ List of ids of the added texts.
+ """
+ # Embed texts and create documents
+ docs = []
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
+ embeddings = self._embedding.embed_documents(list(texts))
+ for idx, text in enumerate(texts):
+ embedding = embeddings[idx]
+ metadata = metadatas[idx] if metadatas else {}
+ docs.append(
+ {
+ self._vector_key: embedding,
+ self._id_key: ids[idx],
+ self._text_key: text,
+ **metadata,
+ }
+ )
+
+ self._connection.add(docs)
+ return ids
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return documents most similar to the query
+
+ Args:
+ query: String to query the vectorstore with.
+ k: Number of documents to return.
+
+ Returns:
+ List of documents most similar to the query.
+ """
+ embedding = self._embedding.embed_query(query)
+ docs = self._connection.search(embedding).limit(k).to_df()
+ return [
+ Document(
+ page_content=row[self._text_key],
+ metadata=row[docs.columns != self._text_key],
+ )
+ for _, row in docs.iterrows()
+ ]
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ connection: Any = None,
+ vector_key: Optional[str] = "vector",
+ id_key: Optional[str] = "id",
+ text_key: Optional[str] = "text",
+ **kwargs: Any,
+ ) -> LanceDB:
+ instance = LanceDB(
+ connection,
+ embedding,
+ vector_key,
+ id_key,
+ text_key,
+ )
+ instance.add_texts(texts, metadatas=metadatas, **kwargs)
+
+ return instance
diff --git a/libs/community/langchain_community/vectorstores/llm_rails.py b/libs/community/langchain_community/vectorstores/llm_rails.py
new file mode 100644
index 00000000000..a46d74a32e0
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/llm_rails.py
@@ -0,0 +1,244 @@
+"""Wrapper around LLMRails vector database."""
+from __future__ import annotations
+
+import json
+import logging
+import os
+import uuid
+from typing import Any, Iterable, List, Optional, Tuple
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import Field
+from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
+
+
+class LLMRails(VectorStore):
+ """Implementation of Vector Store using LLMRails.
+
+ See https://llmrails.com/
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import LLMRails
+
+ vectorstore = LLMRails(
+ api_key=llm_rails_api_key,
+ datastore_id=datastore_id
+ )
+ """
+
+ def __init__(
+ self,
+ datastore_id: Optional[str] = None,
+ api_key: Optional[str] = None,
+ ):
+ """Initialize with LLMRails API."""
+ self._datastore_id = datastore_id or os.environ.get("LLM_RAILS_DATASTORE_ID")
+ self._api_key = api_key or os.environ.get("LLM_RAILS_API_KEY")
+ if self._api_key is None:
+ logging.warning("Can't find Rails credentials in environment.")
+
+ self._session = requests.Session() # to reuse connections
+ self.datastore_id = datastore_id
+ self.base_url = "https://api.llmrails.com/v1"
+
+ def _get_post_headers(self) -> dict:
+ """Returns headers that should be attached to each post request."""
+ return {"X-API-KEY": self._api_key}
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+
+ """
+ names: List[str] = []
+ for text in texts:
+ doc_name = str(uuid.uuid4())
+ response = self._session.post(
+ f"{self.base_url}/datastores/{self._datastore_id}/text",
+ json={"name": doc_name, "text": text},
+ verify=True,
+ headers=self._get_post_headers(),
+ )
+
+ if response.status_code != 200:
+ logging.error(
+ f"Create request failed for doc_name = {doc_name} with status code "
+ f"{response.status_code}, reason {response.reason}, text "
+ f"{response.text}"
+ )
+
+ return names
+
+ names.append(doc_name)
+
+ return names
+
+ def add_files(
+ self,
+ files_list: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> bool:
+ """
+ LLMRails provides a way to add documents directly via our API where
+ pre-processing and chunking occurs internally in an optimal way
+ This method provides a way to use that API in LangChain
+
+ Args:
+ files_list: Iterable of strings, each representing a local file path.
+ Files could be text, HTML, PDF, markdown, doc/docx, ppt/pptx, etc.
+ see API docs for full list
+
+ Returns:
+ List of ids associated with each of the files indexed
+ """
+ files = []
+
+ for file in files_list:
+ if not os.path.exists(file):
+ logging.error(f"File {file} does not exist, skipping")
+ continue
+
+ files.append(("file", (os.path.basename(file), open(file, "rb"))))
+
+ response = self._session.post(
+ f"{self.base_url}/datastores/{self._datastore_id}/file",
+ files=files,
+ verify=True,
+ headers=self._get_post_headers(),
+ )
+
+ if response.status_code != 200:
+ logging.error(
+ f"Create request failed for datastore = {self._datastore_id} "
+ f"with status code {response.status_code}, reason {response.reason}, "
+ f"text {response.text}"
+ )
+
+ return False
+
+ return True
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 5
+ ) -> List[Tuple[Document, float]]:
+ """Return LLMRails documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 5 Max 10.
+ alpha: parameter for hybrid search .
+
+ Returns:
+ List of Documents most similar to the query and score for each.
+ """
+ response = self._session.post(
+ headers=self._get_post_headers(),
+ url=f"{self.base_url}/datastores/{self._datastore_id}/search",
+ data=json.dumps({"k": k, "text": query}),
+ timeout=10,
+ )
+
+ if response.status_code != 200:
+ logging.error(
+ "Query failed %s",
+ f"(code {response.status_code}, reason {response.reason}, details "
+ f"{response.text})",
+ )
+ return []
+
+ results = response.json()["results"]
+ docs = [
+ (
+ Document(
+ page_content=x["text"],
+ metadata={
+ key: value
+ for key, value in x["metadata"].items()
+ if key != "score"
+ },
+ ),
+ x["metadata"]["score"],
+ )
+ for x in results
+ ]
+
+ return docs
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return LLMRails documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 5.
+
+ Returns:
+ List of Documents most similar to the query
+ """
+ docs_and_scores = self.similarity_search_with_score(query, k=k)
+
+ return [doc for doc, _ in docs_and_scores]
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> LLMRails:
+ """Construct LLMRails wrapper from raw documents.
+ This is intended to be a quick way to get started.
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import LLMRails
+ llm_rails = LLMRails.from_texts(
+ texts,
+ datastore_id=datastore_id,
+ api_key=llm_rails_api_key
+ )
+ """
+ # Note: LLMRails generates its own embeddings, so we ignore the provided
+ # embeddings (required by interface)
+ llm_rails = cls(**kwargs)
+ llm_rails.add_texts(texts)
+ return llm_rails
+
+ def as_retriever(self, **kwargs: Any) -> LLMRailsRetriever:
+ return LLMRailsRetriever(vectorstore=self, **kwargs)
+
+
+class LLMRailsRetriever(VectorStoreRetriever):
+ """Retriever for LLMRails."""
+
+ vectorstore: LLMRails
+ search_kwargs: dict = Field(default_factory=lambda: {"k": 5})
+ """Search params.
+ k: Number of Documents to return. Defaults to 5.
+ alpha: parameter for hybrid search .
+ """
+
+ def add_texts(self, texts: List[str]) -> None:
+ """Add text to the datastore.
+
+ Args:
+ texts (List[str]): The text
+ """
+ self.vectorstore.add_texts(texts)
diff --git a/libs/community/langchain_community/vectorstores/marqo.py b/libs/community/langchain_community/vectorstores/marqo.py
new file mode 100644
index 00000000000..d9777533ada
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/marqo.py
@@ -0,0 +1,470 @@
+from __future__ import annotations
+
+import json
+import uuid
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ import marqo
+
+
+class Marqo(VectorStore):
+ """`Marqo` vector store.
+
+ Marqo indexes have their own models associated with them to generate your
+ embeddings. This means that you can selected from a range of different models
+ and also use CLIP models to create multimodal indexes
+ with images and text together.
+
+ Marqo also supports more advanced queries with multiple weighted terms, see See
+ https://docs.marqo.ai/latest/#searching-using-weights-in-queries.
+ This class can flexibly take strings or dictionaries for weighted queries
+ in its similarity search methods.
+
+ To use, you should have the `marqo` python package installed, you can do this with
+ `pip install marqo`.
+
+ Example:
+ .. code-block:: python
+
+ import marqo
+ from langchain_community.vectorstores import Marqo
+ client = marqo.Client(url=os.environ["MARQO_URL"], ...)
+ vectorstore = Marqo(client, index_name)
+
+ """
+
+ def __init__(
+ self,
+ client: marqo.Client,
+ index_name: str,
+ add_documents_settings: Optional[Dict[str, Any]] = None,
+ searchable_attributes: Optional[List[str]] = None,
+ page_content_builder: Optional[Callable[[Dict[str, Any]], str]] = None,
+ ):
+ """Initialize with Marqo client."""
+ try:
+ import marqo
+ except ImportError:
+ raise ImportError(
+ "Could not import marqo python package. "
+ "Please install it with `pip install marqo`."
+ )
+ if not isinstance(client, marqo.Client):
+ raise ValueError(
+ f"client should be an instance of marqo.Client, got {type(client)}"
+ )
+ self._client = client
+ self._index_name = index_name
+ self._add_documents_settings = (
+ {} if add_documents_settings is None else add_documents_settings
+ )
+ self._searchable_attributes = searchable_attributes
+ self.page_content_builder = page_content_builder
+
+ self.tensor_fields = ["text"]
+
+ self._document_batch_size = 1024
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return None
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Upload texts with metadata (properties) to Marqo.
+
+ You can either have marqo generate ids for each document or you can provide
+ your own by including a "_id" field in the metadata objects.
+
+ Args:
+ texts (Iterable[str]): am iterator of texts - assumed to preserve an
+ order that matches the metadatas.
+ metadatas (Optional[List[dict]], optional): a list of metadatas.
+
+ Raises:
+ ValueError: if metadatas is provided and the number of metadatas differs
+ from the number of texts.
+
+ Returns:
+ List[str]: The list of ids that were added.
+ """
+
+ if self._client.index(self._index_name).get_settings()["index_defaults"][
+ "treat_urls_and_pointers_as_images"
+ ]:
+ raise ValueError(
+ "Marqo.add_texts is disabled for multimodal indexes. To add documents "
+ "with a multimodal index use the Python client for Marqo directly."
+ )
+ documents: List[Dict[str, str]] = []
+
+ num_docs = 0
+ for i, text in enumerate(texts):
+ doc = {
+ "text": text,
+ "metadata": json.dumps(metadatas[i]) if metadatas else json.dumps({}),
+ }
+ documents.append(doc)
+ num_docs += 1
+
+ ids = []
+ for i in range(0, num_docs, self._document_batch_size):
+ response = self._client.index(self._index_name).add_documents(
+ documents[i : i + self._document_batch_size],
+ tensor_fields=self.tensor_fields,
+ **self._add_documents_settings,
+ )
+ if response["errors"]:
+ err_msg = (
+ f"Error in upload for documents in index range [{i},"
+ f"{i + self._document_batch_size}], "
+ f"check Marqo logs."
+ )
+ raise RuntimeError(err_msg)
+
+ ids += [item["_id"] for item in response["items"]]
+
+ return ids
+
+ def similarity_search(
+ self,
+ query: Union[str, Dict[str, float]],
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Search the marqo index for the most similar documents.
+
+ Args:
+ query (Union[str, Dict[str, float]]): The query for the search, either
+ as a string or a weighted query.
+ k (int, optional): The number of documents to return. Defaults to 4.
+
+ Returns:
+ List[Document]: k documents ordered from best to worst match.
+ """
+ results = self.marqo_similarity_search(query=query, k=k)
+
+ documents = self._construct_documents_from_results_without_score(results)
+ return documents
+
+ def similarity_search_with_score(
+ self,
+ query: Union[str, Dict[str, float]],
+ k: int = 4,
+ ) -> List[Tuple[Document, float]]:
+ """Return documents from Marqo that are similar to the query as well
+ as their scores.
+
+ Args:
+ query (str): The query to search with, either as a string or a weighted
+ query.
+ k (int, optional): The number of documents to return. Defaults to 4.
+
+ Returns:
+ List[Tuple[Document, float]]: The matching documents and their scores,
+ ordered by descending score.
+ """
+ results = self.marqo_similarity_search(query=query, k=k)
+
+ scored_documents = self._construct_documents_from_results_with_score(results)
+ return scored_documents
+
+ def bulk_similarity_search(
+ self,
+ queries: Iterable[Union[str, Dict[str, float]]],
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[List[Document]]:
+ """Search the marqo index for the most similar documents in bulk with multiple
+ queries.
+
+ Args:
+ queries (Iterable[Union[str, Dict[str, float]]]): An iterable of queries to
+ execute in bulk, queries in the list can be strings or dictionaries of
+ weighted queries.
+ k (int, optional): The number of documents to return for each query.
+ Defaults to 4.
+
+ Returns:
+ List[List[Document]]: A list of results for each query.
+ """
+ bulk_results = self.marqo_bulk_similarity_search(queries=queries, k=k)
+ bulk_documents: List[List[Document]] = []
+ for results in bulk_results["result"]:
+ documents = self._construct_documents_from_results_without_score(results)
+ bulk_documents.append(documents)
+
+ return bulk_documents
+
+ def bulk_similarity_search_with_score(
+ self,
+ queries: Iterable[Union[str, Dict[str, float]]],
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[List[Tuple[Document, float]]]:
+ """Return documents from Marqo that are similar to the query as well as
+ their scores using a batch of queries.
+
+ Args:
+ query (Iterable[Union[str, Dict[str, float]]]): An iterable of queries
+ to execute in bulk, queries in the list can be strings or dictionaries
+ of weighted queries.
+ k (int, optional): The number of documents to return. Defaults to 4.
+
+ Returns:
+ List[Tuple[Document, float]]: A list of lists of the matching
+ documents and their scores for each query
+ """
+ bulk_results = self.marqo_bulk_similarity_search(queries=queries, k=k)
+ bulk_documents: List[List[Tuple[Document, float]]] = []
+ for results in bulk_results["result"]:
+ documents = self._construct_documents_from_results_with_score(results)
+ bulk_documents.append(documents)
+
+ return bulk_documents
+
+ def _construct_documents_from_results_with_score(
+ self, results: Dict[str, List[Dict[str, str]]]
+ ) -> List[Tuple[Document, Any]]:
+ """Helper to convert Marqo results into documents.
+
+ Args:
+ results (List[dict]): A marqo results object with the 'hits'.
+ include_scores (bool, optional): Include scores alongside documents.
+ Defaults to False.
+
+ Returns:
+ Union[List[Document], List[Tuple[Document, float]]]: The documents or
+ document score pairs if `include_scores` is true.
+ """
+ documents: List[Tuple[Document, Any]] = []
+ for res in results["hits"]:
+ if self.page_content_builder is None:
+ text = res["text"]
+ else:
+ text = self.page_content_builder(res)
+
+ metadata = json.loads(res.get("metadata", "{}"))
+ documents.append(
+ (Document(page_content=text, metadata=metadata), res["_score"])
+ )
+ return documents
+
+ def _construct_documents_from_results_without_score(
+ self, results: Dict[str, List[Dict[str, str]]]
+ ) -> List[Document]:
+ """Helper to convert Marqo results into documents.
+
+ Args:
+ results (List[dict]): A marqo results object with the 'hits'.
+ include_scores (bool, optional): Include scores alongside documents.
+ Defaults to False.
+
+ Returns:
+ Union[List[Document], List[Tuple[Document, float]]]: The documents or
+ document score pairs if `include_scores` is true.
+ """
+ documents: List[Document] = []
+ for res in results["hits"]:
+ if self.page_content_builder is None:
+ text = res["text"]
+ else:
+ text = self.page_content_builder(res)
+
+ metadata = json.loads(res.get("metadata", "{}"))
+ documents.append(Document(page_content=text, metadata=metadata))
+ return documents
+
+ def marqo_similarity_search(
+ self,
+ query: Union[str, Dict[str, float]],
+ k: int = 4,
+ ) -> Dict[str, List[Dict[str, str]]]:
+ """Return documents from Marqo exposing Marqo's output directly
+
+ Args:
+ query (str): The query to search with.
+ k (int, optional): The number of documents to return. Defaults to 4.
+
+ Returns:
+ List[Dict[str, Any]]: This hits from marqo.
+ """
+ results = self._client.index(self._index_name).search(
+ q=query, searchable_attributes=self._searchable_attributes, limit=k
+ )
+ return results
+
+ def marqo_bulk_similarity_search(
+ self, queries: Iterable[Union[str, Dict[str, float]]], k: int = 4
+ ) -> Dict[str, List[Dict[str, List[Dict[str, str]]]]]:
+ """Return documents from Marqo using a bulk search, exposes Marqo's
+ output directly
+
+ Args:
+ queries (Iterable[Union[str, Dict[str, float]]]): A list of queries.
+ k (int, optional): The number of documents to return for each query.
+ Defaults to 4.
+
+ Returns:
+ Dict[str, Dict[List[Dict[str, Dict[str, Any]]]]]: A bulk search results
+ object
+ """
+ bulk_results = {
+ "result": [
+ self._client.index(self._index_name).search(
+ q=query, searchable_attributes=self._searchable_attributes, limit=k
+ )
+ for query in queries
+ ]
+ }
+
+ return bulk_results
+
+ @classmethod
+ def from_documents(
+ cls: Type[Marqo],
+ documents: List[Document],
+ embedding: Union[Embeddings, None] = None,
+ **kwargs: Any,
+ ) -> Marqo:
+ """Return VectorStore initialized from documents. Note that Marqo does not
+ need embeddings, we retain the parameter to adhere to the Liskov substitution
+ principle.
+
+
+ Args:
+ documents (List[Document]): Input documents
+ embedding (Any, optional): Embeddings (not required). Defaults to None.
+
+ Returns:
+ VectorStore: A Marqo vectorstore
+ """
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+ return cls.from_texts(texts, metadatas=metadatas, **kwargs)
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Any = None,
+ metadatas: Optional[List[dict]] = None,
+ index_name: str = "",
+ url: str = "http://localhost:8882",
+ api_key: str = "",
+ add_documents_settings: Optional[Dict[str, Any]] = None,
+ searchable_attributes: Optional[List[str]] = None,
+ page_content_builder: Optional[Callable[[Dict[str, str]], str]] = None,
+ index_settings: Optional[Dict[str, Any]] = None,
+ verbose: bool = True,
+ **kwargs: Any,
+ ) -> Marqo:
+ """Return Marqo initialized from texts. Note that Marqo does not need
+ embeddings, we retain the parameter to adhere to the Liskov
+ substitution principle.
+
+ This is a quick way to get started with marqo - simply provide your texts and
+ metadatas and this will create an instance of the data store and index the
+ provided data.
+
+ To know the ids of your documents with this approach you will need to include
+ them in under the key "_id" in your metadatas for each text
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Marqo
+
+ datastore = Marqo(texts=['text'], index_name='my-first-index',
+ url='http://localhost:8882')
+
+ Args:
+ texts (List[str]): A list of texts to index into marqo upon creation.
+ embedding (Any, optional): Embeddings (not required). Defaults to None.
+ index_name (str, optional): The name of the index to use, if none is
+ provided then one will be created with a UUID. Defaults to None.
+ url (str, optional): The URL for Marqo. Defaults to "http://localhost:8882".
+ api_key (str, optional): The API key for Marqo. Defaults to "".
+ metadatas (Optional[List[dict]], optional): A list of metadatas, to
+ accompany the texts. Defaults to None.
+ this is only used when a new index is being created. Defaults to "cpu". Can
+ be "cpu" or "cuda".
+ add_documents_settings (Optional[Dict[str, Any]], optional): Settings
+ for adding documents, see
+ https://docs.marqo.ai/0.0.16/API-Reference/documents/#query-parameters.
+ Defaults to {}.
+ index_settings (Optional[Dict[str, Any]], optional): Index settings if
+ the index doesn't exist, see
+ https://docs.marqo.ai/0.0.16/API-Reference/indexes/#index-defaults-object.
+ Defaults to {}.
+
+ Returns:
+ Marqo: An instance of the Marqo vector store
+ """
+ try:
+ import marqo
+ except ImportError:
+ raise ImportError(
+ "Could not import marqo python package. "
+ "Please install it with `pip install marqo`."
+ )
+
+ if not index_name:
+ index_name = str(uuid.uuid4())
+
+ client = marqo.Client(url=url, api_key=api_key)
+
+ try:
+ client.create_index(index_name, settings_dict=index_settings or {})
+ if verbose:
+ print(f"Created {index_name} successfully.")
+ except Exception:
+ if verbose:
+ print(f"Index {index_name} exists.")
+
+ instance: Marqo = cls(
+ client,
+ index_name,
+ searchable_attributes=searchable_attributes,
+ add_documents_settings=add_documents_settings or {},
+ page_content_builder=page_content_builder,
+ )
+ instance.add_texts(texts, metadatas)
+ return instance
+
+ def get_indexes(self) -> List[Dict[str, str]]:
+ """Helper to see your available indexes in marqo, useful if the
+ from_texts method was used without an index name specified
+
+ Returns:
+ List[Dict[str, str]]: The list of indexes
+ """
+ return self._client.get_indexes()["results"]
+
+ def get_number_of_documents(self) -> int:
+ """Helper to see the number of documents in the index
+
+ Returns:
+ int: The number of documents
+ """
+ return self._client.index(self._index_name).get_stats()["numberOfDocuments"]
diff --git a/libs/community/langchain_community/vectorstores/matching_engine.py b/libs/community/langchain_community/vectorstores/matching_engine.py
new file mode 100644
index 00000000000..fce60648255
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/matching_engine.py
@@ -0,0 +1,579 @@
+from __future__ import annotations
+
+import json
+import logging
+import time
+import uuid
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.utilities.vertexai import get_client_info
+
+if TYPE_CHECKING:
+ from google.cloud import storage
+ from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint
+ from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
+ Namespace,
+ )
+ from google.oauth2.service_account import Credentials
+
+ from langchain_community.embeddings import TensorflowHubEmbeddings
+
+logger = logging.getLogger()
+
+
+class MatchingEngine(VectorStore):
+ """`Google Vertex AI Vector Search` (previously Matching Engine) vector store.
+
+ While the embeddings are stored in the Matching Engine, the embedded
+ documents will be stored in GCS.
+
+ An existing Index and corresponding Endpoint are preconditions for
+ using this module.
+
+ See usage in docs/integrations/vectorstores/google_vertex_ai_vector_search.ipynb
+
+ Note that this implementation is mostly meant for reading if you are
+ planning to do a real time implementation. While reading is a real time
+ operation, updating the index takes close to one hour."""
+
+ def __init__(
+ self,
+ project_id: str,
+ index: MatchingEngineIndex,
+ endpoint: MatchingEngineIndexEndpoint,
+ embedding: Embeddings,
+ gcs_client: storage.Client,
+ gcs_bucket_name: str,
+ credentials: Optional[Credentials] = None,
+ ):
+ """Google Vertex AI Vector Search (previously Matching Engine)
+ implementation of the vector store.
+
+ While the embeddings are stored in the Matching Engine, the embedded
+ documents will be stored in GCS.
+
+ An existing Index and corresponding Endpoint are preconditions for
+ using this module.
+
+ See usage in
+ docs/integrations/vectorstores/google_vertex_ai_vector_search.ipynb.
+
+ Note that this implementation is mostly meant for reading if you are
+ planning to do a real time implementation. While reading is a real time
+ operation, updating the index takes close to one hour.
+
+ Attributes:
+ project_id: The GCS project id.
+ index: The created index class. See
+ ~:func:`MatchingEngine.from_components`.
+ endpoint: The created endpoint class. See
+ ~:func:`MatchingEngine.from_components`.
+ embedding: A :class:`Embeddings` that will be used for
+ embedding the text sent. If none is sent, then the
+ multilingual Tensorflow Universal Sentence Encoder will be used.
+ gcs_client: The GCS client.
+ gcs_bucket_name: The GCS bucket name.
+ credentials (Optional): Created GCP credentials.
+ """
+ super().__init__()
+ self._validate_google_libraries_installation()
+
+ self.project_id = project_id
+ self.index = index
+ self.endpoint = endpoint
+ self.embedding = embedding
+ self.gcs_client = gcs_client
+ self.credentials = credentials
+ self.gcs_bucket_name = gcs_bucket_name
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding
+
+ def _validate_google_libraries_installation(self) -> None:
+ """Validates that Google libraries that are needed are installed."""
+ try:
+ from google.cloud import aiplatform, storage # noqa: F401
+ from google.oauth2 import service_account # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "You must run `pip install --upgrade "
+ "google-cloud-aiplatform google-cloud-storage`"
+ "to use the MatchingEngine Vectorstore."
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ texts = list(texts)
+ if metadatas is not None and len(texts) != len(metadatas):
+ raise ValueError(
+ "texts and metadatas do not have the same length. Received "
+ f"{len(texts)} texts and {len(metadatas)} metadatas."
+ )
+ logger.debug("Embedding documents.")
+ embeddings = self.embedding.embed_documents(texts)
+ jsons = []
+ ids = []
+ # Could be improved with async.
+ for idx, (embedding, text) in enumerate(zip(embeddings, texts)):
+ id = str(uuid.uuid4())
+ ids.append(id)
+ json_: dict = {"id": id, "embedding": embedding}
+ if metadatas is not None:
+ json_["metadata"] = metadatas[idx]
+ jsons.append(json_)
+ self._upload_to_gcs(text, f"documents/{id}")
+
+ logger.debug(f"Uploaded {len(ids)} documents to GCS.")
+
+ # Creating json lines from the embedded documents.
+ result_str = "\n".join([json.dumps(x) for x in jsons])
+
+ filename_prefix = f"indexes/{uuid.uuid4()}"
+ filename = f"{filename_prefix}/{time.time()}.json"
+ self._upload_to_gcs(result_str, filename)
+ logger.debug(
+ f"Uploaded updated json with embeddings to "
+ f"{self.gcs_bucket_name}/{filename}."
+ )
+
+ self.index = self.index.update_embeddings(
+ contents_delta_uri=f"gs://{self.gcs_bucket_name}/{filename_prefix}/"
+ )
+
+ logger.debug("Updated index with new configuration.")
+
+ return ids
+
+ def _upload_to_gcs(self, data: str, gcs_location: str) -> None:
+ """Uploads data to gcs_location.
+
+ Args:
+ data: The data that will be stored.
+ gcs_location: The location where the data will be stored.
+ """
+ bucket = self.gcs_client.get_bucket(self.gcs_bucket_name)
+ blob = bucket.blob(gcs_location)
+ blob.upload_from_string(data)
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[List[Namespace]] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query and their cosine distance from the query.
+
+ Args:
+ query: String query look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Optional. A list of Namespaces for filtering
+ the matching results.
+ For example:
+ [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])]
+ will match datapoints that satisfy "red color" but not include
+ datapoints with "squared shape". Please refer to
+ https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json
+ for more detail.
+
+ Returns:
+ List[Tuple[Document, float]]: List of documents most similar to
+ the query text and cosine distance in float for each.
+ Lower score represents more similarity.
+ """
+ logger.debug(f"Embedding query {query}.")
+ embedding_query = self.embedding.embed_query(query)
+ return self.similarity_search_by_vector_with_score(
+ embedding_query, k=k, filter=filter
+ )
+
+ def similarity_search_by_vector_with_score(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[List[Namespace]] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to the embedding and their cosine distance.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Optional. A list of Namespaces for filtering
+ the matching results.
+ For example:
+ [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])]
+ will match datapoints that satisfy "red color" but not include
+ datapoints with "squared shape". Please refer to
+ https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json
+ for more detail.
+
+ Returns:
+ List[Tuple[Document, float]]: List of documents most similar to
+ the query text and cosine distance in float for each.
+ Lower score represents more similarity.
+ """
+ filter = filter or []
+
+ # If the endpoint is public we use the find_neighbors function.
+ if hasattr(self.endpoint, "_public_match_client") and (
+ self.endpoint._public_match_client
+ ):
+ response = self.endpoint.find_neighbors(
+ deployed_index_id=self._get_index_id(),
+ queries=[embedding],
+ num_neighbors=k,
+ filter=filter,
+ )
+ else:
+ response = self.endpoint.match(
+ deployed_index_id=self._get_index_id(),
+ queries=[embedding],
+ num_neighbors=k,
+ filter=filter,
+ )
+
+ logger.debug(f"Found {len(response)} matches.")
+
+ if len(response) == 0:
+ return []
+
+ results = []
+
+ # I'm only getting the first one because queries receives an array
+ # and the similarity_search method only receives one query. This
+ # means that the match method will always return an array with only
+ # one element.
+ for doc in response[0]:
+ page_content = self._download_from_gcs(f"documents/{doc.id}")
+ results.append((Document(page_content=page_content), doc.distance))
+
+ logger.debug("Downloaded documents for query.")
+
+ return results
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[List[Namespace]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: The string that will be used to search for similar documents.
+ k: The amount of neighbors that will be retrieved.
+ filter: Optional. A list of Namespaces for filtering the matching results.
+ For example:
+ [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])]
+ will match datapoints that satisfy "red color" but not include
+ datapoints with "squared shape". Please refer to
+ https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json
+ for more detail.
+
+ Returns:
+ A list of k matching documents.
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query, k=k, filter=filter, **kwargs
+ )
+
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[List[Namespace]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to the embedding.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: The amount of neighbors that will be retrieved.
+ filter: Optional. A list of Namespaces for filtering the matching results.
+ For example:
+ [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])]
+ will match datapoints that satisfy "red color" but not include
+ datapoints with "squared shape". Please refer to
+ https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json
+ for more detail.
+
+ Returns:
+ A list of k matching documents.
+ """
+ docs_and_scores = self.similarity_search_by_vector_with_score(
+ embedding, k=k, filter=filter, **kwargs
+ )
+
+ return [doc for doc, _ in docs_and_scores]
+
+ def _get_index_id(self) -> str:
+ """Gets the correct index id for the endpoint.
+
+ Returns:
+ The index id if found (which should be found) or throws
+ ValueError otherwise.
+ """
+ for index in self.endpoint.deployed_indexes:
+ if index.index == self.index.resource_name:
+ return index.id
+
+ raise ValueError(
+ f"No index with id {self.index.resource_name} "
+ f"deployed on endpoint "
+ f"{self.endpoint.display_name}."
+ )
+
+ def _download_from_gcs(self, gcs_location: str) -> str:
+ """Downloads from GCS in text format.
+
+ Args:
+ gcs_location: The location where the file is located.
+
+ Returns:
+ The string contents of the file.
+ """
+ bucket = self.gcs_client.get_bucket(self.gcs_bucket_name)
+ blob = bucket.blob(gcs_location)
+ return blob.download_as_string()
+
+ @classmethod
+ def from_texts(
+ cls: Type["MatchingEngine"],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> "MatchingEngine":
+ """Use from components instead."""
+ raise NotImplementedError(
+ "This method is not implemented. Instead, you should initialize the class"
+ " with `MatchingEngine.from_components(...)` and then call "
+ "`add_texts`"
+ )
+
+ @classmethod
+ def from_components(
+ cls: Type["MatchingEngine"],
+ project_id: str,
+ region: str,
+ gcs_bucket_name: str,
+ index_id: str,
+ endpoint_id: str,
+ credentials_path: Optional[str] = None,
+ embedding: Optional[Embeddings] = None,
+ ) -> "MatchingEngine":
+ """Takes the object creation out of the constructor.
+
+ Args:
+ project_id: The GCP project id.
+ region: The default location making the API calls. It must have
+ the same location as the GCS bucket and must be regional.
+ gcs_bucket_name: The location where the vectors will be stored in
+ order for the index to be created.
+ index_id: The id of the created index.
+ endpoint_id: The id of the created endpoint.
+ credentials_path: (Optional) The path of the Google credentials on
+ the local file system.
+ embedding: The :class:`Embeddings` that will be used for
+ embedding the texts.
+
+ Returns:
+ A configured MatchingEngine with the texts added to the index.
+ """
+ gcs_bucket_name = cls._validate_gcs_bucket(gcs_bucket_name)
+ credentials = cls._create_credentials_from_file(credentials_path)
+ index = cls._create_index_by_id(index_id, project_id, region, credentials)
+ endpoint = cls._create_endpoint_by_id(
+ endpoint_id, project_id, region, credentials
+ )
+
+ gcs_client = cls._get_gcs_client(credentials, project_id)
+ cls._init_aiplatform(project_id, region, gcs_bucket_name, credentials)
+
+ return cls(
+ project_id=project_id,
+ index=index,
+ endpoint=endpoint,
+ embedding=embedding or cls._get_default_embeddings(),
+ gcs_client=gcs_client,
+ credentials=credentials,
+ gcs_bucket_name=gcs_bucket_name,
+ )
+
+ @classmethod
+ def _validate_gcs_bucket(cls, gcs_bucket_name: str) -> str:
+ """Validates the gcs_bucket_name as a bucket name.
+
+ Args:
+ gcs_bucket_name: The received bucket uri.
+
+ Returns:
+ A valid gcs_bucket_name or throws ValueError if full path is
+ provided.
+ """
+ gcs_bucket_name = gcs_bucket_name.replace("gs://", "")
+ if "/" in gcs_bucket_name:
+ raise ValueError(
+ f"The argument gcs_bucket_name should only be "
+ f"the bucket name. Received {gcs_bucket_name}"
+ )
+ return gcs_bucket_name
+
+ @classmethod
+ def _create_credentials_from_file(
+ cls, json_credentials_path: Optional[str]
+ ) -> Optional[Credentials]:
+ """Creates credentials for GCP.
+
+ Args:
+ json_credentials_path: The path on the file system where the
+ credentials are stored.
+
+ Returns:
+ An optional of Credentials or None, in which case the default
+ will be used.
+ """
+
+ from google.oauth2 import service_account
+
+ credentials = None
+ if json_credentials_path is not None:
+ credentials = service_account.Credentials.from_service_account_file(
+ json_credentials_path
+ )
+
+ return credentials
+
+ @classmethod
+ def _create_index_by_id(
+ cls, index_id: str, project_id: str, region: str, credentials: "Credentials"
+ ) -> MatchingEngineIndex:
+ """Creates a MatchingEngineIndex object by id.
+
+ Args:
+ index_id: The created index id.
+ project_id: The project to retrieve index from.
+ region: Location to retrieve index from.
+ credentials: GCS credentials.
+
+ Returns:
+ A configured MatchingEngineIndex.
+ """
+
+ from google.cloud import aiplatform
+
+ logger.debug(f"Creating matching engine index with id {index_id}.")
+ return aiplatform.MatchingEngineIndex(
+ index_name=index_id,
+ project=project_id,
+ location=region,
+ credentials=credentials,
+ )
+
+ @classmethod
+ def _create_endpoint_by_id(
+ cls, endpoint_id: str, project_id: str, region: str, credentials: "Credentials"
+ ) -> MatchingEngineIndexEndpoint:
+ """Creates a MatchingEngineIndexEndpoint object by id.
+
+ Args:
+ endpoint_id: The created endpoint id.
+ project_id: The project to retrieve index from.
+ region: Location to retrieve index from.
+ credentials: GCS credentials.
+
+ Returns:
+ A configured MatchingEngineIndexEndpoint.
+ """
+
+ from google.cloud import aiplatform
+
+ logger.debug(f"Creating endpoint with id {endpoint_id}.")
+ return aiplatform.MatchingEngineIndexEndpoint(
+ index_endpoint_name=endpoint_id,
+ project=project_id,
+ location=region,
+ credentials=credentials,
+ )
+
+ @classmethod
+ def _get_gcs_client(
+ cls, credentials: "Credentials", project_id: str
+ ) -> "storage.Client":
+ """Lazily creates a GCS client.
+
+ Returns:
+ A configured GCS client.
+ """
+
+ from google.cloud import storage
+
+ return storage.Client(
+ credentials=credentials,
+ project=project_id,
+ client_info=get_client_info(module="vertex-ai-matching-engine"),
+ )
+
+ @classmethod
+ def _init_aiplatform(
+ cls,
+ project_id: str,
+ region: str,
+ gcs_bucket_name: str,
+ credentials: "Credentials",
+ ) -> None:
+ """Configures the aiplatform library.
+
+ Args:
+ project_id: The GCP project id.
+ region: The default location making the API calls. It must have
+ the same location as the GCS bucket and must be regional.
+ gcs_bucket_name: GCS staging location.
+ credentials: The GCS Credentials object.
+ """
+
+ from google.cloud import aiplatform
+
+ logger.debug(
+ f"Initializing AI Platform for project {project_id} on "
+ f"{region} and for {gcs_bucket_name}."
+ )
+ aiplatform.init(
+ project=project_id,
+ location=region,
+ staging_bucket=gcs_bucket_name,
+ credentials=credentials,
+ )
+
+ @classmethod
+ def _get_default_embeddings(cls) -> "TensorflowHubEmbeddings":
+ """This function returns the default embedding.
+
+ Returns:
+ Default TensorflowHubEmbeddings to use.
+ """
+
+ from langchain_community.embeddings import TensorflowHubEmbeddings
+
+ return TensorflowHubEmbeddings()
diff --git a/libs/community/langchain_community/vectorstores/meilisearch.py b/libs/community/langchain_community/vectorstores/meilisearch.py
new file mode 100644
index 00000000000..b34a990cce2
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/meilisearch.py
@@ -0,0 +1,311 @@
+from __future__ import annotations
+
+import uuid
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_env
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ from meilisearch import Client
+
+
+def _create_client(
+ client: Optional[Client] = None,
+ url: Optional[str] = None,
+ api_key: Optional[str] = None,
+) -> Client:
+ try:
+ import meilisearch
+ except ImportError:
+ raise ImportError(
+ "Could not import meilisearch python package. "
+ "Please install it with `pip install meilisearch`."
+ )
+ if not client:
+ url = url or get_from_env("url", "MEILI_HTTP_ADDR")
+ try:
+ api_key = api_key or get_from_env("api_key", "MEILI_MASTER_KEY")
+ except Exception:
+ pass
+ client = meilisearch.Client(url=url, api_key=api_key)
+ elif not isinstance(client, meilisearch.Client):
+ raise ValueError(
+ f"client should be an instance of meilisearch.Client, "
+ f"got {type(client)}"
+ )
+ try:
+ client.version()
+ except ValueError as e:
+ raise ValueError(f"Failed to connect to Meilisearch: {e}")
+ return client
+
+
+class Meilisearch(VectorStore):
+ """`Meilisearch` vector store.
+
+ To use this, you need to have `meilisearch` python package installed,
+ and a running Meilisearch instance.
+
+ To learn more about Meilisearch Python, refer to the in-depth
+ Meilisearch Python documentation: https://meilisearch.github.io/meilisearch-python/.
+
+ See the following documentation for how to run a Meilisearch instance:
+ https://www.meilisearch.com/docs/learn/getting_started/quick_start.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Meilisearch
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ import meilisearch
+
+ # api_key is optional; provide it if your meilisearch instance requires it
+ client = meilisearch.Client(url='http://127.0.0.1:7700', api_key='***')
+ embeddings = OpenAIEmbeddings()
+ vectorstore = Meilisearch(
+ embedding=embeddings,
+ client=client,
+ index_name='langchain_demo',
+ text_key='text')
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ client: Optional[Client] = None,
+ url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ index_name: str = "langchain-demo",
+ text_key: str = "text",
+ metadata_key: str = "metadata",
+ ):
+ """Initialize with Meilisearch client."""
+ client = _create_client(client=client, url=url, api_key=api_key)
+
+ self._client = client
+ self._index_name = index_name
+ self._embedding = embedding
+ self._text_key = text_key
+ self._metadata_key = metadata_key
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embedding and add them to the vector store.
+
+ Args:
+ texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
+ metadatas (Optional[List[dict]]): Optional list of metadata.
+ Defaults to None.
+ ids Optional[List[str]]: Optional list of IDs.
+ Defaults to None.
+
+ Returns:
+ List[str]: List of IDs of the texts added to the vectorstore.
+ """
+ texts = list(texts)
+
+ # Embed and create the documents
+ docs = []
+ if ids is None:
+ ids = [uuid.uuid4().hex for _ in texts]
+ if metadatas is None:
+ metadatas = [{} for _ in texts]
+ embedding_vectors = self._embedding.embed_documents(texts)
+
+ for i, text in enumerate(texts):
+ id = ids[i]
+ metadata = metadatas[i]
+ metadata[self._text_key] = text
+ embedding = embedding_vectors[i]
+ docs.append(
+ {
+ "id": id,
+ "_vectors": embedding,
+ f"{self._metadata_key}": metadata,
+ }
+ )
+
+ # Send to Meilisearch
+ self._client.index(str(self._index_name)).add_documents(docs)
+ return ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return meilisearch documents most similar to the query.
+
+ Args:
+ query (str): Query text for which to find similar documents.
+ k (int): Number of documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata.
+ Defaults to None.
+
+ Returns:
+ List[Document]: List of Documents most similar to the query
+ text and score for each.
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query=query,
+ k=k,
+ filter=filter,
+ kwargs=kwargs,
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return meilisearch documents most similar to the query, along with scores.
+
+ Args:
+ query (str): Query text for which to find similar documents.
+ k (int): Number of documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata.
+ Defaults to None.
+
+ Returns:
+ List[Document]: List of Documents most similar to the query
+ text and score for each.
+ """
+ _query = self._embedding.embed_query(query)
+
+ docs = self.similarity_search_by_vector_with_scores(
+ embedding=_query,
+ k=k,
+ filter=filter,
+ kwargs=kwargs,
+ )
+ return docs
+
+ def similarity_search_by_vector_with_scores(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return meilisearch documents most similar to embedding vector.
+
+ Args:
+ embedding (List[float]): Embedding to look up similar documents.
+ k (int): Number of documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata.
+ Defaults to None.
+
+ Returns:
+ List[Document]: List of Documents most similar to the query
+ vector and score for each.
+ """
+ docs = []
+ results = self._client.index(str(self._index_name)).search(
+ "", {"vector": embedding, "limit": k, "filter": filter}
+ )
+
+ for result in results["hits"]:
+ metadata = result[self._metadata_key]
+ if self._text_key in metadata:
+ text = metadata.pop(self._text_key)
+ semantic_score = result["_semanticScore"]
+ docs.append(
+ (Document(page_content=text, metadata=metadata), semantic_score)
+ )
+
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return meilisearch documents most similar to embedding vector.
+
+ Args:
+ embedding (List[float]): Embedding to look up similar documents.
+ k (int): Number of documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata.
+ Defaults to None.
+
+ Returns:
+ List[Document]: List of Documents most similar to the query
+ vector and score for each.
+ """
+ docs = self.similarity_search_by_vector_with_scores(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ kwargs=kwargs,
+ )
+ return [doc for doc, _ in docs]
+
+ @classmethod
+ def from_texts(
+ cls: Type[Meilisearch],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ client: Optional[Client] = None,
+ url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ index_name: str = "langchain-demo",
+ ids: Optional[List[str]] = None,
+ text_key: Optional[str] = "text",
+ metadata_key: Optional[str] = "metadata",
+ **kwargs: Any,
+ ) -> Meilisearch:
+ """Construct Meilisearch wrapper from raw documents.
+
+ This is a user-friendly interface that:
+ 1. Embeds documents.
+ 2. Adds the documents to a provided Meilisearch index.
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Meilisearch
+ from langchain_community.embeddings import OpenAIEmbeddings
+ import meilisearch
+
+ # The environment should be the one specified next to the API key
+ # in your Meilisearch console
+ client = meilisearch.Client(url='http://127.0.0.1:7700', api_key='***')
+ embeddings = OpenAIEmbeddings()
+ docsearch = Meilisearch.from_texts(
+ client=client,
+ embeddings=embeddings,
+ )
+ """
+ client = _create_client(client=client, url=url, api_key=api_key)
+
+ vectorstore = cls(
+ embedding=embedding,
+ client=client,
+ index_name=index_name,
+ )
+ vectorstore.add_texts(
+ texts=texts,
+ metadatas=metadatas,
+ ids=ids,
+ text_key=text_key,
+ metadata_key=metadata_key,
+ )
+ return vectorstore
diff --git a/libs/community/langchain_community/vectorstores/milvus.py b/libs/community/langchain_community/vectorstores/milvus.py
new file mode 100644
index 00000000000..213e86c7b7a
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/milvus.py
@@ -0,0 +1,828 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Iterable, List, Optional, Tuple, Union
+from uuid import uuid4
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_MILVUS_CONNECTION = {
+ "host": "localhost",
+ "port": "19530",
+ "user": "",
+ "password": "",
+ "secure": False,
+}
+
+
+class Milvus(VectorStore):
+ """`Milvus` vector store.
+
+ You need to install `pymilvus` and run Milvus.
+
+ See the following documentation for how to run a Milvus instance:
+ https://milvus.io/docs/install_standalone-docker.md
+
+ If looking for a hosted Milvus, take a look at this documentation:
+ https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
+ this project.
+
+ IF USING L2/IP metric, IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
+
+ Args:
+ embedding_function (Embeddings): Function used to embed the text.
+ collection_name (str): Which Milvus collection to use. Defaults to
+ "LangChainCollection".
+ connection_args (Optional[dict[str, any]]): The connection args used for
+ this class comes in the form of a dict.
+ consistency_level (str): The consistency level to use for a collection.
+ Defaults to "Session".
+ index_params (Optional[dict]): Which index params to use. Defaults to
+ HNSW/AUTOINDEX depending on service.
+ search_params (Optional[dict]): Which search params to use. Defaults to
+ default of index.
+ drop_old (Optional[bool]): Whether to drop the current collection. Defaults
+ to False.
+ primary_field (str): Name of the primary key field. Defaults to "pk".
+ text_field (str): Name of the text field. Defaults to "text".
+ vector_field (str): Name of the vector field. Defaults to "vector".
+
+ The connection args used for this class comes in the form of a dict,
+ here are a few of the options:
+ address (str): The actual address of Milvus
+ instance. Example address: "localhost:19530"
+ uri (str): The uri of Milvus instance. Example uri:
+ "http://randomwebsite:19530",
+ "tcp:foobarsite:19530",
+ "https://ok.s3.south.com:19530".
+ host (str): The host of Milvus instance. Default at "localhost",
+ PyMilvus will fill in the default host if only port is provided.
+ port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
+ will fill in the default port if only host is provided.
+ user (str): Use which user to connect to Milvus instance. If user and
+ password are provided, we will add related header in every RPC call.
+ password (str): Required when user is provided. The password
+ corresponding to the user.
+ secure (bool): Default is false. If set to true, tls will be enabled.
+ client_key_path (str): If use tls two-way authentication, need to
+ write the client.key path.
+ client_pem_path (str): If use tls two-way authentication, need to
+ write the client.pem path.
+ ca_pem_path (str): If use tls two-way authentication, need to write
+ the ca.pem path.
+ server_pem_path (str): If use tls one-way authentication, need to
+ write the server.pem path.
+ server_name (str): If use tls, need to write the common name.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Milvus
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embedding = OpenAIEmbeddings()
+ # Connect to a milvus instance on localhost
+ milvus_store = Milvus(
+ embedding_function = Embeddings,
+ collection_name = "LangChainCollection",
+ drop_old = True,
+ )
+
+ Raises:
+ ValueError: If the pymilvus python package is not installed.
+ """
+
+ def __init__(
+ self,
+ embedding_function: Embeddings,
+ collection_name: str = "LangChainCollection",
+ connection_args: Optional[dict[str, Any]] = None,
+ consistency_level: str = "Session",
+ index_params: Optional[dict] = None,
+ search_params: Optional[dict] = None,
+ drop_old: Optional[bool] = False,
+ *,
+ primary_field: str = "pk",
+ text_field: str = "text",
+ vector_field: str = "vector",
+ ):
+ """Initialize the Milvus vector store."""
+ try:
+ from pymilvus import Collection, utility
+ except ImportError:
+ raise ValueError(
+ "Could not import pymilvus python package. "
+ "Please install it with `pip install pymilvus`."
+ )
+
+ # Default search params when one is not provided.
+ self.default_search_params = {
+ "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
+ "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
+ "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
+ "HNSW": {"metric_type": "L2", "params": {"ef": 10}},
+ "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
+ "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
+ "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
+ "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
+ "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
+ "AUTOINDEX": {"metric_type": "L2", "params": {}},
+ }
+
+ self.embedding_func = embedding_function
+ self.collection_name = collection_name
+ self.index_params = index_params
+ self.search_params = search_params
+ self.consistency_level = consistency_level
+
+ # In order for a collection to be compatible, pk needs to be auto'id and int
+ self._primary_field = primary_field
+ # In order for compatibility, the text field will need to be called "text"
+ self._text_field = text_field
+ # In order for compatibility, the vector field needs to be called "vector"
+ self._vector_field = vector_field
+ self.fields: list[str] = []
+ # Create the connection to the server
+ if connection_args is None:
+ connection_args = DEFAULT_MILVUS_CONNECTION
+ self.alias = self._create_connection_alias(connection_args)
+ self.col: Optional[Collection] = None
+
+ # Grab the existing collection if it exists
+ if utility.has_collection(self.collection_name, using=self.alias):
+ self.col = Collection(
+ self.collection_name,
+ using=self.alias,
+ )
+ # If need to drop old, drop it
+ if drop_old and isinstance(self.col, Collection):
+ self.col.drop()
+ self.col = None
+
+ # Initialize the vector store
+ self._init()
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_func
+
+ def _create_connection_alias(self, connection_args: dict) -> str:
+ """Create the connection to the Milvus server."""
+ from pymilvus import MilvusException, connections
+
+ # Grab the connection arguments that are used for checking existing connection
+ host: str = connection_args.get("host", None)
+ port: Union[str, int] = connection_args.get("port", None)
+ address: str = connection_args.get("address", None)
+ uri: str = connection_args.get("uri", None)
+ user = connection_args.get("user", None)
+
+ # Order of use is host/port, uri, address
+ if host is not None and port is not None:
+ given_address = str(host) + ":" + str(port)
+ elif uri is not None:
+ given_address = uri.split("https://")[1]
+ elif address is not None:
+ given_address = address
+ else:
+ given_address = None
+ logger.debug("Missing standard address type for reuse attempt")
+
+ # User defaults to empty string when getting connection info
+ if user is not None:
+ tmp_user = user
+ else:
+ tmp_user = ""
+
+ # If a valid address was given, then check if a connection exists
+ if given_address is not None:
+ for con in connections.list_connections():
+ addr = connections.get_connection_addr(con[0])
+ if (
+ con[1]
+ and ("address" in addr)
+ and (addr["address"] == given_address)
+ and ("user" in addr)
+ and (addr["user"] == tmp_user)
+ ):
+ logger.debug("Using previous connection: %s", con[0])
+ return con[0]
+
+ # Generate a new connection if one doesn't exist
+ alias = uuid4().hex
+ try:
+ connections.connect(alias=alias, **connection_args)
+ logger.debug("Created new connection using: %s", alias)
+ return alias
+ except MilvusException as e:
+ logger.error("Failed to create new connection using: %s", alias)
+ raise e
+
+ def _init(
+ self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
+ ) -> None:
+ if embeddings is not None:
+ self._create_collection(embeddings, metadatas)
+ self._extract_fields()
+ self._create_index()
+ self._create_search_params()
+ self._load()
+
+ def _create_collection(
+ self, embeddings: list, metadatas: Optional[list[dict]] = None
+ ) -> None:
+ from pymilvus import (
+ Collection,
+ CollectionSchema,
+ DataType,
+ FieldSchema,
+ MilvusException,
+ )
+ from pymilvus.orm.types import infer_dtype_bydata
+
+ # Determine embedding dim
+ dim = len(embeddings[0])
+ fields = []
+ # Determine metadata schema
+ if metadatas:
+ # Create FieldSchema for each entry in metadata.
+ for key, value in metadatas[0].items():
+ # Infer the corresponding datatype of the metadata
+ dtype = infer_dtype_bydata(value)
+ # Datatype isn't compatible
+ if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
+ logger.error(
+ "Failure to create collection, unrecognized dtype for key: %s",
+ key,
+ )
+ raise ValueError(f"Unrecognized datatype for {key}.")
+ # Dataype is a string/varchar equivalent
+ elif dtype == DataType.VARCHAR:
+ fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
+ else:
+ fields.append(FieldSchema(key, dtype))
+
+ # Create the text field
+ fields.append(
+ FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
+ )
+ # Create the primary key field
+ fields.append(
+ FieldSchema(
+ self._primary_field, DataType.INT64, is_primary=True, auto_id=True
+ )
+ )
+ # Create the vector field, supports binary or float vectors
+ fields.append(
+ FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
+ )
+
+ # Create the schema for the collection
+ schema = CollectionSchema(fields)
+
+ # Create the collection
+ try:
+ self.col = Collection(
+ name=self.collection_name,
+ schema=schema,
+ consistency_level=self.consistency_level,
+ using=self.alias,
+ )
+ except MilvusException as e:
+ logger.error(
+ "Failed to create collection: %s error: %s", self.collection_name, e
+ )
+ raise e
+
+ def _extract_fields(self) -> None:
+ """Grab the existing fields from the Collection"""
+ from pymilvus import Collection
+
+ if isinstance(self.col, Collection):
+ schema = self.col.schema
+ for x in schema.fields:
+ self.fields.append(x.name)
+ # Since primary field is auto-id, no need to track it
+ self.fields.remove(self._primary_field)
+
+ def _get_index(self) -> Optional[dict[str, Any]]:
+ """Return the vector index information if it exists"""
+ from pymilvus import Collection
+
+ if isinstance(self.col, Collection):
+ for x in self.col.indexes:
+ if x.field_name == self._vector_field:
+ return x.to_dict()
+ return None
+
+ def _create_index(self) -> None:
+ """Create a index on the collection"""
+ from pymilvus import Collection, MilvusException
+
+ if isinstance(self.col, Collection) and self._get_index() is None:
+ try:
+ # If no index params, use a default HNSW based one
+ if self.index_params is None:
+ self.index_params = {
+ "metric_type": "L2",
+ "index_type": "HNSW",
+ "params": {"M": 8, "efConstruction": 64},
+ }
+
+ try:
+ self.col.create_index(
+ self._vector_field,
+ index_params=self.index_params,
+ using=self.alias,
+ )
+
+ # If default did not work, most likely on Zilliz Cloud
+ except MilvusException:
+ # Use AUTOINDEX based index
+ self.index_params = {
+ "metric_type": "L2",
+ "index_type": "AUTOINDEX",
+ "params": {},
+ }
+ self.col.create_index(
+ self._vector_field,
+ index_params=self.index_params,
+ using=self.alias,
+ )
+ logger.debug(
+ "Successfully created an index on collection: %s",
+ self.collection_name,
+ )
+
+ except MilvusException as e:
+ logger.error(
+ "Failed to create an index on collection: %s", self.collection_name
+ )
+ raise e
+
+ def _create_search_params(self) -> None:
+ """Generate search params based on the current index type"""
+ from pymilvus import Collection
+
+ if isinstance(self.col, Collection) and self.search_params is None:
+ index = self._get_index()
+ if index is not None:
+ index_type: str = index["index_param"]["index_type"]
+ metric_type: str = index["index_param"]["metric_type"]
+ self.search_params = self.default_search_params[index_type]
+ self.search_params["metric_type"] = metric_type
+
+ def _load(self) -> None:
+ """Load the collection if available."""
+ from pymilvus import Collection
+
+ if isinstance(self.col, Collection) and self._get_index() is not None:
+ self.col.load()
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ timeout: Optional[int] = None,
+ batch_size: int = 1000,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Insert text data into Milvus.
+
+ Inserting data when the collection has not be made yet will result
+ in creating a new Collection. The data of the first entity decides
+ the schema of the new collection, the dim is extracted from the first
+ embedding and the columns are decided by the first metadata dict.
+ Metada keys will need to be present for all inserted values. At
+ the moment there is no None equivalent in Milvus.
+
+ Args:
+ texts (Iterable[str]): The texts to embed, it is assumed
+ that they all fit in memory.
+ metadatas (Optional[List[dict]]): Metadata dicts attached to each of
+ the texts. Defaults to None.
+ timeout (Optional[int]): Timeout for each batch insert. Defaults
+ to None.
+ batch_size (int, optional): Batch size to use for insertion.
+ Defaults to 1000.
+
+ Raises:
+ MilvusException: Failure to add texts
+
+ Returns:
+ List[str]: The resulting keys for each inserted element.
+ """
+ from pymilvus import Collection, MilvusException
+
+ texts = list(texts)
+
+ try:
+ embeddings = self.embedding_func.embed_documents(texts)
+ except NotImplementedError:
+ embeddings = [self.embedding_func.embed_query(x) for x in texts]
+
+ if len(embeddings) == 0:
+ logger.debug("Nothing to insert, skipping.")
+ return []
+
+ # If the collection hasn't been initialized yet, perform all steps to do so
+ if not isinstance(self.col, Collection):
+ self._init(embeddings, metadatas)
+
+ # Dict to hold all insert columns
+ insert_dict: dict[str, list] = {
+ self._text_field: texts,
+ self._vector_field: embeddings,
+ }
+
+ # Collect the metadata into the insert dict.
+ if metadatas is not None:
+ for d in metadatas:
+ for key, value in d.items():
+ if key in self.fields:
+ insert_dict.setdefault(key, []).append(value)
+
+ # Total insert count
+ vectors: list = insert_dict[self._vector_field]
+ total_count = len(vectors)
+
+ pks: list[str] = []
+
+ assert isinstance(self.col, Collection)
+ for i in range(0, total_count, batch_size):
+ # Grab end index
+ end = min(i + batch_size, total_count)
+ # Convert dict to list of lists batch for insertion
+ insert_list = [insert_dict[x][i:end] for x in self.fields]
+ # Insert into the collection.
+ try:
+ res: Collection
+ res = self.col.insert(insert_list, timeout=timeout, **kwargs)
+ pks.extend(res.primary_keys)
+ except MilvusException as e:
+ logger.error(
+ "Failed to insert batch starting at entity: %s/%s", i, total_count
+ )
+ raise e
+ return pks
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a similarity search against the query string.
+
+ Args:
+ query (str): The text to search.
+ k (int, optional): How many results to return. Defaults to 4.
+ param (dict, optional): The search params for the index type.
+ Defaults to None.
+ expr (str, optional): Filtering expression. Defaults to None.
+ timeout (int, optional): How long to wait before timeout error.
+ Defaults to None.
+ kwargs: Collection.search() keyword arguments.
+
+ Returns:
+ List[Document]: Document results for search.
+ """
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+ res = self.similarity_search_with_score(
+ query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+ )
+ return [doc for doc, _ in res]
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a similarity search against the query string.
+
+ Args:
+ embedding (List[float]): The embedding vector to search.
+ k (int, optional): How many results to return. Defaults to 4.
+ param (dict, optional): The search params for the index type.
+ Defaults to None.
+ expr (str, optional): Filtering expression. Defaults to None.
+ timeout (int, optional): How long to wait before timeout error.
+ Defaults to None.
+ kwargs: Collection.search() keyword arguments.
+
+ Returns:
+ List[Document]: Document results for search.
+ """
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+ res = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+ )
+ return [doc for doc, _ in res]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Perform a search on a query string and return results with score.
+
+ For more information about the search parameters, take a look at the pymilvus
+ documentation found here:
+ https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
+
+ Args:
+ query (str): The text being searched.
+ k (int, optional): The amount of results to return. Defaults to 4.
+ param (dict): The search params for the specified index.
+ Defaults to None.
+ expr (str, optional): Filtering expression. Defaults to None.
+ timeout (int, optional): How long to wait before timeout error.
+ Defaults to None.
+ kwargs: Collection.search() keyword arguments.
+
+ Returns:
+ List[float], List[Tuple[Document, any, any]]:
+ """
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+
+ # Embed the query text.
+ embedding = self.embedding_func.embed_query(query)
+
+ res = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+ )
+ return res
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Perform a search on a query string and return results with score.
+
+ For more information about the search parameters, take a look at the pymilvus
+ documentation found here:
+ https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
+
+ Args:
+ embedding (List[float]): The embedding vector being searched.
+ k (int, optional): The amount of results to return. Defaults to 4.
+ param (dict): The search params for the specified index.
+ Defaults to None.
+ expr (str, optional): Filtering expression. Defaults to None.
+ timeout (int, optional): How long to wait before timeout error.
+ Defaults to None.
+ kwargs: Collection.search() keyword arguments.
+
+ Returns:
+ List[Tuple[Document, float]]: Result doc and score.
+ """
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+
+ if param is None:
+ param = self.search_params
+
+ # Determine result metadata fields.
+ output_fields = self.fields[:]
+ output_fields.remove(self._vector_field)
+
+ # Perform the search.
+ res = self.col.search(
+ data=[embedding],
+ anns_field=self._vector_field,
+ param=param,
+ limit=k,
+ expr=expr,
+ output_fields=output_fields,
+ timeout=timeout,
+ **kwargs,
+ )
+ # Organize results.
+ ret = []
+ for result in res[0]:
+ meta = {x: result.entity.get(x) for x in output_fields}
+ doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
+ pair = (doc, result.score)
+ ret.append(pair)
+
+ return ret
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a search and return results that are reordered by MMR.
+
+ Args:
+ query (str): The text being searched.
+ k (int, optional): How many results to give. Defaults to 4.
+ fetch_k (int, optional): Total results to select k from.
+ Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5
+ param (dict, optional): The search params for the specified index.
+ Defaults to None.
+ expr (str, optional): Filtering expression. Defaults to None.
+ timeout (int, optional): How long to wait before timeout error.
+ Defaults to None.
+ kwargs: Collection.search() keyword arguments.
+
+
+ Returns:
+ List[Document]: Document results for search.
+ """
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+
+ embedding = self.embedding_func.embed_query(query)
+
+ return self.max_marginal_relevance_search_by_vector(
+ embedding=embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ param=param,
+ expr=expr,
+ timeout=timeout,
+ **kwargs,
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: list[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a search and return results that are reordered by MMR.
+
+ Args:
+ embedding (str): The embedding vector being searched.
+ k (int, optional): How many results to give. Defaults to 4.
+ fetch_k (int, optional): Total results to select k from.
+ Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5
+ param (dict, optional): The search params for the specified index.
+ Defaults to None.
+ expr (str, optional): Filtering expression. Defaults to None.
+ timeout (int, optional): How long to wait before timeout error.
+ Defaults to None.
+ kwargs: Collection.search() keyword arguments.
+
+ Returns:
+ List[Document]: Document results for search.
+ """
+ if self.col is None:
+ logger.debug("No existing collection to search.")
+ return []
+
+ if param is None:
+ param = self.search_params
+
+ # Determine result metadata fields.
+ output_fields = self.fields[:]
+ output_fields.remove(self._vector_field)
+
+ # Perform the search.
+ res = self.col.search(
+ data=[embedding],
+ anns_field=self._vector_field,
+ param=param,
+ limit=fetch_k,
+ expr=expr,
+ output_fields=output_fields,
+ timeout=timeout,
+ **kwargs,
+ )
+ # Organize results.
+ ids = []
+ documents = []
+ scores = []
+ for result in res[0]:
+ meta = {x: result.entity.get(x) for x in output_fields}
+ doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
+ documents.append(doc)
+ scores.append(result.score)
+ ids.append(result.id)
+
+ vectors = self.col.query(
+ expr=f"{self._primary_field} in {ids}",
+ output_fields=[self._primary_field, self._vector_field],
+ timeout=timeout,
+ )
+ # Reorganize the results from query to match search order.
+ vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
+
+ ordered_result_embeddings = [vectors[x] for x in ids]
+
+ # Get the new order of results.
+ new_ordering = maximal_marginal_relevance(
+ np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
+ )
+
+ # Reorder the values and return.
+ ret = []
+ for x in new_ordering:
+ # Function can return -1 index
+ if x == -1:
+ break
+ else:
+ ret.append(documents[x])
+ return ret
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = "LangChainCollection",
+ connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
+ consistency_level: str = "Session",
+ index_params: Optional[dict] = None,
+ search_params: Optional[dict] = None,
+ drop_old: bool = False,
+ **kwargs: Any,
+ ) -> Milvus:
+ """Create a Milvus collection, indexes it with HNSW, and insert data.
+
+ Args:
+ texts (List[str]): Text data.
+ embedding (Embeddings): Embedding function.
+ metadatas (Optional[List[dict]]): Metadata for each text if it exists.
+ Defaults to None.
+ collection_name (str, optional): Collection name to use. Defaults to
+ "LangChainCollection".
+ connection_args (dict[str, Any], optional): Connection args to use. Defaults
+ to DEFAULT_MILVUS_CONNECTION.
+ consistency_level (str, optional): Which consistency level to use. Defaults
+ to "Session".
+ index_params (Optional[dict], optional): Which index_params to use. Defaults
+ to None.
+ search_params (Optional[dict], optional): Which search params to use.
+ Defaults to None.
+ drop_old (Optional[bool], optional): Whether to drop the collection with
+ that name if it exists. Defaults to False.
+
+ Returns:
+ Milvus: Milvus Vector Store
+ """
+ vector_db = cls(
+ embedding_function=embedding,
+ collection_name=collection_name,
+ connection_args=connection_args,
+ consistency_level=consistency_level,
+ index_params=index_params,
+ search_params=search_params,
+ drop_old=drop_old,
+ **kwargs,
+ )
+ vector_db.add_texts(texts=texts, metadatas=metadatas)
+ return vector_db
diff --git a/libs/community/langchain_community/vectorstores/momento_vector_index.py b/libs/community/langchain_community/vectorstores/momento_vector_index.py
new file mode 100644
index 00000000000..b8f5b3e5512
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/momento_vector_index.py
@@ -0,0 +1,479 @@
+import logging
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ cast,
+)
+from uuid import uuid4
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_env
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import (
+ DistanceStrategy,
+ maximal_marginal_relevance,
+)
+
+VST = TypeVar("VST", bound="VectorStore")
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from momento import PreviewVectorIndexClient
+
+
+class MomentoVectorIndex(VectorStore):
+ """`Momento Vector Index` (MVI) vector store.
+
+ Momento Vector Index is a serverless vector index that can be used to store and
+ search vectors. To use you should have the ``momento`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import OpenAIEmbeddings
+ from langchain_community.vectorstores import MomentoVectorIndex
+ from momento import (
+ CredentialProvider,
+ PreviewVectorIndexClient,
+ VectorIndexConfigurations,
+ )
+
+ vectorstore = MomentoVectorIndex(
+ embedding=OpenAIEmbeddings(),
+ client=PreviewVectorIndexClient(
+ VectorIndexConfigurations.Default.latest(),
+ credential_provider=CredentialProvider.from_environment_variable(
+ "MOMENTO_API_KEY"
+ ),
+ ),
+ index_name="my-index",
+ )
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ client: "PreviewVectorIndexClient",
+ index_name: str = "default",
+ distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
+ text_field: str = "text",
+ ensure_index_exists: bool = True,
+ **kwargs: Any,
+ ):
+ """Initialize a Vector Store backed by Momento Vector Index.
+
+ Args:
+ embedding (Embeddings): The embedding function to use.
+ configuration (VectorIndexConfiguration): The configuration to initialize
+ the Vector Index with.
+ credential_provider (CredentialProvider): The credential provider to
+ authenticate the Vector Index with.
+ index_name (str, optional): The name of the index to store the documents in.
+ Defaults to "default".
+ distance_strategy (DistanceStrategy, optional): The distance strategy to
+ use. If you select DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses
+ the squared Euclidean distance. Defaults to DistanceStrategy.COSINE.
+ text_field (str, optional): The name of the metadata field to store the
+ original text in. Defaults to "text".
+ ensure_index_exists (bool, optional): Whether to ensure that the index
+ exists before adding documents to it. Defaults to True.
+ """
+ try:
+ from momento import PreviewVectorIndexClient
+ except ImportError:
+ raise ImportError(
+ "Could not import momento python package. "
+ "Please install it with `pip install momento`."
+ )
+
+ self._client: PreviewVectorIndexClient = client
+ self._embedding = embedding
+ self.index_name = index_name
+ self.__validate_distance_strategy(distance_strategy)
+ self.distance_strategy = distance_strategy
+ self.text_field = text_field
+ self._ensure_index_exists = ensure_index_exists
+
+ @staticmethod
+ def __validate_distance_strategy(distance_strategy: DistanceStrategy) -> None:
+ if distance_strategy not in [
+ DistanceStrategy.COSINE,
+ DistanceStrategy.MAX_INNER_PRODUCT,
+ DistanceStrategy.MAX_INNER_PRODUCT,
+ ]:
+ raise ValueError(f"Distance strategy {distance_strategy} not implemented.")
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embedding
+
+ def _create_index_if_not_exists(self, num_dimensions: int) -> bool:
+ """Create index if it does not exist."""
+ from momento.requests.vector_index import SimilarityMetric
+ from momento.responses.vector_index import CreateIndex
+
+ similarity_metric = None
+ if self.distance_strategy == DistanceStrategy.COSINE:
+ similarity_metric = SimilarityMetric.COSINE_SIMILARITY
+ elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ similarity_metric = SimilarityMetric.INNER_PRODUCT
+ elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
+ similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
+ else:
+ logger.error(f"Distance strategy {self.distance_strategy} not implemented.")
+ raise ValueError(
+ f"Distance strategy {self.distance_strategy} not implemented."
+ )
+
+ response = self._client.create_index(
+ self.index_name, num_dimensions, similarity_metric
+ )
+ if isinstance(response, CreateIndex.Success):
+ return True
+ elif isinstance(response, CreateIndex.IndexAlreadyExists):
+ return False
+ elif isinstance(response, CreateIndex.Error):
+ logger.error(f"Error creating index: {response.inner_exception}")
+ raise response.inner_exception
+ else:
+ logger.error(f"Unexpected response: {response}")
+ raise Exception(f"Unexpected response: {response}")
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts (Iterable[str]): Iterable of strings to add to the vectorstore.
+ metadatas (Optional[List[dict]]): Optional list of metadatas associated with
+ the texts.
+ kwargs (Any): Other optional parameters. Specifically:
+ - ids (List[str], optional): List of ids to use for the texts.
+ Defaults to None, in which case uuids are generated.
+
+ Returns:
+ List[str]: List of ids from adding the texts into the vectorstore.
+ """
+ from momento.requests.vector_index import Item
+ from momento.responses.vector_index import UpsertItemBatch
+
+ texts = list(texts)
+
+ if len(texts) == 0:
+ return []
+
+ if metadatas is not None:
+ for metadata, text in zip(metadatas, texts):
+ metadata[self.text_field] = text
+ else:
+ metadatas = [{self.text_field: text} for text in texts]
+
+ try:
+ embeddings = self._embedding.embed_documents(texts)
+ except NotImplementedError:
+ embeddings = [self._embedding.embed_query(x) for x in texts]
+
+ # Create index if it does not exist.
+ # We assume that if it does exist, then it was created with the desired number
+ # of dimensions and similarity metric.
+ if self._ensure_index_exists:
+ self._create_index_if_not_exists(len(embeddings[0]))
+
+ if "ids" in kwargs:
+ ids = kwargs["ids"]
+ if len(ids) != len(embeddings):
+ raise ValueError("Number of ids must match number of texts")
+ else:
+ ids = [str(uuid4()) for _ in range(len(embeddings))]
+
+ batch_size = 128
+ for i in range(0, len(embeddings), batch_size):
+ start = i
+ end = min(i + batch_size, len(embeddings))
+ items = [
+ Item(id=id, vector=vector, metadata=metadata)
+ for id, vector, metadata in zip(
+ ids[start:end],
+ embeddings[start:end],
+ metadatas[start:end],
+ )
+ ]
+
+ response = self._client.upsert_item_batch(self.index_name, items)
+ if isinstance(response, UpsertItemBatch.Success):
+ pass
+ elif isinstance(response, UpsertItemBatch.Error):
+ raise response.inner_exception
+ else:
+ raise Exception(f"Unexpected response: {response}")
+
+ return ids
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete by vector ID.
+
+ Args:
+ ids (List[str]): List of ids to delete.
+ kwargs (Any): Other optional parameters (unused)
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+ from momento.responses.vector_index import DeleteItemBatch
+
+ if ids is None:
+ return True
+ response = self._client.delete_item_batch(self.index_name, ids)
+ return isinstance(response, DeleteItemBatch.Success)
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Search for similar documents to the query string.
+
+ Args:
+ query (str): The query string to search for.
+ k (int, optional): The number of results to return. Defaults to 4.
+
+ Returns:
+ List[Document]: A list of documents that are similar to the query.
+ """
+ res = self.similarity_search_with_score(query=query, k=k, **kwargs)
+ return [doc for doc, _ in res]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Search for similar documents to the query string.
+
+ Args:
+ query (str): The query string to search for.
+ k (int, optional): The number of results to return. Defaults to 4.
+ kwargs (Any): Vector Store specific search parameters. The following are
+ forwarded to the Momento Vector Index:
+ - top_k (int, optional): The number of results to return.
+
+ Returns:
+ List[Tuple[Document, float]]: A list of tuples of the form
+ (Document, score).
+ """
+ embedding = self._embedding.embed_query(query)
+
+ results = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, **kwargs
+ )
+ return results
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Search for similar documents to the query vector.
+
+ Args:
+ embedding (List[float]): The query vector to search for.
+ k (int, optional): The number of results to return. Defaults to 4.
+ kwargs (Any): Vector Store specific search parameters. The following are
+ forwarded to the Momento Vector Index:
+ - top_k (int, optional): The number of results to return.
+
+ Returns:
+ List[Tuple[Document, float]]: A list of tuples of the form
+ (Document, score).
+ """
+ from momento.requests.vector_index import ALL_METADATA
+ from momento.responses.vector_index import Search
+
+ if "top_k" in kwargs:
+ k = kwargs["k"]
+ response = self._client.search(
+ self.index_name, embedding, top_k=k, metadata_fields=ALL_METADATA
+ )
+
+ if not isinstance(response, Search.Success):
+ return []
+
+ results = []
+ for hit in response.hits:
+ text = cast(str, hit.metadata.pop(self.text_field))
+ doc = Document(page_content=text, metadata=hit.metadata)
+ pair = (doc, hit.score)
+ results.append(pair)
+
+ return results
+
+ def similarity_search_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Search for similar documents to the query vector.
+
+ Args:
+ embedding (List[float]): The query vector to search for.
+ k (int, optional): The number of results to return. Defaults to 4.
+
+ Returns:
+ List[Document]: A list of documents that are similar to the query.
+ """
+ results = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, **kwargs
+ )
+ return [doc for doc, _ in results]
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ from momento.requests.vector_index import ALL_METADATA
+ from momento.responses.vector_index import SearchAndFetchVectors
+
+ response = self._client.search_and_fetch_vectors(
+ self.index_name, embedding, top_k=fetch_k, metadata_fields=ALL_METADATA
+ )
+
+ if isinstance(response, SearchAndFetchVectors.Success):
+ pass
+ elif isinstance(response, SearchAndFetchVectors.Error):
+ logger.error(f"Error searching and fetching vectors: {response}")
+ return []
+ else:
+ logger.error(f"Unexpected response: {response}")
+ raise Exception(f"Unexpected response: {response}")
+
+ mmr_selected = maximal_marginal_relevance(
+ query_embedding=np.array([embedding], dtype=np.float32),
+ embedding_list=[hit.vector for hit in response.hits],
+ lambda_mult=lambda_mult,
+ k=k,
+ )
+ selected = [response.hits[i].metadata for i in mmr_selected]
+ return [
+ Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501
+ for metadata in selected
+ ]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding = self._embedding.embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ embedding, k, fetch_k, lambda_mult, **kwargs
+ )
+
+ @classmethod
+ def from_texts(
+ cls: Type[VST],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> VST:
+ """Return the Vector Store initialized from texts and embeddings.
+
+ Args:
+ cls (Type[VST]): The Vector Store class to use to initialize
+ the Vector Store.
+ texts (List[str]): The texts to initialize the Vector Store with.
+ embedding (Embeddings): The embedding function to use.
+ metadatas (Optional[List[dict]], optional): The metadata associated with
+ the texts. Defaults to None.
+ kwargs (Any): Vector Store specific parameters. The following are forwarded
+ to the Vector Store constructor and required:
+ - index_name (str, optional): The name of the index to store the documents
+ in. Defaults to "default".
+ - text_field (str, optional): The name of the metadata field to store the
+ original text in. Defaults to "text".
+ - distance_strategy (DistanceStrategy, optional): The distance strategy to
+ use. Defaults to DistanceStrategy.COSINE. If you select
+ DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared
+ Euclidean distance.
+ - ensure_index_exists (bool, optional): Whether to ensure that the index
+ exists before adding documents to it. Defaults to True.
+ Additionally you can either pass in a client or an API key
+ - client (PreviewVectorIndexClient): The Momento Vector Index client to use.
+ - api_key (Optional[str]): The configuration to use to initialize
+ the Vector Index with. Defaults to None. If None, the configuration
+ is initialized from the environment variable `MOMENTO_API_KEY`.
+
+ Returns:
+ VST: Momento Vector Index vector store initialized from texts and
+ embeddings.
+ """
+ from momento import (
+ CredentialProvider,
+ PreviewVectorIndexClient,
+ VectorIndexConfigurations,
+ )
+
+ if "client" in kwargs:
+ client = kwargs.pop("client")
+ else:
+ supplied_api_key = kwargs.pop("api_key", None)
+ api_key = supplied_api_key or get_from_env("api_key", "MOMENTO_API_KEY")
+ client = PreviewVectorIndexClient(
+ configuration=VectorIndexConfigurations.Default.latest(),
+ credential_provider=CredentialProvider.from_string(api_key),
+ )
+ vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore
+ vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs)
+ return vector_db
diff --git a/libs/community/langchain_community/vectorstores/mongodb_atlas.py b/libs/community/langchain_community/vectorstores/mongodb_atlas.py
new file mode 100644
index 00000000000..61c901940fa
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/mongodb_atlas.py
@@ -0,0 +1,357 @@
+from __future__ import annotations
+
+import logging
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+if TYPE_CHECKING:
+ from pymongo.collection import Collection
+
+MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_INSERT_BATCH_SIZE = 100
+
+
+class MongoDBAtlasVectorSearch(VectorStore):
+ """`MongoDB Atlas Vector Search` vector store.
+
+ To use, you should have both:
+ - the ``pymongo`` python package installed
+ - a connection string associated with a MongoDB Atlas Cluster having deployed an
+ Atlas Search index
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import MongoDBAtlasVectorSearch
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ from pymongo import MongoClient
+
+ mongo_client = MongoClient("")
+ collection = mongo_client[""][""]
+ embeddings = OpenAIEmbeddings()
+ vectorstore = MongoDBAtlasVectorSearch(collection, embeddings)
+ """
+
+ def __init__(
+ self,
+ collection: Collection[MongoDBDocumentType],
+ embedding: Embeddings,
+ *,
+ index_name: str = "default",
+ text_key: str = "text",
+ embedding_key: str = "embedding",
+ ):
+ """
+ Args:
+ collection: MongoDB collection to add the texts to.
+ embedding: Text embedding model to use.
+ text_key: MongoDB field that will contain the text for each
+ document.
+ embedding_key: MongoDB field that will contain the embedding for
+ each document.
+ index_name: Name of the Atlas Search index.
+ """
+ self._collection = collection
+ self._embedding = embedding
+ self._index_name = index_name
+ self._text_key = text_key
+ self._embedding_key = embedding_key
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embedding
+
+ @classmethod
+ def from_connection_string(
+ cls,
+ connection_string: str,
+ namespace: str,
+ embedding: Embeddings,
+ **kwargs: Any,
+ ) -> MongoDBAtlasVectorSearch:
+ """Construct a `MongoDB Atlas Vector Search` vector store
+ from a MongoDB connection URI.
+
+ Args:
+ connection_string: A valid MongoDB connection URI.
+ namespace: A valid MongoDB namespace (database and collection).
+ embedding: The text embedding model to use for the vector store.
+
+ Returns:
+ A new MongoDBAtlasVectorSearch instance.
+
+ """
+ try:
+ from importlib.metadata import version
+
+ from pymongo import MongoClient
+ from pymongo.driver_info import DriverInfo
+ except ImportError:
+ raise ImportError(
+ "Could not import pymongo, please install it with "
+ "`pip install pymongo`."
+ )
+ client: MongoClient = MongoClient(
+ connection_string,
+ driver=DriverInfo(name="Langchain", version=version("langchain")),
+ )
+ db_name, collection_name = namespace.split(".")
+ collection = client[db_name][collection_name]
+ return cls(collection, embedding, **kwargs)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[str, Any]]] = None,
+ **kwargs: Any,
+ ) -> List:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
+ _metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
+ texts_batch = []
+ metadatas_batch = []
+ result_ids = []
+ for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
+ texts_batch.append(text)
+ metadatas_batch.append(metadata)
+ if (i + 1) % batch_size == 0:
+ result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
+ texts_batch = []
+ metadatas_batch = []
+ if texts_batch:
+ result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
+ return result_ids
+
+ def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
+ if not texts:
+ return []
+ # Embed and create the documents
+ embeddings = self._embedding.embed_documents(texts)
+ to_insert = [
+ {self._text_key: t, self._embedding_key: embedding, **m}
+ for t, m, embedding in zip(texts, metadatas, embeddings)
+ ]
+ # insert the documents in MongoDB Atlas
+ insert_result = self._collection.insert_many(to_insert) # type: ignore
+ return insert_result.inserted_ids
+
+ def _similarity_search_with_score(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ pre_filter: Optional[Dict] = None,
+ post_filter_pipeline: Optional[List[Dict]] = None,
+ ) -> List[Tuple[Document, float]]:
+ params = {
+ "queryVector": embedding,
+ "path": self._embedding_key,
+ "numCandidates": k * 10,
+ "limit": k,
+ "index": self._index_name,
+ }
+ if pre_filter:
+ params["filter"] = pre_filter
+ query = {"$vectorSearch": params}
+
+ pipeline = [
+ query,
+ {"$set": {"score": {"$meta": "vectorSearchScore"}}},
+ ]
+ if post_filter_pipeline is not None:
+ pipeline.extend(post_filter_pipeline)
+ cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
+ docs = []
+ for res in cursor:
+ text = res.pop(self._text_key)
+ score = res.pop("score")
+ docs.append((Document(page_content=text, metadata=res), score))
+ return docs
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ *,
+ k: int = 4,
+ pre_filter: Optional[Dict] = None,
+ post_filter_pipeline: Optional[List[Dict]] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return MongoDB documents most similar to the given query and their scores.
+
+ Uses the knnBeta Operator available in MongoDB Atlas Search.
+ This feature is in early access and available only for evaluation purposes, to
+ validate functionality, and to gather feedback from a small closed group of
+ early access users. It is not recommended for production deployments as we
+ may introduce breaking changes.
+ For more: https://www.mongodb.com/docs/atlas/atlas-search/knn-beta
+
+ Args:
+ query: Text to look up documents similar to.
+ k: (Optional) number of documents to return. Defaults to 4.
+ pre_filter: (Optional) dictionary of argument(s) to prefilter document
+ fields on.
+ post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
+ following the knnBeta vector search.
+
+ Returns:
+ List of documents most similar to the query and their scores.
+ """
+ embedding = self._embedding.embed_query(query)
+ docs = self._similarity_search_with_score(
+ embedding,
+ k=k,
+ pre_filter=pre_filter,
+ post_filter_pipeline=post_filter_pipeline,
+ )
+ return docs
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ pre_filter: Optional[Dict] = None,
+ post_filter_pipeline: Optional[List[Dict]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return MongoDB documents most similar to the given query.
+
+ Uses the knnBeta Operator available in MongoDB Atlas Search.
+ This feature is in early access and available only for evaluation purposes, to
+ validate functionality, and to gather feedback from a small closed group of
+ early access users. It is not recommended for production deployments as we
+ may introduce breaking changes.
+ For more: https://www.mongodb.com/docs/atlas/atlas-search/knn-beta
+
+ Args:
+ query: Text to look up documents similar to.
+ k: (Optional) number of documents to return. Defaults to 4.
+ pre_filter: (Optional) dictionary of argument(s) to prefilter document
+ fields on.
+ post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
+ following the knnBeta vector search.
+
+ Returns:
+ List of documents most similar to the query and their scores.
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query,
+ k=k,
+ pre_filter=pre_filter,
+ post_filter_pipeline=post_filter_pipeline,
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ pre_filter: Optional[Dict] = None,
+ post_filter_pipeline: Optional[List[Dict]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return documents selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: (Optional) number of documents to return. Defaults to 4.
+ fetch_k: (Optional) number of documents to fetch before passing to MMR
+ algorithm. Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ pre_filter: (Optional) dictionary of argument(s) to prefilter on document
+ fields.
+ post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
+ following the knnBeta vector search.
+ Returns:
+ List of documents selected by maximal marginal relevance.
+ """
+ query_embedding = self._embedding.embed_query(query)
+ docs = self._similarity_search_with_score(
+ query_embedding,
+ k=fetch_k,
+ pre_filter=pre_filter,
+ post_filter_pipeline=post_filter_pipeline,
+ )
+ mmr_doc_indexes = maximal_marginal_relevance(
+ np.array(query_embedding),
+ [doc.metadata[self._embedding_key] for doc, _ in docs],
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
+ return mmr_docs
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[Dict]] = None,
+ collection: Optional[Collection[MongoDBDocumentType]] = None,
+ **kwargs: Any,
+ ) -> MongoDBAtlasVectorSearch:
+ """Construct a `MongoDB Atlas Vector Search` vector store from raw documents.
+
+ This is a user-friendly interface that:
+ 1. Embeds documents.
+ 2. Adds the documents to a provided MongoDB Atlas Vector Search index
+ (Lucene)
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+ from pymongo import MongoClient
+
+ from langchain_community.vectorstores import MongoDBAtlasVectorSearch
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ mongo_client = MongoClient("")
+ collection = mongo_client[""][""]
+ embeddings = OpenAIEmbeddings()
+ vectorstore = MongoDBAtlasVectorSearch.from_texts(
+ texts,
+ embeddings,
+ metadatas=metadatas,
+ collection=collection
+ )
+ """
+ if collection is None:
+ raise ValueError("Must provide 'collection' named parameter.")
+ vectorstore = cls(collection, embedding, **kwargs)
+ vectorstore.add_texts(texts, metadatas=metadatas)
+ return vectorstore
diff --git a/libs/community/langchain_community/vectorstores/myscale.py b/libs/community/langchain_community/vectorstores/myscale.py
new file mode 100644
index 00000000000..e5a18576cde
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/myscale.py
@@ -0,0 +1,614 @@
+from __future__ import annotations
+
+import json
+import logging
+from hashlib import sha1
+from threading import Thread
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseSettings
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger()
+
+
+def has_mul_sub_str(s: str, *args: Any) -> bool:
+ """
+ Check if a string contains multiple substrings.
+ Args:
+ s: string to check.
+ *args: substrings to check.
+
+ Returns:
+ True if all substrings are in the string, False otherwise.
+ """
+ for a in args:
+ if a not in s:
+ return False
+ return True
+
+
+class MyScaleSettings(BaseSettings):
+ """MyScale client configuration.
+
+ Attribute:
+ myscale_host (str) : An URL to connect to MyScale backend.
+ Defaults to 'localhost'.
+ myscale_port (int) : URL port to connect with HTTP. Defaults to 8443.
+ username (str) : Username to login. Defaults to None.
+ password (str) : Password to login. Defaults to None.
+ index_type (str): index type string.
+ index_param (dict): index build parameter.
+ database (str) : Database name to find the table. Defaults to 'default'.
+ table (str) : Table name to operate on.
+ Defaults to 'vector_table'.
+ metric (str) : Metric to compute distance,
+ supported are ('L2', 'Cosine', 'IP'). Defaults to 'Cosine'.
+ column_map (Dict) : Column type map to project column name onto langchain
+ semantics. Must have keys: `text`, `id`, `vector`,
+ must be same size to number of columns. For example:
+ .. code-block:: python
+
+ {
+ 'id': 'text_id',
+ 'vector': 'text_embedding',
+ 'text': 'text_plain',
+ 'metadata': 'metadata_dictionary_in_json',
+ }
+
+ Defaults to identity map.
+
+ """
+
+ host: str = "localhost"
+ port: int = 8443
+
+ username: Optional[str] = None
+ password: Optional[str] = None
+
+ index_type: str = "MSTG"
+ index_param: Optional[Dict[str, str]] = None
+
+ column_map: Dict[str, str] = {
+ "id": "id",
+ "text": "text",
+ "vector": "vector",
+ "metadata": "metadata",
+ }
+
+ database: str = "default"
+ table: str = "langchain"
+ metric: str = "Cosine"
+
+ def __getitem__(self, item: str) -> Any:
+ return getattr(self, item)
+
+ class Config:
+ env_file = ".env"
+ env_prefix = "myscale_"
+ env_file_encoding = "utf-8"
+
+
+class MyScale(VectorStore):
+ """`MyScale` vector store.
+
+ You need a `clickhouse-connect` python package, and a valid account
+ to connect to MyScale.
+
+ MyScale can not only search with simple vector indexes.
+ It also supports a complex query with multiple conditions,
+ constraints and even sub-queries.
+
+ For more information, please visit
+ [myscale official site](https://docs.myscale.com/en/overview/)
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ config: Optional[MyScaleSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ """MyScale Wrapper to LangChain
+
+ embedding (Embeddings):
+ config (MyScaleSettings): Configuration to MyScale Client
+ Other keyword arguments will pass into
+ [clickhouse-connect](https://docs.myscale.com/)
+ """
+ try:
+ from clickhouse_connect import get_client
+ except ImportError:
+ raise ImportError(
+ "Could not import clickhouse connect python package. "
+ "Please install it with `pip install clickhouse-connect`."
+ )
+ try:
+ from tqdm import tqdm
+
+ self.pgbar = tqdm
+ except ImportError:
+ # Just in case if tqdm is not installed
+ self.pgbar = lambda x: x
+ super().__init__()
+ if config is not None:
+ self.config = config
+ else:
+ self.config = MyScaleSettings()
+ assert self.config
+ assert self.config.host and self.config.port
+ assert (
+ self.config.column_map
+ and self.config.database
+ and self.config.table
+ and self.config.metric
+ )
+ for k in ["id", "vector", "text", "metadata"]:
+ assert k in self.config.column_map
+ assert self.config.metric.upper() in ["IP", "COSINE", "L2"]
+ if self.config.metric in ["ip", "cosine", "l2"]:
+ logger.warning(
+ "Lower case metric types will be deprecated "
+ "the future. Please use one of ('IP', 'Cosine', 'L2')"
+ )
+
+ # initialize the schema
+ dim = len(embedding.embed_query("try this out"))
+
+ index_params = (
+ ", " + ",".join([f"'{k}={v}'" for k, v in self.config.index_param.items()])
+ if self.config.index_param
+ else ""
+ )
+ schema_ = f"""
+ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
+ {self.config.column_map['id']} String,
+ {self.config.column_map['text']} String,
+ {self.config.column_map['vector']} Array(Float32),
+ {self.config.column_map['metadata']} JSON,
+ CONSTRAINT cons_vec_len CHECK length(\
+ {self.config.column_map['vector']}) = {dim},
+ VECTOR INDEX vidx {self.config.column_map['vector']} \
+ TYPE {self.config.index_type}(\
+ 'metric_type={self.config.metric}'{index_params})
+ ) ENGINE = MergeTree ORDER BY {self.config.column_map['id']}
+ """
+ self.dim = dim
+ self.BS = "\\"
+ self.must_escape = ("\\", "'")
+ self._embeddings = embedding
+ self.dist_order = (
+ "ASC" if self.config.metric.upper() in ["COSINE", "L2"] else "DESC"
+ )
+
+ # Create a connection to myscale
+ self.client = get_client(
+ host=self.config.host,
+ port=self.config.port,
+ username=self.config.username,
+ password=self.config.password,
+ **kwargs,
+ )
+ self.client.command("SET allow_experimental_object_type=1")
+ self.client.command(schema_)
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embeddings
+
+ def escape_str(self, value: str) -> str:
+ return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
+
+ def _build_istr(self, transac: Iterable, column_names: Iterable[str]) -> str:
+ ks = ",".join(column_names)
+ _data = []
+ for n in transac:
+ n = ",".join([f"'{self.escape_str(str(_n))}'" for _n in n])
+ _data.append(f"({n})")
+ i_str = f"""
+ INSERT INTO TABLE
+ {self.config.database}.{self.config.table}({ks})
+ VALUES
+ {','.join(_data)}
+ """
+ return i_str
+
+ def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
+ _i_str = self._build_istr(transac, column_names)
+ self.client.command(_i_str)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ batch_size: int = 32,
+ ids: Optional[Iterable[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ ids: Optional list of ids to associate with the texts.
+ batch_size: Batch size of insertion
+ metadata: Optional column data to be inserted
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+
+ """
+ # Embed and create the documents
+ ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts]
+ colmap_ = self.config.column_map
+
+ transac = []
+ column_names = {
+ colmap_["id"]: ids,
+ colmap_["text"]: texts,
+ colmap_["vector"]: map(self._embeddings.embed_query, texts),
+ }
+ metadatas = metadatas or [{} for _ in texts]
+ column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
+ assert len(set(colmap_) - set(column_names)) >= 0
+ keys, values = zip(*column_names.items())
+ try:
+ t = None
+ for v in self.pgbar(
+ zip(*values), desc="Inserting data...", total=len(metadatas)
+ ):
+ assert len(v[keys.index(self.config.column_map["vector"])]) == self.dim
+ transac.append(v)
+ if len(transac) == batch_size:
+ if t:
+ t.join()
+ t = Thread(target=self._insert, args=[transac, keys])
+ t.start()
+ transac = []
+ if len(transac) > 0:
+ if t:
+ t.join()
+ self._insert(transac, keys)
+ return [i for i in ids]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: Iterable[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ config: Optional[MyScaleSettings] = None,
+ text_ids: Optional[Iterable[str]] = None,
+ batch_size: int = 32,
+ **kwargs: Any,
+ ) -> MyScale:
+ """Create Myscale wrapper with existing texts
+
+ Args:
+ texts (Iterable[str]): List or tuple of strings to be added
+ embedding (Embeddings): Function to extract text embedding
+ config (MyScaleSettings, Optional): Myscale configuration
+ text_ids (Optional[Iterable], optional): IDs for the texts.
+ Defaults to None.
+ batch_size (int, optional): Batchsize when transmitting data to MyScale.
+ Defaults to 32.
+ metadata (List[dict], optional): metadata to texts. Defaults to None.
+ Other keyword arguments will pass into
+ [clickhouse-connect](https://clickhouse.com/docs/en/integrations/python#clickhouse-connect-driver-api)
+ Returns:
+ MyScale Index
+ """
+ ctx = cls(embedding, config, **kwargs)
+ ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas)
+ return ctx
+
+ def __repr__(self) -> str:
+ """Text representation for myscale, prints backends, username and schemas.
+ Easy to use with `str(Myscale())`
+
+ Returns:
+ repr: string to show connection info and data schema
+ """
+ _repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ "
+ _repr += f"{self.config.host}:{self.config.port}\033[0m\n\n"
+ _repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n"
+ _repr += "-" * 51 + "\n"
+ for r in self.client.query(
+ f"DESC {self.config.database}.{self.config.table}"
+ ).named_results():
+ _repr += (
+ f"|\033[94m{r['name']:24s}\033[0m|\033[96m{r['type']:24s}\033[0m|\n"
+ )
+ _repr += "-" * 51 + "\n"
+ return _repr
+
+ def _build_qstr(
+ self, q_emb: List[float], topk: int, where_str: Optional[str] = None
+ ) -> str:
+ q_emb_str = ",".join(map(str, q_emb))
+ if where_str:
+ where_str = f"PREWHERE {where_str}"
+ else:
+ where_str = ""
+
+ q_str = f"""
+ SELECT {self.config.column_map['text']},
+ {self.config.column_map['metadata']}, dist
+ FROM {self.config.database}.{self.config.table}
+ {where_str}
+ ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
+ AS dist {self.dist_order}
+ LIMIT {topk}
+ """
+ return q_str
+
+ def similarity_search(
+ self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
+ ) -> List[Document]:
+ """Perform a similarity search with MyScale
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of Documents
+ """
+ return self.similarity_search_by_vector(
+ self._embeddings.embed_query(query), k, where_str, **kwargs
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a similarity search with MyScale by vectors
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of (Document, similarity)
+ """
+ q_str = self._build_qstr(embedding, k, where_str)
+ try:
+ return [
+ Document(
+ page_content=r[self.config.column_map["text"]],
+ metadata=r[self.config.column_map["metadata"]],
+ )
+ for r in self.client.query(q_str).named_results()
+ ]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ def similarity_search_with_relevance_scores(
+ self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Perform a similarity search with MyScale
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of documents most similar to the query text
+ and cosine distance in float for each.
+ Lower score represents more similarity.
+ """
+ q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str)
+ try:
+ return [
+ (
+ Document(
+ page_content=r[self.config.column_map["text"]],
+ metadata=r[self.config.column_map["metadata"]],
+ ),
+ r["dist"],
+ )
+ for r in self.client.query(q_str).named_results()
+ ]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ def drop(self) -> None:
+ """
+ Helper function: Drop data
+ """
+ self.client.command(
+ f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}"
+ )
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> Optional[bool]:
+ """Delete by vector ID or other criteria.
+
+ Args:
+ ids: List of ids to delete.
+ **kwargs: Other keyword arguments that subclasses might use.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+ assert not (
+ ids is None and where_str is None
+ ), "You need to specify where to be deleted! Either with `ids` or `where_str`"
+ conds = []
+ if ids:
+ conds.extend([f"{self.config.column_map['id']} = '{id}'" for id in ids])
+ if where_str:
+ conds.append(where_str)
+ assert len(conds) > 0
+ where_str_final = " AND ".join(conds)
+ qstr = (
+ f"DELETE FROM {self.config.database}.{self.config.table} "
+ f"WHERE {where_str_final}"
+ )
+ try:
+ self.client.command(qstr)
+ return True
+ except Exception as e:
+ logger.error(str(e))
+ return False
+
+ @property
+ def metadata_column(self) -> str:
+ return self.config.column_map["metadata"]
+
+
+class MyScaleWithoutJSON(MyScale):
+ """MyScale vector store without metadata column
+
+ This is super handy if you are working to a SQL-native table
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ config: Optional[MyScaleSettings] = None,
+ must_have_cols: List[str] = [],
+ **kwargs: Any,
+ ) -> None:
+ """Building a myscale vector store without metadata column
+
+ embedding (Embeddings): embedding model
+ config (MyScaleSettings): Configuration to MyScale Client
+ must_have_cols (List[str]): column names to be included in query
+ Other keyword arguments will pass into
+ [clickhouse-connect](https://docs.myscale.com/)
+ """
+ super().__init__(embedding, config, **kwargs)
+ self.must_have_cols: List[str] = must_have_cols
+
+ def _build_qstr(
+ self, q_emb: List[float], topk: int, where_str: Optional[str] = None
+ ) -> str:
+ q_emb_str = ",".join(map(str, q_emb))
+ if where_str:
+ where_str = f"PREWHERE {where_str}"
+ else:
+ where_str = ""
+
+ q_str = f"""
+ SELECT {self.config.column_map['text']}, dist,
+ {','.join(self.must_have_cols)}
+ FROM {self.config.database}.{self.config.table}
+ {where_str}
+ ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
+ AS dist {self.dist_order}
+ LIMIT {topk}
+ """
+ return q_str
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a similarity search with MyScale by vectors
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of (Document, similarity)
+ """
+ q_str = self._build_qstr(embedding, k, where_str)
+ try:
+ return [
+ Document(
+ page_content=r[self.config.column_map["text"]],
+ metadata={k: r[k] for k in self.must_have_cols},
+ )
+ for r in self.client.query(q_str).named_results()
+ ]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ def similarity_search_with_relevance_scores(
+ self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Perform a similarity search with MyScale
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of documents most similar to the query text
+ and cosine distance in float for each.
+ Lower score represents more similarity.
+ """
+ q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str)
+ try:
+ return [
+ (
+ Document(
+ page_content=r[self.config.column_map["text"]],
+ metadata={k: r[k] for k in self.must_have_cols},
+ ),
+ r["dist"],
+ )
+ for r in self.client.query(q_str).named_results()
+ ]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ @property
+ def metadata_column(self) -> str:
+ return ""
diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py
new file mode 100644
index 00000000000..7ccc2e7d109
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/neo4j_vector.py
@@ -0,0 +1,941 @@
+from __future__ import annotations
+
+import enum
+import logging
+import os
+import uuid
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_env
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import DistanceStrategy
+
+DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
+DISTANCE_MAPPING = {
+ DistanceStrategy.EUCLIDEAN_DISTANCE: "euclidean",
+ DistanceStrategy.COSINE: "cosine",
+}
+
+
+class SearchType(str, enum.Enum):
+ """Enumerator of the Distance strategies."""
+
+ VECTOR = "vector"
+ HYBRID = "hybrid"
+
+
+DEFAULT_SEARCH_TYPE = SearchType.VECTOR
+
+
+def _get_search_index_query(search_type: SearchType) -> str:
+ type_to_query_map = {
+ SearchType.VECTOR: (
+ "CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score "
+ ),
+ SearchType.HYBRID: (
+ "CALL { "
+ "CALL db.index.vector.queryNodes($index, $k, $embedding) "
+ "YIELD node, score "
+ "RETURN node, score UNION "
+ "CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) "
+ "YIELD node, score "
+ "WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
+ "UNWIND nodes AS n "
+ "RETURN n.node AS node, (n.score / max) AS score " # We use 0 as min
+ "} "
+ "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " # dedup
+ ),
+ }
+ return type_to_query_map[search_type]
+
+
+def check_if_not_null(props: List[str], values: List[Any]) -> None:
+ """Check if the values are not None or empty string"""
+ for prop, value in zip(props, values):
+ if not value:
+ raise ValueError(f"Parameter `{prop}` must not be None or empty string")
+
+
+def sort_by_index_name(
+ lst: List[Dict[str, Any]], index_name: str
+) -> List[Dict[str, Any]]:
+ """Sort first element to match the index_name if exists"""
+ return sorted(lst, key=lambda x: x.get("index_name") != index_name)
+
+
+class Neo4jVector(VectorStore):
+ """`Neo4j` vector index.
+
+ To use, you should have the ``neo4j`` python package installed.
+
+ Args:
+ url: Neo4j connection url
+ username: Neo4j username.
+ password: Neo4j password
+ database: Optionally provide Neo4j database
+ Defaults to "neo4j"
+ embedding: Any embedding function implementing
+ `langchain.embeddings.base.Embeddings` interface.
+ distance_strategy: The distance strategy to use. (default: COSINE)
+ pre_delete_collection: If True, will delete existing data if it exists.
+ (default: False). Useful for testing.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores.neo4j_vector import Neo4jVector
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ url="bolt://localhost:7687"
+ username="neo4j"
+ password="pleaseletmein"
+ embeddings = OpenAIEmbeddings()
+ vectorestore = Neo4jVector.from_documents(
+ embedding=embeddings,
+ documents=docs,
+ url=url
+ username=username,
+ password=password,
+ )
+
+
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ *,
+ search_type: SearchType = SearchType.VECTOR,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ url: Optional[str] = None,
+ keyword_index_name: Optional[str] = "keyword",
+ database: str = "neo4j",
+ index_name: str = "vector",
+ node_label: str = "Chunk",
+ embedding_node_property: str = "embedding",
+ text_node_property: str = "text",
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ logger: Optional[logging.Logger] = None,
+ pre_delete_collection: bool = False,
+ retrieval_query: str = "",
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ ) -> None:
+ try:
+ import neo4j
+ except ImportError:
+ raise ImportError(
+ "Could not import neo4j python package. "
+ "Please install it with `pip install neo4j`."
+ )
+
+ # Allow only cosine and euclidean distance strategies
+ if distance_strategy not in [
+ DistanceStrategy.EUCLIDEAN_DISTANCE,
+ DistanceStrategy.COSINE,
+ ]:
+ raise ValueError(
+ "distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'"
+ )
+
+ # Handle if the credentials are environment variables
+
+ # Support URL for backwards compatibility
+ url = os.environ.get("NEO4J_URL", url)
+ url = get_from_env("url", "NEO4J_URI", url)
+ username = get_from_env("username", "NEO4J_USERNAME", username)
+ password = get_from_env("password", "NEO4J_PASSWORD", password)
+ database = get_from_env("database", "NEO4J_DATABASE", database)
+
+ self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
+ self._database = database
+ self.schema = ""
+ # Verify connection
+ try:
+ self._driver.verify_connectivity()
+ except neo4j.exceptions.ServiceUnavailable:
+ raise ValueError(
+ "Could not connect to Neo4j database. "
+ "Please ensure that the url is correct"
+ )
+ except neo4j.exceptions.AuthError:
+ raise ValueError(
+ "Could not connect to Neo4j database. "
+ "Please ensure that the username and password are correct"
+ )
+
+ # Verify if the version support vector index
+ self.verify_version()
+
+ # Verify that required values are not null
+ check_if_not_null(
+ [
+ "index_name",
+ "node_label",
+ "embedding_node_property",
+ "text_node_property",
+ ],
+ [index_name, node_label, embedding_node_property, text_node_property],
+ )
+
+ self.embedding = embedding
+ self._distance_strategy = distance_strategy
+ self.index_name = index_name
+ self.keyword_index_name = keyword_index_name
+ self.node_label = node_label
+ self.embedding_node_property = embedding_node_property
+ self.text_node_property = text_node_property
+ self.logger = logger or logging.getLogger(__name__)
+ self.override_relevance_score_fn = relevance_score_fn
+ self.retrieval_query = retrieval_query
+ self.search_type = search_type
+ # Calculate embedding dimension
+ self.embedding_dimension = len(embedding.embed_query("foo"))
+
+ # Delete existing data if flagged
+ if pre_delete_collection:
+ from neo4j.exceptions import DatabaseError
+
+ self.query(
+ f"MATCH (n:`{self.node_label}`) "
+ "CALL { WITH n DETACH DELETE n } "
+ "IN TRANSACTIONS OF 10000 ROWS;"
+ )
+ # Delete index
+ try:
+ self.query(f"DROP INDEX {self.index_name}")
+ except DatabaseError: # Index didn't exist yet
+ pass
+
+ def query(
+ self, query: str, *, params: Optional[dict] = None
+ ) -> List[Dict[str, Any]]:
+ """
+ This method sends a Cypher query to the connected Neo4j database
+ and returns the results as a list of dictionaries.
+
+ Args:
+ query (str): The Cypher query to execute.
+ params (dict, optional): Dictionary of query parameters. Defaults to {}.
+
+ Returns:
+ List[Dict[str, Any]]: List of dictionaries containing the query results.
+ """
+ from neo4j.exceptions import CypherSyntaxError
+
+ params = params or {}
+ with self._driver.session(database=self._database) as session:
+ try:
+ data = session.run(query, params)
+ return [r.data() for r in data]
+ except CypherSyntaxError as e:
+ raise ValueError(f"Cypher Statement is not valid\n{e}")
+
+ def verify_version(self) -> None:
+ """
+ Check if the connected Neo4j database version supports vector indexing.
+
+ Queries the Neo4j database to retrieve its version and compares it
+ against a target version (5.11.0) that is known to support vector
+ indexing. Raises a ValueError if the connected Neo4j version is
+ not supported.
+ """
+ version = self.query("CALL dbms.components()")[0]["versions"][0]
+ if "aura" in version:
+ version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,)
+ else:
+ version_tuple = tuple(map(int, version.split(".")))
+
+ target_version = (5, 11, 0)
+
+ if version_tuple < target_version:
+ raise ValueError(
+ "Version index is only supported in Neo4j version 5.11 or greater"
+ )
+
+ def retrieve_existing_index(self) -> Optional[int]:
+ """
+ Check if the vector index exists in the Neo4j database
+ and returns its embedding dimension.
+
+ This method queries the Neo4j database for existing indexes
+ and attempts to retrieve the dimension of the vector index
+ with the specified name. If the index exists, its dimension is returned.
+ If the index doesn't exist, `None` is returned.
+
+ Returns:
+ int or None: The embedding dimension of the existing index if found.
+ """
+
+ index_information = self.query(
+ "SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options "
+ "WHERE type = 'VECTOR' AND (name = $index_name "
+ "OR (labelsOrTypes[0] = $node_label AND "
+ "properties[0] = $embedding_node_property)) "
+ "RETURN name, labelsOrTypes, properties, options ",
+ params={
+ "index_name": self.index_name,
+ "node_label": self.node_label,
+ "embedding_node_property": self.embedding_node_property,
+ },
+ )
+ # sort by index_name
+ index_information = sort_by_index_name(index_information, self.index_name)
+ try:
+ self.index_name = index_information[0]["name"]
+ self.node_label = index_information[0]["labelsOrTypes"][0]
+ self.embedding_node_property = index_information[0]["properties"][0]
+ embedding_dimension = index_information[0]["options"]["indexConfig"][
+ "vector.dimensions"
+ ]
+
+ return embedding_dimension
+ except IndexError:
+ return None
+
+ def retrieve_existing_fts_index(
+ self, text_node_properties: List[str] = []
+ ) -> Optional[str]:
+ """
+ Check if the fulltext index exists in the Neo4j database
+
+ This method queries the Neo4j database for existing fts indexes
+ with the specified name.
+
+ Returns:
+ (Tuple): keyword index information
+ """
+
+ index_information = self.query(
+ "SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options "
+ "WHERE type = 'FULLTEXT' AND (name = $keyword_index_name "
+ "OR (labelsOrTypes = [$node_label] AND "
+ "properties = $text_node_property)) "
+ "RETURN name, labelsOrTypes, properties, options ",
+ params={
+ "keyword_index_name": self.keyword_index_name,
+ "node_label": self.node_label,
+ "text_node_property": text_node_properties or [self.text_node_property],
+ },
+ )
+ # sort by index_name
+ index_information = sort_by_index_name(index_information, self.index_name)
+ try:
+ self.keyword_index_name = index_information[0]["name"]
+ self.text_node_property = index_information[0]["properties"][0]
+ node_label = index_information[0]["labelsOrTypes"][0]
+ return node_label
+ except IndexError:
+ return None
+
+ def create_new_index(self) -> None:
+ """
+ This method constructs a Cypher query and executes it
+ to create a new vector index in Neo4j.
+ """
+ index_query = (
+ "CALL db.index.vector.createNodeIndex("
+ "$index_name,"
+ "$node_label,"
+ "$embedding_node_property,"
+ "toInteger($embedding_dimension),"
+ "$similarity_metric )"
+ )
+
+ parameters = {
+ "index_name": self.index_name,
+ "node_label": self.node_label,
+ "embedding_node_property": self.embedding_node_property,
+ "embedding_dimension": self.embedding_dimension,
+ "similarity_metric": DISTANCE_MAPPING[self._distance_strategy],
+ }
+ self.query(index_query, params=parameters)
+
+ def create_new_keyword_index(self, text_node_properties: List[str] = []) -> None:
+ """
+ This method constructs a Cypher query and executes it
+ to create a new full text index in Neo4j.
+ """
+ node_props = text_node_properties or [self.text_node_property]
+ fts_index_query = (
+ f"CREATE FULLTEXT INDEX {self.keyword_index_name} "
+ f"FOR (n:`{self.node_label}`) ON EACH "
+ f"[{', '.join(['n.`' + el + '`' for el in node_props])}]"
+ )
+ self.query(fts_index_query)
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding
+
+ @classmethod
+ def __from(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ create_id_index: bool = True,
+ search_type: SearchType = SearchType.VECTOR,
+ **kwargs: Any,
+ ) -> Neo4jVector:
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ store = cls(
+ embedding=embedding,
+ search_type=search_type,
+ **kwargs,
+ )
+ # Check if the vector index already exists
+ embedding_dimension = store.retrieve_existing_index()
+
+ # If the vector index doesn't exist yet
+ if not embedding_dimension:
+ store.create_new_index()
+ # If the index already exists, check if embedding dimensions match
+ elif not store.embedding_dimension == embedding_dimension:
+ raise ValueError(
+ f"Index with name {store.index_name} already exists."
+ "The provided embedding function and vector index "
+ "dimensions do not match.\n"
+ f"Embedding function dimension: {store.embedding_dimension}\n"
+ f"Vector index dimension: {embedding_dimension}"
+ )
+
+ if search_type == SearchType.HYBRID:
+ fts_node_label = store.retrieve_existing_fts_index()
+ # If the FTS index doesn't exist yet
+ if not fts_node_label:
+ store.create_new_keyword_index()
+ else: # Validate that FTS and Vector index use the same information
+ if not fts_node_label == store.node_label:
+ raise ValueError(
+ "Vector and keyword index don't index the same node label"
+ )
+
+ # Create unique constraint for faster import
+ if create_id_index:
+ store.query(
+ "CREATE CONSTRAINT IF NOT EXISTS "
+ f"FOR (n:`{store.node_label}`) REQUIRE n.id IS UNIQUE;"
+ )
+
+ store.add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ return store
+
+ def add_embeddings(
+ self,
+ texts: Iterable[str],
+ embeddings: List[List[float]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add embeddings to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ embeddings: List of list of embedding vectors.
+ metadatas: List of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+ """
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ import_query = (
+ "UNWIND $data AS row "
+ "CALL { WITH row "
+ f"MERGE (c:`{self.node_label}` {{id: row.id}}) "
+ "WITH c, row "
+ f"CALL db.create.setVectorProperty(c, "
+ f"'{self.embedding_node_property}', row.embedding) "
+ "YIELD node "
+ f"SET c.`{self.text_node_property}` = row.text "
+ "SET c += row.metadata } IN TRANSACTIONS OF 1000 ROWS"
+ )
+
+ parameters = {
+ "data": [
+ {"text": text, "metadata": metadata, "embedding": embedding, "id": id}
+ for text, metadata, embedding, id in zip(
+ texts, metadatas, embeddings, ids
+ )
+ ]
+ }
+
+ self.query(import_query, params=parameters)
+
+ return ids
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ embeddings = self.embedding.embed_documents(list(texts))
+ return self.add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search with Neo4jVector.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ embedding = self.embedding.embed_query(text=query)
+ return self.similarity_search_by_vector(
+ embedding=embedding,
+ k=k,
+ query=query,
+ )
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ embedding = self.embedding.embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, query=query
+ )
+ return docs
+
+ def similarity_search_with_score_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """
+ Perform a similarity search in the Neo4j database using a
+ given vector and return the top k similar documents with their scores.
+
+ This method uses a Cypher query to find the top k documents that
+ are most similar to a given embedding. The similarity is measured
+ using a vector index in the Neo4j database. The results are returned
+ as a list of tuples, each containing a Document object and
+ its similarity score.
+
+ Args:
+ embedding (List[float]): The embedding vector to compare against.
+ k (int, optional): The number of top similar documents to retrieve.
+
+ Returns:
+ List[Tuple[Document, float]]: A list of tuples, each containing
+ a Document object and its similarity score.
+ """
+ default_retrieval = (
+ f"RETURN node.`{self.text_node_property}` AS text, score, "
+ f"node {{.*, `{self.text_node_property}`: Null, "
+ f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
+ )
+
+ retrieval_query = (
+ self.retrieval_query if self.retrieval_query else default_retrieval
+ )
+
+ read_query = _get_search_index_query(self.search_type) + retrieval_query
+ parameters = {
+ "index": self.index_name,
+ "k": k,
+ "embedding": embedding,
+ "keyword_index": self.keyword_index_name,
+ "query": kwargs["query"],
+ }
+
+ results = self.query(read_query, params=parameters)
+
+ docs = [
+ (
+ Document(
+ page_content=result["text"],
+ metadata={
+ k: v for k, v in result["metadata"].items() if v is not None
+ },
+ ),
+ result["score"],
+ )
+ for result in results
+ ]
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ @classmethod
+ def from_texts(
+ cls: Type[Neo4jVector],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Neo4jVector:
+ """
+ Return Neo4jVector initialized from texts and embeddings.
+ Neo4j credentials are required in the form of `url`, `username`,
+ and `password` and optional `database` parameters.
+ """
+ embeddings = embedding.embed_documents(list(texts))
+
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ distance_strategy=distance_strategy,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> Neo4jVector:
+ """Construct Neo4jVector wrapper from raw documents and pre-
+ generated embeddings.
+
+ Return Neo4jVector initialized from documents and embeddings.
+ Neo4j credentials are required in the form of `url`, `username`,
+ and `password` and optional `database` parameters.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores.neo4j_vector import Neo4jVector
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ vectorstore = Neo4jVector.from_embeddings(
+ text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_existing_index(
+ cls: Type[Neo4jVector],
+ embedding: Embeddings,
+ index_name: str,
+ search_type: SearchType = DEFAULT_SEARCH_TYPE,
+ keyword_index_name: Optional[str] = None,
+ **kwargs: Any,
+ ) -> Neo4jVector:
+ """
+ Get instance of an existing Neo4j vector index. This method will
+ return the instance of the store without inserting any new
+ embeddings.
+ Neo4j credentials are required in the form of `url`, `username`,
+ and `password` and optional `database` parameters along with
+ the `index_name` definition.
+ """
+
+ if search_type == SearchType.HYBRID and not keyword_index_name:
+ raise ValueError(
+ "keyword_index name has to be specified "
+ "when using hybrid search option"
+ )
+
+ store = cls(
+ embedding=embedding,
+ index_name=index_name,
+ keyword_index_name=keyword_index_name,
+ search_type=search_type,
+ **kwargs,
+ )
+
+ embedding_dimension = store.retrieve_existing_index()
+
+ if not embedding_dimension:
+ raise ValueError(
+ "The specified vector index name does not exist. "
+ "Make sure to check if you spelled it correctly"
+ )
+
+ # Check if embedding function and vector index dimensions match
+ if not store.embedding_dimension == embedding_dimension:
+ raise ValueError(
+ "The provided embedding function and vector index "
+ "dimensions do not match.\n"
+ f"Embedding function dimension: {store.embedding_dimension}\n"
+ f"Vector index dimension: {embedding_dimension}"
+ )
+
+ if search_type == SearchType.HYBRID:
+ fts_node_label = store.retrieve_existing_fts_index()
+ # If the FTS index doesn't exist yet
+ if not fts_node_label:
+ raise ValueError(
+ "The specified keyword index name does not exist. "
+ "Make sure to check if you spelled it correctly"
+ )
+ else: # Validate that FTS and Vector index use the same information
+ if not fts_node_label == store.node_label:
+ raise ValueError(
+ "Vector and keyword index don't index the same node label"
+ )
+
+ return store
+
+ @classmethod
+ def from_documents(
+ cls: Type[Neo4jVector],
+ documents: List[Document],
+ embedding: Embeddings,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Neo4jVector:
+ """
+ Return Neo4jVector initialized from documents and embeddings.
+ Neo4j credentials are required in the form of `url`, `username`,
+ and `password` and optional `database` parameters.
+ """
+
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+
+ return cls.from_texts(
+ texts=texts,
+ embedding=embedding,
+ distance_strategy=distance_strategy,
+ metadatas=metadatas,
+ ids=ids,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_existing_graph(
+ cls: Type[Neo4jVector],
+ embedding: Embeddings,
+ node_label: str,
+ embedding_node_property: str,
+ text_node_properties: List[str],
+ *,
+ keyword_index_name: Optional[str] = "keyword",
+ index_name: str = "vector",
+ search_type: SearchType = DEFAULT_SEARCH_TYPE,
+ retrieval_query: str = "",
+ **kwargs: Any,
+ ) -> Neo4jVector:
+ """
+ Initialize and return a Neo4jVector instance from an existing graph.
+
+ This method initializes a Neo4jVector instance using the provided
+ parameters and the existing graph. It validates the existence of
+ the indices and creates new ones if they don't exist.
+
+ Returns:
+ Neo4jVector: An instance of Neo4jVector initialized with the provided parameters
+ and existing graph.
+
+ Example:
+ >>> neo4j_vector = Neo4jVector.from_existing_graph(
+ ... embedding=my_embedding,
+ ... node_label="Document",
+ ... embedding_node_property="embedding",
+ ... text_node_properties=["title", "content"]
+ ... )
+
+ Note:
+ Neo4j credentials are required in the form of `url`, `username`, and `password`,
+ and optional `database` parameters passed as additional keyword arguments.
+ """
+ # Validate the list is not empty
+ if not text_node_properties:
+ raise ValueError(
+ "Parameter `text_node_properties` must not be an empty list"
+ )
+ # Prefer retrieval query from params, otherwise construct it
+ if not retrieval_query:
+ retrieval_query = (
+ f"RETURN reduce(str='', k IN {text_node_properties} |"
+ " str + '\\n' + k + ': ' + coalesce(node[k], '')) AS text, "
+ "node {.*, `"
+ + embedding_node_property
+ + "`: Null, id: Null, "
+ + ", ".join([f"`{prop}`: Null" for prop in text_node_properties])
+ + "} AS metadata, score"
+ )
+ store = cls(
+ embedding=embedding,
+ index_name=index_name,
+ keyword_index_name=keyword_index_name,
+ search_type=search_type,
+ retrieval_query=retrieval_query,
+ node_label=node_label,
+ embedding_node_property=embedding_node_property,
+ **kwargs,
+ )
+
+ # Check if the vector index already exists
+ embedding_dimension = store.retrieve_existing_index()
+
+ # If the vector index doesn't exist yet
+ if not embedding_dimension:
+ store.create_new_index()
+ # If the index already exists, check if embedding dimensions match
+ elif not store.embedding_dimension == embedding_dimension:
+ raise ValueError(
+ f"Index with name {store.index_name} already exists."
+ "The provided embedding function and vector index "
+ "dimensions do not match.\n"
+ f"Embedding function dimension: {store.embedding_dimension}\n"
+ f"Vector index dimension: {embedding_dimension}"
+ )
+ # FTS index for Hybrid search
+ if search_type == SearchType.HYBRID:
+ fts_node_label = store.retrieve_existing_fts_index(text_node_properties)
+ # If the FTS index doesn't exist yet
+ if not fts_node_label:
+ store.create_new_keyword_index(text_node_properties)
+ else: # Validate that FTS and Vector index use the same information
+ if not fts_node_label == store.node_label:
+ raise ValueError(
+ "Vector and keyword index don't index the same node label"
+ )
+
+ # Populate embeddings
+ while True:
+ fetch_query = (
+ f"MATCH (n:`{node_label}`) "
+ f"WHERE n.{embedding_node_property} IS null "
+ "AND any(k in $props WHERE n[k] IS NOT null) "
+ f"RETURN elementId(n) AS id, reduce(str='',"
+ "k IN $props | str + '\\n' + k + ':' + coalesce(n[k], '')) AS text "
+ "LIMIT 1000"
+ )
+ data = store.query(fetch_query, params={"props": text_node_properties})
+ text_embeddings = embedding.embed_documents([el["text"] for el in data])
+
+ params = {
+ "data": [
+ {"id": el["id"], "embedding": embedding}
+ for el, embedding in zip(data, text_embeddings)
+ ]
+ }
+
+ store.query(
+ "UNWIND $data AS row "
+ f"MATCH (n:`{node_label}`) "
+ "WHERE elementId(n) = row.id "
+ f"CALL db.create.setVectorProperty(n, "
+ f"'{embedding_node_property}', row.embedding) "
+ "YIELD node RETURN count(*)",
+ params=params,
+ )
+ # If embedding calculation should be stopped
+ if len(data) < 1000:
+ break
+ return store
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+ if self.override_relevance_score_fn is not None:
+ return self.override_relevance_score_fn
+
+ # Default strategy is to rely on distance strategy provided
+ # in vectorstore constructor
+ if self._distance_strategy == DistanceStrategy.COSINE:
+ return lambda x: x
+ elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
+ return lambda x: x
+ else:
+ raise ValueError(
+ "No supported normalization function"
+ f" for distance_strategy of {self._distance_strategy}."
+ "Consider providing relevance_score_fn to PGVector constructor."
+ )
diff --git a/libs/community/langchain_community/vectorstores/nucliadb.py b/libs/community/langchain_community/vectorstores/nucliadb.py
new file mode 100644
index 00000000000..d20e51bbc7c
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/nucliadb.py
@@ -0,0 +1,159 @@
+import os
+from typing import Any, Dict, Iterable, List, Optional, Type
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VST, VectorStore
+
+FIELD_TYPES = {
+ "f": "files",
+ "t": "texts",
+ "l": "links",
+}
+
+
+class NucliaDB(VectorStore):
+ """NucliaDB vector store."""
+
+ _config: Dict[str, Any] = {}
+
+ def __init__(
+ self,
+ knowledge_box: str,
+ local: bool,
+ api_key: Optional[str] = None,
+ backend: Optional[str] = None,
+ ) -> None:
+ """Initialize the NucliaDB client.
+
+ Args:
+ knowledge_box: the Knowledge Box id.
+ local: Whether to use a local NucliaDB instance or Nuclia Cloud
+ api_key: A contributor API key for the kb (needed when local is False)
+ backend: The backend url to use when local is True, defaults to
+ http://localhost:8080
+ """
+ try:
+ from nuclia.sdk import NucliaAuth
+ except ImportError:
+ raise ValueError(
+ "nuclia python package not found. "
+ "Please install it with `pip install nuclia`."
+ )
+ self._config["LOCAL"] = local
+ zone = os.environ.get("NUCLIA_ZONE", "europe-1")
+ self._kb = knowledge_box
+ if local:
+ if not backend:
+ backend = "http://localhost:8080"
+ self._config["BACKEND"] = f"{backend}/api/v1"
+ self._config["TOKEN"] = None
+ NucliaAuth().nucliadb(url=backend)
+ NucliaAuth().kb(url=self.kb_url, interactive=False)
+ else:
+ self._config["BACKEND"] = f"https://{zone}.nuclia.cloud/api/v1"
+ self._config["TOKEN"] = api_key
+ NucliaAuth().kb(
+ url=self.kb_url, token=self._config["TOKEN"], interactive=False
+ )
+
+ @property
+ def is_local(self) -> str:
+ return self._config["LOCAL"]
+
+ @property
+ def kb_url(self) -> str:
+ return f"{self._config['BACKEND']}/kb/{self._kb}"
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Upload texts to NucliaDB"""
+ ids = []
+ from nuclia.sdk import NucliaResource
+
+ factory = NucliaResource()
+ for i, text in enumerate(texts):
+ extra: Dict[str, Any] = {"metadata": ""}
+ if metadatas:
+ extra = {"metadata": metadatas[i]}
+ id = factory.create(
+ texts={"text": {"body": text}},
+ extra=extra,
+ url=self.kb_url,
+ api_key=self._config["TOKEN"],
+ )
+ ids.append(id)
+ return ids
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ if not ids:
+ return None
+ from nuclia.sdk import NucliaResource
+
+ factory = NucliaResource()
+ results: List[bool] = []
+ for id in ids:
+ try:
+ factory.delete(rid=id, url=self.kb_url, api_key=self._config["TOKEN"])
+ results.append(True)
+ except ValueError:
+ results.append(False)
+ return all(results)
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ from nuclia.sdk import NucliaSearch
+ from nucliadb_models.search import FindRequest, ResourceProperties
+
+ request = FindRequest(
+ query=query,
+ page_size=k,
+ show=[ResourceProperties.VALUES, ResourceProperties.EXTRA],
+ )
+ search = NucliaSearch()
+ results = search.find(
+ query=request, url=self.kb_url, api_key=self._config["TOKEN"]
+ )
+ paragraphs = []
+ for resource in results.resources.values():
+ for field in resource.fields.values():
+ for paragraph_id, paragraph in field.paragraphs.items():
+ info = paragraph_id.split("/")
+ field_type = FIELD_TYPES.get(info[1], None)
+ field_id = info[2]
+ if not field_type:
+ continue
+ value = getattr(resource.data, field_type, {}).get(field_id, None)
+ paragraphs.append(
+ {
+ "text": paragraph.text,
+ "metadata": {
+ "extra": getattr(
+ getattr(resource, "extra", {}), "metadata", None
+ ),
+ "value": value,
+ },
+ "order": paragraph.order,
+ }
+ )
+ sorted_paragraphs = sorted(paragraphs, key=lambda x: x["order"])
+ return [
+ Document(page_content=paragraph["text"], metadata=paragraph["metadata"])
+ for paragraph in sorted_paragraphs
+ ]
+
+ @classmethod
+ def from_texts(
+ cls: Type[VST],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> VST:
+ """Return VectorStore initialized from texts and embeddings."""
+ raise NotImplementedError
diff --git a/libs/community/langchain_community/vectorstores/opensearch_vector_search.py b/libs/community/langchain_community/vectorstores/opensearch_vector_search.py
new file mode 100644
index 00000000000..99ece00e7a6
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/opensearch_vector_search.py
@@ -0,0 +1,915 @@
+from __future__ import annotations
+
+import uuid
+import warnings
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+IMPORT_OPENSEARCH_PY_ERROR = (
+ "Could not import OpenSearch. Please install it with `pip install opensearch-py`."
+)
+SCRIPT_SCORING_SEARCH = "script_scoring"
+PAINLESS_SCRIPTING_SEARCH = "painless_scripting"
+MATCH_ALL_QUERY = {"match_all": {}} # type: Dict
+
+
+def _import_opensearch() -> Any:
+ """Import OpenSearch if available, otherwise raise error."""
+ try:
+ from opensearchpy import OpenSearch
+ except ImportError:
+ raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
+ return OpenSearch
+
+
+def _import_bulk() -> Any:
+ """Import bulk if available, otherwise raise error."""
+ try:
+ from opensearchpy.helpers import bulk
+ except ImportError:
+ raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
+ return bulk
+
+
+def _import_not_found_error() -> Any:
+ """Import not found error if available, otherwise raise error."""
+ try:
+ from opensearchpy.exceptions import NotFoundError
+ except ImportError:
+ raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
+ return NotFoundError
+
+
+def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
+ """Get OpenSearch client from the opensearch_url, otherwise raise error."""
+ try:
+ opensearch = _import_opensearch()
+ client = opensearch(opensearch_url, **kwargs)
+ except ValueError as e:
+ raise ImportError(
+ f"OpenSearch client string provided is not in proper format. "
+ f"Got error: {e} "
+ )
+ return client
+
+
+def _validate_embeddings_and_bulk_size(embeddings_length: int, bulk_size: int) -> None:
+ """Validate Embeddings Length and Bulk Size."""
+ if embeddings_length == 0:
+ raise RuntimeError("Embeddings size is zero")
+ if bulk_size < embeddings_length:
+ raise RuntimeError(
+ f"The embeddings count, {embeddings_length} is more than the "
+ f"[bulk_size], {bulk_size}. Increase the value of [bulk_size]."
+ )
+
+
+def _validate_aoss_with_engines(is_aoss: bool, engine: str) -> None:
+ """Validate AOSS with the engine."""
+ if is_aoss and engine != "nmslib" and engine != "faiss":
+ raise ValueError(
+ "Amazon OpenSearch Service Serverless only "
+ "supports `nmslib` or `faiss` engines"
+ )
+
+
+def _is_aoss_enabled(http_auth: Any) -> bool:
+ """Check if the service is http_auth is set as `aoss`."""
+ if (
+ http_auth is not None
+ and hasattr(http_auth, "service")
+ and http_auth.service == "aoss"
+ ):
+ return True
+ return False
+
+
+def _bulk_ingest_embeddings(
+ client: Any,
+ index_name: str,
+ embeddings: List[List[float]],
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ vector_field: str = "vector_field",
+ text_field: str = "text",
+ mapping: Optional[Dict] = None,
+ max_chunk_bytes: Optional[int] = 1 * 1024 * 1024,
+ is_aoss: bool = False,
+) -> List[str]:
+ """Bulk Ingest Embeddings into given index."""
+ if not mapping:
+ mapping = dict()
+
+ bulk = _import_bulk()
+ not_found_error = _import_not_found_error()
+ requests = []
+ return_ids = []
+ mapping = mapping
+
+ try:
+ client.indices.get(index=index_name)
+ except not_found_error:
+ client.indices.create(index=index_name, body=mapping)
+
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+ _id = ids[i] if ids else str(uuid.uuid4())
+ request = {
+ "_op_type": "index",
+ "_index": index_name,
+ vector_field: embeddings[i],
+ text_field: text,
+ "metadata": metadata,
+ }
+ if is_aoss:
+ request["id"] = _id
+ else:
+ request["_id"] = _id
+ requests.append(request)
+ return_ids.append(_id)
+ bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
+ if not is_aoss:
+ client.indices.refresh(index=index_name)
+ return return_ids
+
+
+def _default_scripting_text_mapping(
+ dim: int,
+ vector_field: str = "vector_field",
+) -> Dict:
+ """For Painless Scripting or Script Scoring,the default mapping to create index."""
+ return {
+ "mappings": {
+ "properties": {
+ vector_field: {"type": "knn_vector", "dimension": dim},
+ }
+ }
+ }
+
+
+def _default_text_mapping(
+ dim: int,
+ engine: str = "nmslib",
+ space_type: str = "l2",
+ ef_search: int = 512,
+ ef_construction: int = 512,
+ m: int = 16,
+ vector_field: str = "vector_field",
+) -> Dict:
+ """For Approximate k-NN Search, this is the default mapping to create index."""
+ return {
+ "settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}},
+ "mappings": {
+ "properties": {
+ vector_field: {
+ "type": "knn_vector",
+ "dimension": dim,
+ "method": {
+ "name": "hnsw",
+ "space_type": space_type,
+ "engine": engine,
+ "parameters": {"ef_construction": ef_construction, "m": m},
+ },
+ }
+ }
+ },
+ }
+
+
+def _default_approximate_search_query(
+ query_vector: List[float],
+ k: int = 4,
+ vector_field: str = "vector_field",
+) -> Dict:
+ """For Approximate k-NN Search, this is the default query."""
+ return {
+ "size": k,
+ "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
+ }
+
+
+def _approximate_search_query_with_boolean_filter(
+ query_vector: List[float],
+ boolean_filter: Dict,
+ k: int = 4,
+ vector_field: str = "vector_field",
+ subquery_clause: str = "must",
+) -> Dict:
+ """For Approximate k-NN Search, with Boolean Filter."""
+ return {
+ "size": k,
+ "query": {
+ "bool": {
+ "filter": boolean_filter,
+ subquery_clause: [
+ {"knn": {vector_field: {"vector": query_vector, "k": k}}}
+ ],
+ }
+ },
+ }
+
+
+def _approximate_search_query_with_efficient_filter(
+ query_vector: List[float],
+ efficient_filter: Dict,
+ k: int = 4,
+ vector_field: str = "vector_field",
+) -> Dict:
+ """For Approximate k-NN Search, with Efficient Filter for Lucene and
+ Faiss Engines."""
+ search_query = _default_approximate_search_query(
+ query_vector, k=k, vector_field=vector_field
+ )
+ search_query["query"]["knn"][vector_field]["filter"] = efficient_filter
+ return search_query
+
+
+def _default_script_query(
+ query_vector: List[float],
+ k: int = 4,
+ space_type: str = "l2",
+ pre_filter: Optional[Dict] = None,
+ vector_field: str = "vector_field",
+) -> Dict:
+ """For Script Scoring Search, this is the default query."""
+
+ if not pre_filter:
+ pre_filter = MATCH_ALL_QUERY
+
+ return {
+ "size": k,
+ "query": {
+ "script_score": {
+ "query": pre_filter,
+ "script": {
+ "source": "knn_score",
+ "lang": "knn",
+ "params": {
+ "field": vector_field,
+ "query_value": query_vector,
+ "space_type": space_type,
+ },
+ },
+ }
+ },
+ }
+
+
+def __get_painless_scripting_source(
+ space_type: str, vector_field: str = "vector_field"
+) -> str:
+ """For Painless Scripting, it returns the script source based on space type."""
+ source_value = (
+ "(1.0 + " + space_type + "(params.query_value, doc['" + vector_field + "']))"
+ )
+ if space_type == "cosineSimilarity":
+ return source_value
+ else:
+ return "1/" + source_value
+
+
+def _default_painless_scripting_query(
+ query_vector: List[float],
+ k: int = 4,
+ space_type: str = "l2Squared",
+ pre_filter: Optional[Dict] = None,
+ vector_field: str = "vector_field",
+) -> Dict:
+ """For Painless Scripting Search, this is the default query."""
+
+ if not pre_filter:
+ pre_filter = MATCH_ALL_QUERY
+
+ source = __get_painless_scripting_source(space_type, vector_field=vector_field)
+ return {
+ "size": k,
+ "query": {
+ "script_score": {
+ "query": pre_filter,
+ "script": {
+ "source": source,
+ "params": {
+ "field": vector_field,
+ "query_value": query_vector,
+ },
+ },
+ }
+ },
+ }
+
+
+class OpenSearchVectorSearch(VectorStore):
+ """`Amazon OpenSearch Vector Engine` vector store.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import OpenSearchVectorSearch
+ opensearch_vector_search = OpenSearchVectorSearch(
+ "http://localhost:9200",
+ "embeddings",
+ embedding_function
+ )
+
+ """
+
+ def __init__(
+ self,
+ opensearch_url: str,
+ index_name: str,
+ embedding_function: Embeddings,
+ **kwargs: Any,
+ ):
+ """Initialize with necessary components."""
+ self.embedding_function = embedding_function
+ self.index_name = index_name
+ http_auth = kwargs.get("http_auth")
+ self.is_aoss = _is_aoss_enabled(http_auth=http_auth)
+ self.client = _get_opensearch_client(opensearch_url, **kwargs)
+ self.engine = kwargs.get("engine")
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_function
+
+ def __add(
+ self,
+ texts: Iterable[str],
+ embeddings: List[List[float]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ bulk_size: int = 500,
+ **kwargs: Any,
+ ) -> List[str]:
+ _validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
+ index_name = kwargs.get("index_name", self.index_name)
+ text_field = kwargs.get("text_field", "text")
+ dim = len(embeddings[0])
+ engine = kwargs.get("engine", "nmslib")
+ space_type = kwargs.get("space_type", "l2")
+ ef_search = kwargs.get("ef_search", 512)
+ ef_construction = kwargs.get("ef_construction", 512)
+ m = kwargs.get("m", 16)
+ vector_field = kwargs.get("vector_field", "vector_field")
+ max_chunk_bytes = kwargs.get("max_chunk_bytes", 1 * 1024 * 1024)
+
+ _validate_aoss_with_engines(self.is_aoss, engine)
+
+ mapping = _default_text_mapping(
+ dim, engine, space_type, ef_search, ef_construction, m, vector_field
+ )
+
+ return _bulk_ingest_embeddings(
+ self.client,
+ index_name,
+ embeddings,
+ texts,
+ metadatas=metadatas,
+ ids=ids,
+ vector_field=vector_field,
+ text_field=text_field,
+ mapping=mapping,
+ max_chunk_bytes=max_chunk_bytes,
+ is_aoss=self.is_aoss,
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ bulk_size: int = 500,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids to associate with the texts.
+ bulk_size: Bulk API request count; Default: 500
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+
+ Optional Args:
+ vector_field: Document field embeddings are stored in. Defaults to
+ "vector_field".
+
+ text_field: Document field the text of the document is stored in. Defaults
+ to "text".
+ """
+ embeddings = self.embedding_function.embed_documents(list(texts))
+ return self.__add(
+ texts,
+ embeddings,
+ metadatas=metadatas,
+ ids=ids,
+ bulk_size=bulk_size,
+ **kwargs,
+ )
+
+ def add_embeddings(
+ self,
+ text_embeddings: Iterable[Tuple[str, List[float]]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ bulk_size: int = 500,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add the given texts and embeddings to the vectorstore.
+
+ Args:
+ text_embeddings: Iterable pairs of string and embedding to
+ add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids to associate with the texts.
+ bulk_size: Bulk API request count; Default: 500
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+
+ Optional Args:
+ vector_field: Document field embeddings are stored in. Defaults to
+ "vector_field".
+
+ text_field: Document field the text of the document is stored in. Defaults
+ to "text".
+ """
+ texts, embeddings = zip(*text_embeddings)
+ return self.__add(
+ list(texts),
+ list(embeddings),
+ metadatas=metadatas,
+ ids=ids,
+ bulk_size=bulk_size,
+ **kwargs,
+ )
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ By default, supports Approximate Search.
+ Also supports Script Scoring and Painless Scripting.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query.
+
+ Optional Args:
+ vector_field: Document field embeddings are stored in. Defaults to
+ "vector_field".
+
+ text_field: Document field the text of the document is stored in. Defaults
+ to "text".
+
+ metadata_field: Document field that metadata is stored in. Defaults to
+ "metadata".
+ Can be set to a special value "*" to include the entire document.
+
+ Optional Args for Approximate Search:
+ search_type: "approximate_search"; default: "approximate_search"
+
+ boolean_filter: A Boolean filter is a post filter consists of a Boolean
+ query that contains a k-NN query and a filter.
+
+ subquery_clause: Query clause on the knn vector field; default: "must"
+
+ lucene_filter: the Lucene algorithm decides whether to perform an exact
+ k-NN search with pre-filtering or an approximate search with modified
+ post-filtering. (deprecated, use `efficient_filter`)
+
+ efficient_filter: the Lucene Engine or Faiss Engine decides whether to
+ perform an exact k-NN search with pre-filtering or an approximate search
+ with modified post-filtering.
+
+ Optional Args for Script Scoring Search:
+ search_type: "script_scoring"; default: "approximate_search"
+
+ space_type: "l2", "l1", "linf", "cosinesimil", "innerproduct",
+ "hammingbit"; default: "l2"
+
+ pre_filter: script_score query to pre-filter documents before identifying
+ nearest neighbors; default: {"match_all": {}}
+
+ Optional Args for Painless Scripting Search:
+ search_type: "painless_scripting"; default: "approximate_search"
+
+ space_type: "l2Squared", "l1Norm", "cosineSimilarity"; default: "l2Squared"
+
+ pre_filter: script_score query to pre-filter documents before identifying
+ nearest neighbors; default: {"match_all": {}}
+ """
+ docs_with_scores = self.similarity_search_with_score(query, k, **kwargs)
+ return [doc[0] for doc in docs_with_scores]
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and it's scores most similar to query.
+
+ By default, supports Approximate Search.
+ Also supports Script Scoring and Painless Scripting.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents along with its scores most similar to the query.
+
+ Optional Args:
+ same as `similarity_search`
+ """
+
+ text_field = kwargs.get("text_field", "text")
+ metadata_field = kwargs.get("metadata_field", "metadata")
+
+ hits = self._raw_similarity_search_with_score(query=query, k=k, **kwargs)
+
+ documents_with_scores = [
+ (
+ Document(
+ page_content=hit["_source"][text_field],
+ metadata=hit["_source"]
+ if metadata_field == "*" or metadata_field not in hit["_source"]
+ else hit["_source"][metadata_field],
+ ),
+ hit["_score"],
+ )
+ for hit in hits
+ ]
+ return documents_with_scores
+
+ def _raw_similarity_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[dict]:
+ """Return raw opensearch documents (dict) including vectors,
+ scores most similar to query.
+
+ By default, supports Approximate Search.
+ Also supports Script Scoring and Painless Scripting.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of dict with its scores most similar to the query.
+
+ Optional Args:
+ same as `similarity_search`
+ """
+ embedding = self.embedding_function.embed_query(query)
+ search_type = kwargs.get("search_type", "approximate_search")
+ vector_field = kwargs.get("vector_field", "vector_field")
+ index_name = kwargs.get("index_name", self.index_name)
+ filter = kwargs.get("filter", {})
+
+ if (
+ self.is_aoss
+ and search_type != "approximate_search"
+ and search_type != SCRIPT_SCORING_SEARCH
+ ):
+ raise ValueError(
+ "Amazon OpenSearch Service Serverless only "
+ "supports `approximate_search` and `script_scoring`"
+ )
+
+ if search_type == "approximate_search":
+ boolean_filter = kwargs.get("boolean_filter", {})
+ subquery_clause = kwargs.get("subquery_clause", "must")
+ efficient_filter = kwargs.get("efficient_filter", {})
+ # `lucene_filter` is deprecated, added for Backwards Compatibility
+ lucene_filter = kwargs.get("lucene_filter", {})
+
+ if boolean_filter != {} and efficient_filter != {}:
+ raise ValueError(
+ "Both `boolean_filter` and `efficient_filter` are provided which "
+ "is invalid"
+ )
+
+ if lucene_filter != {} and efficient_filter != {}:
+ raise ValueError(
+ "Both `lucene_filter` and `efficient_filter` are provided which "
+ "is invalid. `lucene_filter` is deprecated"
+ )
+
+ if lucene_filter != {} and boolean_filter != {}:
+ raise ValueError(
+ "Both `lucene_filter` and `boolean_filter` are provided which "
+ "is invalid. `lucene_filter` is deprecated"
+ )
+
+ if (
+ efficient_filter == {}
+ and boolean_filter == {}
+ and lucene_filter == {}
+ and filter != {}
+ ):
+ if self.engine in ["faiss", "lucene"]:
+ efficient_filter = filter
+ else:
+ boolean_filter = filter
+
+ if boolean_filter != {}:
+ search_query = _approximate_search_query_with_boolean_filter(
+ embedding,
+ boolean_filter,
+ k=k,
+ vector_field=vector_field,
+ subquery_clause=subquery_clause,
+ )
+ elif efficient_filter != {}:
+ search_query = _approximate_search_query_with_efficient_filter(
+ embedding, efficient_filter, k=k, vector_field=vector_field
+ )
+ elif lucene_filter != {}:
+ warnings.warn(
+ "`lucene_filter` is deprecated. Please use the keyword argument"
+ " `efficient_filter`"
+ )
+ search_query = _approximate_search_query_with_efficient_filter(
+ embedding, lucene_filter, k=k, vector_field=vector_field
+ )
+ else:
+ search_query = _default_approximate_search_query(
+ embedding, k=k, vector_field=vector_field
+ )
+ elif search_type == SCRIPT_SCORING_SEARCH:
+ space_type = kwargs.get("space_type", "l2")
+ pre_filter = kwargs.get("pre_filter", MATCH_ALL_QUERY)
+ search_query = _default_script_query(
+ embedding, k, space_type, pre_filter, vector_field
+ )
+ elif search_type == PAINLESS_SCRIPTING_SEARCH:
+ space_type = kwargs.get("space_type", "l2Squared")
+ pre_filter = kwargs.get("pre_filter", MATCH_ALL_QUERY)
+ search_query = _default_painless_scripting_query(
+ embedding, k, space_type, pre_filter, vector_field
+ )
+ else:
+ raise ValueError("Invalid `search_type` provided as an argument")
+
+ response = self.client.search(index=index_name, body=search_query)
+
+ return [hit for hit in response["hits"]["hits"]]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> list[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+
+ vector_field = kwargs.get("vector_field", "vector_field")
+ text_field = kwargs.get("text_field", "text")
+ metadata_field = kwargs.get("metadata_field", "metadata")
+
+ # Get embedding of the user query
+ embedding = self.embedding_function.embed_query(query)
+
+ # Do ANN/KNN search to get top fetch_k results where fetch_k >= k
+ results = self._raw_similarity_search_with_score(query, fetch_k, **kwargs)
+
+ embeddings = [result["_source"][vector_field] for result in results]
+
+ # Rerank top k results using MMR, (mmr_selected is a list of indices)
+ mmr_selected = maximal_marginal_relevance(
+ np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
+ )
+
+ return [
+ Document(
+ page_content=results[i]["_source"][text_field],
+ metadata=results[i]["_source"][metadata_field],
+ )
+ for i in mmr_selected
+ ]
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ bulk_size: int = 500,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> OpenSearchVectorSearch:
+ """Construct OpenSearchVectorSearch wrapper from raw texts.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import OpenSearchVectorSearch
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ opensearch_vector_search = OpenSearchVectorSearch.from_texts(
+ texts,
+ embeddings,
+ opensearch_url="http://localhost:9200"
+ )
+
+ OpenSearch by default supports Approximate Search powered by nmslib, faiss
+ and lucene engines recommended for large datasets. Also supports brute force
+ search through Script Scoring and Painless Scripting.
+
+ Optional Args:
+ vector_field: Document field embeddings are stored in. Defaults to
+ "vector_field".
+
+ text_field: Document field the text of the document is stored in. Defaults
+ to "text".
+
+ Optional Keyword Args for Approximate Search:
+ engine: "nmslib", "faiss", "lucene"; default: "nmslib"
+
+ space_type: "l2", "l1", "cosinesimil", "linf", "innerproduct"; default: "l2"
+
+ ef_search: Size of the dynamic list used during k-NN searches. Higher values
+ lead to more accurate but slower searches; default: 512
+
+ ef_construction: Size of the dynamic list used during k-NN graph creation.
+ Higher values lead to more accurate graph but slower indexing speed;
+ default: 512
+
+ m: Number of bidirectional links created for each new element. Large impact
+ on memory consumption. Between 2 and 100; default: 16
+
+ Keyword Args for Script Scoring or Painless Scripting:
+ is_appx_search: False
+
+ """
+ embeddings = embedding.embed_documents(texts)
+ return cls.from_embeddings(
+ embeddings,
+ texts,
+ embedding,
+ metadatas=metadatas,
+ bulk_size=bulk_size,
+ ids=ids,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ embeddings: List[List[float]],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ bulk_size: int = 500,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> OpenSearchVectorSearch:
+ """Construct OpenSearchVectorSearch wrapper from pre-vectorized embeddings.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import OpenSearchVectorSearch
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embedder = OpenAIEmbeddings()
+ embeddings = embedder.embed_documents(["foo", "bar"])
+ opensearch_vector_search = OpenSearchVectorSearch.from_embeddings(
+ embeddings,
+ texts,
+ embedder,
+ opensearch_url="http://localhost:9200"
+ )
+
+ OpenSearch by default supports Approximate Search powered by nmslib, faiss
+ and lucene engines recommended for large datasets. Also supports brute force
+ search through Script Scoring and Painless Scripting.
+
+ Optional Args:
+ vector_field: Document field embeddings are stored in. Defaults to
+ "vector_field".
+
+ text_field: Document field the text of the document is stored in. Defaults
+ to "text".
+
+ Optional Keyword Args for Approximate Search:
+ engine: "nmslib", "faiss", "lucene"; default: "nmslib"
+
+ space_type: "l2", "l1", "cosinesimil", "linf", "innerproduct"; default: "l2"
+
+ ef_search: Size of the dynamic list used during k-NN searches. Higher values
+ lead to more accurate but slower searches; default: 512
+
+ ef_construction: Size of the dynamic list used during k-NN graph creation.
+ Higher values lead to more accurate graph but slower indexing speed;
+ default: 512
+
+ m: Number of bidirectional links created for each new element. Large impact
+ on memory consumption. Between 2 and 100; default: 16
+
+ Keyword Args for Script Scoring or Painless Scripting:
+ is_appx_search: False
+
+ """
+ opensearch_url = get_from_dict_or_env(
+ kwargs, "opensearch_url", "OPENSEARCH_URL"
+ )
+ # List of arguments that needs to be removed from kwargs
+ # before passing kwargs to get opensearch client
+ keys_list = [
+ "opensearch_url",
+ "index_name",
+ "is_appx_search",
+ "vector_field",
+ "text_field",
+ "engine",
+ "space_type",
+ "ef_search",
+ "ef_construction",
+ "m",
+ "max_chunk_bytes",
+ "is_aoss",
+ ]
+ _validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
+ dim = len(embeddings[0])
+ # Get the index name from either from kwargs or ENV Variable
+ # before falling back to random generation
+ index_name = get_from_dict_or_env(
+ kwargs, "index_name", "OPENSEARCH_INDEX_NAME", default=uuid.uuid4().hex
+ )
+ is_appx_search = kwargs.get("is_appx_search", True)
+ vector_field = kwargs.get("vector_field", "vector_field")
+ text_field = kwargs.get("text_field", "text")
+ max_chunk_bytes = kwargs.get("max_chunk_bytes", 1 * 1024 * 1024)
+ http_auth = kwargs.get("http_auth")
+ is_aoss = _is_aoss_enabled(http_auth=http_auth)
+ engine = None
+
+ if is_aoss and not is_appx_search:
+ raise ValueError(
+ "Amazon OpenSearch Service Serverless only "
+ "supports `approximate_search`"
+ )
+
+ if is_appx_search:
+ engine = kwargs.get("engine", "nmslib")
+ space_type = kwargs.get("space_type", "l2")
+ ef_search = kwargs.get("ef_search", 512)
+ ef_construction = kwargs.get("ef_construction", 512)
+ m = kwargs.get("m", 16)
+
+ _validate_aoss_with_engines(is_aoss, engine)
+
+ mapping = _default_text_mapping(
+ dim, engine, space_type, ef_search, ef_construction, m, vector_field
+ )
+ else:
+ mapping = _default_scripting_text_mapping(dim)
+
+ [kwargs.pop(key, None) for key in keys_list]
+ client = _get_opensearch_client(opensearch_url, **kwargs)
+ _bulk_ingest_embeddings(
+ client,
+ index_name,
+ embeddings,
+ texts,
+ ids=ids,
+ metadatas=metadatas,
+ vector_field=vector_field,
+ text_field=text_field,
+ mapping=mapping,
+ max_chunk_bytes=max_chunk_bytes,
+ is_aoss=is_aoss,
+ )
+ kwargs["engine"] = engine
+ return cls(opensearch_url, index_name, embedding, **kwargs)
diff --git a/libs/community/langchain_community/vectorstores/pgembedding.py b/libs/community/langchain_community/vectorstores/pgembedding.py
new file mode 100644
index 00000000000..d5c37d5942e
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/pgembedding.py
@@ -0,0 +1,531 @@
+from __future__ import annotations
+
+import logging
+import uuid
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
+
+import sqlalchemy
+from sqlalchemy import func
+from sqlalchemy.dialects.postgresql import JSON, UUID
+from sqlalchemy.orm import Session, relationship
+
+try:
+ from sqlalchemy.orm import declarative_base
+except ImportError:
+ from sqlalchemy.ext.declarative import declarative_base
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore
+
+Base = declarative_base() # type: Any
+
+
+ADA_TOKEN_COUNT = 1536
+_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
+
+
+class BaseModel(Base):
+ """Base model for all SQL stores."""
+
+ __abstract__ = True
+ uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+
+
+class CollectionStore(BaseModel):
+ """Collection store."""
+
+ __tablename__ = "langchain_pg_collection"
+
+ name = sqlalchemy.Column(sqlalchemy.String)
+ cmetadata = sqlalchemy.Column(JSON)
+
+ embeddings = relationship(
+ "EmbeddingStore",
+ back_populates="collection",
+ passive_deletes=True,
+ )
+
+ @classmethod
+ def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
+ return session.query(cls).filter(cls.name == name).first() # type: ignore
+
+ @classmethod
+ def get_or_create(
+ cls,
+ session: Session,
+ name: str,
+ cmetadata: Optional[dict] = None,
+ ) -> Tuple["CollectionStore", bool]:
+ """
+ Get or create a collection.
+ Returns [Collection, bool] where the bool is True if the collection was created.
+ """
+ created = False
+ collection = cls.get_by_name(session, name)
+ if collection:
+ return collection, created
+
+ collection = cls(name=name, cmetadata=cmetadata)
+ session.add(collection)
+ session.commit()
+ created = True
+ return collection, created
+
+
+class EmbeddingStore(BaseModel):
+ """Embedding store."""
+
+ __tablename__ = "langchain_pg_embedding"
+
+ collection_id = sqlalchemy.Column(
+ UUID(as_uuid=True),
+ sqlalchemy.ForeignKey(
+ f"{CollectionStore.__tablename__}.uuid",
+ ondelete="CASCADE",
+ ),
+ )
+ collection = relationship(CollectionStore, back_populates="embeddings")
+
+ embedding = sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL)) # type: ignore
+ document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
+ cmetadata = sqlalchemy.Column(JSON, nullable=True)
+
+ # custom_id : any user defined id
+ custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
+
+
+class QueryResult:
+ """Result from a query."""
+
+ EmbeddingStore: EmbeddingStore
+ distance: float
+
+
+class PGEmbedding(VectorStore):
+ """`Postgres` with the `pg_embedding` extension as a vector store.
+
+ pg_embedding uses sequential scan by default. but you can create a HNSW index
+ using the create_hnsw_index method.
+ - `connection_string` is a postgres connection string.
+ - `embedding_function` any embedding function implementing
+ `langchain.embeddings.base.Embeddings` interface.
+ - `collection_name` is the name of the collection to use. (default: langchain)
+ - NOTE: This is not the name of the table, but the name of the collection.
+ The tables will be created when initializing the store (if not exists)
+ So, make sure the user has the right permissions to create tables.
+ - `distance_strategy` is the distance strategy to use. (default: EUCLIDEAN)
+ - `EUCLIDEAN` is the euclidean distance.
+ - `pre_delete_collection` if True, will delete the collection if it exists.
+ (default: False)
+ - Useful for testing.
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ embedding_function: Embeddings,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ collection_metadata: Optional[dict] = None,
+ pre_delete_collection: bool = False,
+ logger: Optional[logging.Logger] = None,
+ ) -> None:
+ self.connection_string = connection_string
+ self.embedding_function = embedding_function
+ self.collection_name = collection_name
+ self.collection_metadata = collection_metadata
+ self.pre_delete_collection = pre_delete_collection
+ self.logger = logger or logging.getLogger(__name__)
+ self.__post_init__()
+
+ def __post_init__(
+ self,
+ ) -> None:
+ self._conn = self.connect()
+ self.create_hnsw_extension()
+ self.create_tables_if_not_exists()
+ self.create_collection()
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_function
+
+ def connect(self) -> sqlalchemy.engine.Connection:
+ engine = sqlalchemy.create_engine(self.connection_string)
+ conn = engine.connect()
+ return conn
+
+ def create_hnsw_extension(self) -> None:
+ try:
+ with Session(self._conn) as session:
+ statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS embedding")
+ session.execute(statement)
+ session.commit()
+ except Exception as e:
+ self.logger.exception(e)
+
+ def create_tables_if_not_exists(self) -> None:
+ with self._conn.begin():
+ Base.metadata.create_all(self._conn)
+
+ def drop_tables(self) -> None:
+ with self._conn.begin():
+ Base.metadata.drop_all(self._conn)
+
+ def create_collection(self) -> None:
+ if self.pre_delete_collection:
+ self.delete_collection()
+ with Session(self._conn) as session:
+ CollectionStore.get_or_create(
+ session, self.collection_name, cmetadata=self.collection_metadata
+ )
+
+ def create_hnsw_index(
+ self,
+ max_elements: int = 10000,
+ dims: int = ADA_TOKEN_COUNT,
+ m: int = 8,
+ ef_construction: int = 16,
+ ef_search: int = 16,
+ ) -> None:
+ create_index_query = sqlalchemy.text(
+ "CREATE INDEX IF NOT EXISTS langchain_pg_embedding_idx "
+ "ON langchain_pg_embedding USING hnsw (embedding) "
+ "WITH ("
+ "maxelements = {}, "
+ "dims = {}, "
+ "m = {}, "
+ "efconstruction = {}, "
+ "efsearch = {}"
+ ");".format(max_elements, dims, m, ef_construction, ef_search)
+ )
+
+ # Execute the queries
+ try:
+ with Session(self._conn) as session:
+ # Create the HNSW index
+ session.execute(create_index_query)
+ session.commit()
+ print("HNSW extension and index created successfully.")
+ except Exception as e:
+ print(f"Failed to create HNSW extension or index: {e}")
+
+ def delete_collection(self) -> None:
+ self.logger.debug("Trying to delete collection")
+ with Session(self._conn) as session:
+ collection = self.get_collection(session)
+ if not collection:
+ self.logger.warning("Collection not found")
+ return
+ session.delete(collection)
+ session.commit()
+
+ def get_collection(self, session: Session) -> Optional["CollectionStore"]:
+ return CollectionStore.get_by_name(session, self.collection_name)
+
+ @classmethod
+ def _initialize_from_embeddings(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGEmbedding:
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ connection_string = cls.get_connection_string(kwargs)
+
+ store = cls(
+ connection_string=connection_string,
+ collection_name=collection_name,
+ embedding_function=embedding,
+ pre_delete_collection=pre_delete_collection,
+ )
+
+ store.add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ return store
+
+ def add_embeddings(
+ self,
+ texts: List[str],
+ embeddings: List[List[float]],
+ metadatas: List[dict],
+ ids: List[str],
+ **kwargs: Any,
+ ) -> None:
+ with Session(self._conn) as session:
+ collection = self.get_collection(session)
+ if not collection:
+ raise ValueError("Collection not found")
+ for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
+ embedding_store = EmbeddingStore(
+ embedding=embedding,
+ document=text,
+ cmetadata=metadata,
+ custom_id=id,
+ )
+ collection.embeddings.append(embedding_store)
+ session.add(embedding_store)
+ session.commit()
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ embeddings = self.embedding_function.embed_documents(list(texts))
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ with Session(self._conn) as session:
+ collection = self.get_collection(session)
+ if not collection:
+ raise ValueError("Collection not found")
+ for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
+ embedding_store = EmbeddingStore(
+ embedding=embedding,
+ document=text,
+ cmetadata=metadata,
+ custom_id=id,
+ )
+ collection.embeddings.append(embedding_store)
+ session.add(embedding_store)
+ session.commit()
+
+ return ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ embedding = self.embedding_function.embed_query(text=query)
+ return self.similarity_search_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ )
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ ) -> List[Tuple[Document, float]]:
+ embedding = self.embedding_function.embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return docs
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[dict] = None,
+ ) -> List[Tuple[Document, float]]:
+ with Session(self._conn) as session:
+ collection = self.get_collection(session)
+ set_enable_seqscan_stmt = sqlalchemy.text("SET enable_seqscan = off")
+ session.execute(set_enable_seqscan_stmt)
+ if not collection:
+ raise ValueError("Collection not found")
+
+ filter_by = EmbeddingStore.collection_id == collection.uuid
+
+ if filter is not None:
+ filter_clauses = []
+ for key, value in filter.items():
+ IN = "in"
+ if isinstance(value, dict) and IN in map(str.lower, value):
+ value_case_insensitive = {
+ k.lower(): v for k, v in value.items()
+ }
+ filter_by_metadata = EmbeddingStore.cmetadata[key].astext.in_(
+ value_case_insensitive[IN]
+ )
+ filter_clauses.append(filter_by_metadata)
+ elif isinstance(value, dict) and "substring" in map(
+ str.lower, value
+ ):
+ filter_by_metadata = EmbeddingStore.cmetadata[key].astext.ilike(
+ f"%{value['substring']}%"
+ )
+ filter_clauses.append(filter_by_metadata)
+ else:
+ filter_by_metadata = EmbeddingStore.cmetadata[
+ key
+ ].astext == str(value)
+ filter_clauses.append(filter_by_metadata)
+
+ filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
+
+ results: List[QueryResult] = (
+ session.query(
+ EmbeddingStore,
+ func.abs(EmbeddingStore.embedding.op("<->")(embedding)).label(
+ "distance"
+ ),
+ ) # Specify the columns you need here, e.g., EmbeddingStore.embedding
+ .filter(filter_by)
+ .order_by(
+ func.abs(EmbeddingStore.embedding.op("<->")(embedding)).asc()
+ ) # Using PostgreSQL specific operator with the correct column name
+ .limit(k)
+ .all()
+ )
+
+ docs = [
+ (
+ Document(
+ page_content=result.EmbeddingStore.document,
+ metadata=result.EmbeddingStore.cmetadata,
+ ),
+ result.distance if self.embedding_function is not None else 0.0,
+ )
+ for result in results
+ ]
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ @classmethod
+ def from_texts(
+ cls: Type[PGEmbedding],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGEmbedding:
+ embeddings = embedding.embed_documents(list(texts))
+
+ return cls._initialize_from_embeddings(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGEmbedding:
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+
+ return cls._initialize_from_embeddings(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_existing_index(
+ cls: Type[PGEmbedding],
+ embedding: Embeddings,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGEmbedding:
+ connection_string = cls.get_connection_string(kwargs)
+
+ store = cls(
+ connection_string=connection_string,
+ collection_name=collection_name,
+ embedding_function=embedding,
+ pre_delete_collection=pre_delete_collection,
+ )
+
+ return store
+
+ @classmethod
+ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
+ connection_string: str = get_from_dict_or_env(
+ data=kwargs,
+ key="connection_string",
+ env_key="POSTGRES_CONNECTION_STRING",
+ )
+
+ if not connection_string:
+ raise ValueError(
+ "Postgres connection string is required"
+ "Either pass it as a parameter"
+ "or set the POSTGRES_CONNECTION_STRING environment variable."
+ )
+
+ return connection_string
+
+ @classmethod
+ def from_documents(
+ cls: Type[PGEmbedding],
+ documents: List[Document],
+ embedding: Embeddings,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGEmbedding:
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+ connection_string = cls.get_connection_string(kwargs)
+
+ kwargs["connection_string"] = connection_string
+
+ return cls.from_texts(
+ texts=texts,
+ pre_delete_collection=pre_delete_collection,
+ embedding=embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ **kwargs,
+ )
diff --git a/libs/community/langchain_community/vectorstores/pgvecto_rs.py b/libs/community/langchain_community/vectorstores/pgvecto_rs.py
new file mode 100644
index 00000000000..9efdef27ca9
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/pgvecto_rs.py
@@ -0,0 +1,248 @@
+from __future__ import annotations
+
+import uuid
+from typing import Any, Iterable, List, Literal, Optional, Tuple, Type
+
+import numpy as np
+import sqlalchemy
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+from sqlalchemy import insert, select
+from sqlalchemy.dialects import postgresql
+from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from sqlalchemy.orm.session import Session
+
+
+class _ORMBase(DeclarativeBase):
+ __tablename__: str
+ id: Mapped[uuid.UUID]
+ text: Mapped[str]
+ meta: Mapped[dict]
+ embedding: Mapped[np.ndarray]
+
+
+class PGVecto_rs(VectorStore):
+ _engine: sqlalchemy.engine.Engine
+ _table: Type[_ORMBase]
+ _embedding: Embeddings
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ dimension: int,
+ db_url: str,
+ collection_name: str,
+ new_table: bool = False,
+ ) -> None:
+ try:
+ from pgvecto_rs.sqlalchemy import Vector
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import pgvector_rs, please install with "
+ "`pip install pgvector_rs`."
+ ) from e
+
+ class _Table(_ORMBase):
+ __tablename__ = f"collection_{collection_name}"
+ id: Mapped[uuid.UUID] = mapped_column(
+ postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
+ )
+ text: Mapped[str] = mapped_column(sqlalchemy.String)
+ meta: Mapped[dict] = mapped_column(postgresql.JSONB)
+ embedding: Mapped[np.ndarray] = mapped_column(Vector(dimension))
+
+ self._engine = sqlalchemy.create_engine(db_url)
+ self._table = _Table
+ self._table.__table__.create(self._engine, checkfirst=not new_table) # type: ignore
+ self._embedding = embedding
+
+ # ================ Create interface =================
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ db_url: str = "",
+ collection_name: str = str(uuid.uuid4().hex),
+ **kwargs: Any,
+ ) -> PGVecto_rs:
+ """Return VectorStore initialized from texts and optional metadatas."""
+ sample_embedding = embedding.embed_query("Hello pgvecto_rs!")
+ dimension = len(sample_embedding)
+ if db_url is None:
+ raise ValueError("db_url must be provided")
+ _self: PGVecto_rs = cls(
+ embedding=embedding,
+ dimension=dimension,
+ db_url=db_url,
+ collection_name=collection_name,
+ new_table=True,
+ )
+ _self.add_texts(texts, metadatas, **kwargs)
+ return _self
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: List[Document],
+ embedding: Embeddings,
+ db_url: str = "",
+ collection_name: str = str(uuid.uuid4().hex),
+ **kwargs: Any,
+ ) -> PGVecto_rs:
+ """Return VectorStore initialized from documents."""
+ texts = [document.page_content for document in documents]
+ metadatas = [document.metadata for document in documents]
+ return cls.from_texts(
+ texts, embedding, metadatas, db_url, collection_name, **kwargs
+ )
+
+ @classmethod
+ def from_collection_name(
+ cls,
+ embedding: Embeddings,
+ db_url: str,
+ collection_name: str,
+ ) -> PGVecto_rs:
+ """Create new empty vectorstore with collection_name.
+ Or connect to an existing vectorstore in database if exists.
+ Arguments should be the same as when the vectorstore was created."""
+ sample_embedding = embedding.embed_query("Hello pgvecto_rs!")
+ return cls(
+ embedding=embedding,
+ dimension=len(sample_embedding),
+ db_url=db_url,
+ collection_name=collection_name,
+ )
+
+ # ================ Insert interface =================
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids of the added texts.
+
+ """
+ embeddings = self._embedding.embed_documents(list(texts))
+ with Session(self._engine) as _session:
+ results: List[str] = []
+ for text, embedding, metadata in zip(
+ texts, embeddings, metadatas or [dict()] * len(list(texts))
+ ):
+ t = insert(self._table).values(
+ text=text, meta=metadata, embedding=embedding
+ )
+ id = _session.execute(t).inserted_primary_key[0] # type: ignore
+ results.append(str(id))
+ _session.commit()
+ return results
+
+ def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
+ """Run more documents through the embeddings and add to the vectorstore.
+
+ Args:
+ documents (List[Document]): List of documents to add to the vectorstore.
+
+ Returns:
+ List of ids of the added documents.
+ """
+ return self.add_texts(
+ [document.page_content for document in documents],
+ [document.metadata for document in documents],
+ **kwargs,
+ )
+
+ # ================ Query interface =================
+ def similarity_search_with_score_by_vector(
+ self,
+ query_vector: List[float],
+ k: int = 4,
+ distance_func: Literal[
+ "sqrt_euclid", "neg_dot_prod", "ned_cos"
+ ] = "sqrt_euclid",
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query vector, with its score."""
+ with Session(self._engine) as _session:
+ real_distance_func = (
+ self._table.embedding.squared_euclidean_distance
+ if distance_func == "sqrt_euclid"
+ else self._table.embedding.negative_dot_product_distance
+ if distance_func == "neg_dot_prod"
+ else self._table.embedding.negative_cosine_distance
+ if distance_func == "ned_cos"
+ else None
+ )
+ if real_distance_func is None:
+ raise ValueError("Invalid distance function")
+
+ t = (
+ select(self._table, real_distance_func(query_vector).label("score"))
+ .order_by("score")
+ .limit(k) # type: ignore
+ )
+ return [
+ (Document(page_content=row[0].text, metadata=row[0].meta), row[1])
+ for row in _session.execute(t)
+ ]
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ distance_func: Literal[
+ "sqrt_euclid", "neg_dot_prod", "ned_cos"
+ ] = "sqrt_euclid",
+ **kwargs: Any,
+ ) -> List[Document]:
+ return [
+ doc
+ for doc, score in self.similarity_search_with_score_by_vector(
+ embedding, k, distance_func, **kwargs
+ )
+ ]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ distance_func: Literal[
+ "sqrt_euclid", "neg_dot_prod", "ned_cos"
+ ] = "sqrt_euclid",
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ query_vector = self._embedding.embed_query(query)
+ return self.similarity_search_with_score_by_vector(
+ query_vector, k, distance_func, **kwargs
+ )
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ distance_func: Literal[
+ "sqrt_euclid", "neg_dot_prod", "ned_cos"
+ ] = "sqrt_euclid",
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query."""
+ query_vector = self._embedding.embed_query(query)
+ return [
+ doc
+ for doc, score in self.similarity_search_with_score_by_vector(
+ query_vector, k, distance_func, **kwargs
+ )
+ ]
diff --git a/libs/community/langchain_community/vectorstores/pgvector.py b/libs/community/langchain_community/vectorstores/pgvector.py
new file mode 100644
index 00000000000..eb20095930c
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/pgvector.py
@@ -0,0 +1,947 @@
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import enum
+import logging
+import uuid
+from functools import partial
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
+
+import numpy as np
+import sqlalchemy
+from sqlalchemy import delete
+from sqlalchemy.dialects.postgresql import JSON, UUID
+from sqlalchemy.orm import Session, relationship
+
+try:
+ from sqlalchemy.orm import declarative_base
+except ImportError:
+ from sqlalchemy.ext.declarative import declarative_base
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+
+class DistanceStrategy(str, enum.Enum):
+ """Enumerator of the Distance strategies."""
+
+ EUCLIDEAN = "l2"
+ COSINE = "cosine"
+ MAX_INNER_PRODUCT = "inner"
+
+
+DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
+
+Base = declarative_base() # type: Any
+
+
+_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
+
+
+class BaseModel(Base):
+ """Base model for the SQL stores."""
+
+ __abstract__ = True
+ uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+
+
+def _get_embedding_collection_store() -> Any:
+ from pgvector.sqlalchemy import Vector
+
+ class CollectionStore(BaseModel):
+ """Collection store."""
+
+ __tablename__ = "langchain_pg_collection"
+
+ name = sqlalchemy.Column(sqlalchemy.String)
+ cmetadata = sqlalchemy.Column(JSON)
+
+ embeddings = relationship(
+ "EmbeddingStore",
+ back_populates="collection",
+ passive_deletes=True,
+ )
+
+ @classmethod
+ def get_by_name(
+ cls, session: Session, name: str
+ ) -> Optional["CollectionStore"]:
+ return session.query(cls).filter(cls.name == name).first() # type: ignore
+
+ @classmethod
+ def get_or_create(
+ cls,
+ session: Session,
+ name: str,
+ cmetadata: Optional[dict] = None,
+ ) -> Tuple["CollectionStore", bool]:
+ """
+ Get or create a collection.
+ Returns [Collection, bool] where the bool is True if the collection was created.
+ """ # noqa: E501
+ created = False
+ collection = cls.get_by_name(session, name)
+ if collection:
+ return collection, created
+
+ collection = cls(name=name, cmetadata=cmetadata)
+ session.add(collection)
+ session.commit()
+ created = True
+ return collection, created
+
+ class EmbeddingStore(BaseModel):
+ """Embedding store."""
+
+ __tablename__ = "langchain_pg_embedding"
+
+ collection_id = sqlalchemy.Column(
+ UUID(as_uuid=True),
+ sqlalchemy.ForeignKey(
+ f"{CollectionStore.__tablename__}.uuid",
+ ondelete="CASCADE",
+ ),
+ )
+ collection = relationship(CollectionStore, back_populates="embeddings")
+
+ embedding: Vector = sqlalchemy.Column(Vector(None))
+ document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
+ cmetadata = sqlalchemy.Column(JSON, nullable=True)
+
+ # custom_id : any user defined id
+ custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
+
+ return EmbeddingStore, CollectionStore
+
+
+def _results_to_docs(docs_and_scores: Any) -> List[Document]:
+ """Return docs from docs and scores."""
+ return [doc for doc, _ in docs_and_scores]
+
+
+class PGVector(VectorStore):
+ """`Postgres`/`PGVector` vector store.
+
+ To use, you should have the ``pgvector`` python package installed.
+
+ Args:
+ connection_string: Postgres connection string.
+ embedding_function: Any embedding function implementing
+ `langchain.embeddings.base.Embeddings` interface.
+ collection_name: The name of the collection to use. (default: langchain)
+ NOTE: This is not the name of the table, but the name of the collection.
+ The tables will be created when initializing the store (if not exists)
+ So, make sure the user has the right permissions to create tables.
+ distance_strategy: The distance strategy to use. (default: COSINE)
+ pre_delete_collection: If True, will delete the collection if it exists.
+ (default: False). Useful for testing.
+ engine_args: SQLAlchemy's create engine arguments.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import PGVector
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ CONNECTION_STRING = "postgresql+psycopg2://hwc@localhost:5432/test3"
+ COLLECTION_NAME = "state_of_the_union_test"
+ embeddings = OpenAIEmbeddings()
+ vectorestore = PGVector.from_documents(
+ embedding=embeddings,
+ documents=docs,
+ collection_name=COLLECTION_NAME,
+ connection_string=CONNECTION_STRING,
+ )
+
+
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ embedding_function: Embeddings,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ collection_metadata: Optional[dict] = None,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ pre_delete_collection: bool = False,
+ logger: Optional[logging.Logger] = None,
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ *,
+ connection: Optional[sqlalchemy.engine.Connection] = None,
+ engine_args: Optional[dict[str, Any]] = None,
+ ) -> None:
+ self.connection_string = connection_string
+ self.embedding_function = embedding_function
+ self.collection_name = collection_name
+ self.collection_metadata = collection_metadata
+ self._distance_strategy = distance_strategy
+ self.pre_delete_collection = pre_delete_collection
+ self.logger = logger or logging.getLogger(__name__)
+ self.override_relevance_score_fn = relevance_score_fn
+ self.engine_args = engine_args or {}
+ # Create a connection if not provided, otherwise use the provided connection
+ self._conn = connection if connection else self.connect()
+ self.__post_init__()
+
+ def __post_init__(
+ self,
+ ) -> None:
+ """Initialize the store."""
+ self.create_vector_extension()
+
+ EmbeddingStore, CollectionStore = _get_embedding_collection_store()
+ self.CollectionStore = CollectionStore
+ self.EmbeddingStore = EmbeddingStore
+ self.create_tables_if_not_exists()
+ self.create_collection()
+
+ def __del__(self) -> None:
+ if self._conn:
+ self._conn.close()
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_function
+
+ def connect(self) -> sqlalchemy.engine.Connection:
+ engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
+ conn = engine.connect()
+ return conn
+
+ def create_vector_extension(self) -> None:
+ try:
+ with Session(self._conn) as session:
+ # The advisor lock fixes issue arising from concurrent
+ # creation of the vector extension.
+ # https://github.com/langchain-ai/langchain/issues/12933
+ # For more information see:
+ # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
+ statement = sqlalchemy.text(
+ "BEGIN;"
+ "SELECT pg_advisory_xact_lock(1573678846307946496);"
+ "CREATE EXTENSION IF NOT EXISTS vector;"
+ "COMMIT;"
+ )
+ session.execute(statement)
+ session.commit()
+ except Exception as e:
+ raise Exception(f"Failed to create vector extension: {e}") from e
+
+ def create_tables_if_not_exists(self) -> None:
+ with self._conn.begin():
+ Base.metadata.create_all(self._conn)
+
+ def drop_tables(self) -> None:
+ with self._conn.begin():
+ Base.metadata.drop_all(self._conn)
+
+ def create_collection(self) -> None:
+ if self.pre_delete_collection:
+ self.delete_collection()
+ with Session(self._conn) as session:
+ self.CollectionStore.get_or_create(
+ session, self.collection_name, cmetadata=self.collection_metadata
+ )
+
+ def delete_collection(self) -> None:
+ self.logger.debug("Trying to delete collection")
+ with Session(self._conn) as session:
+ collection = self.get_collection(session)
+ if not collection:
+ self.logger.warning("Collection not found")
+ return
+ session.delete(collection)
+ session.commit()
+
+ @contextlib.contextmanager
+ def _make_session(self) -> Generator[Session, None, None]:
+ """Create a context manager for the session, bind to _conn string."""
+ yield Session(self._conn)
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Delete vectors by ids or uuids.
+
+ Args:
+ ids: List of ids to delete.
+ """
+ with Session(self._conn) as session:
+ if ids is not None:
+ self.logger.debug(
+ "Trying to delete vectors by ids (represented by the model "
+ "using the custom ids field)"
+ )
+ stmt = delete(self.EmbeddingStore).where(
+ self.EmbeddingStore.custom_id.in_(ids)
+ )
+ session.execute(stmt)
+ session.commit()
+
+ def get_collection(self, session: Session) -> Any:
+ return self.CollectionStore.get_by_name(session, self.collection_name)
+
+ @classmethod
+ def __from(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ connection_string: Optional[str] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGVector:
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+ if connection_string is None:
+ connection_string = cls.get_connection_string(kwargs)
+
+ store = cls(
+ connection_string=connection_string,
+ collection_name=collection_name,
+ embedding_function=embedding,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ store.add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ return store
+
+ def add_embeddings(
+ self,
+ texts: Iterable[str],
+ embeddings: List[List[float]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add embeddings to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ embeddings: List of list of embedding vectors.
+ metadatas: List of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+ """
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ with Session(self._conn) as session:
+ collection = self.get_collection(session)
+ if not collection:
+ raise ValueError("Collection not found")
+ for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
+ embedding_store = self.EmbeddingStore(
+ embedding=embedding,
+ document=text,
+ cmetadata=metadata,
+ custom_id=id,
+ collection_id=collection.uuid,
+ )
+ session.add(embedding_store)
+ session.commit()
+
+ return ids
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ embeddings = self.embedding_function.embed_documents(list(texts))
+ return self.add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search with PGVector with distance.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ embedding = self.embedding_function.embed_query(text=query)
+ return self.similarity_search_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ )
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query and score for each.
+ """
+ embedding = self.embedding_function.embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return docs
+
+ @property
+ def distance_strategy(self) -> Any:
+ if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
+ return self.EmbeddingStore.embedding.l2_distance
+ elif self._distance_strategy == DistanceStrategy.COSINE:
+ return self.EmbeddingStore.embedding.cosine_distance
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ return self.EmbeddingStore.embedding.max_inner_product
+ else:
+ raise ValueError(
+ f"Got unexpected value for distance: {self._distance_strategy}. "
+ f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
+ )
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[dict] = None,
+ ) -> List[Tuple[Document, float]]:
+ results = self.__query_collection(embedding=embedding, k=k, filter=filter)
+
+ return self._results_to_docs_and_scores(results)
+
+ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
+ """Return docs and scores from results."""
+ docs = [
+ (
+ Document(
+ page_content=result.EmbeddingStore.document,
+ metadata=result.EmbeddingStore.cmetadata,
+ ),
+ result.distance if self.embedding_function is not None else None,
+ )
+ for result in results
+ ]
+ return docs
+
+ def __query_collection(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Any]:
+ """Query the collection."""
+ with Session(self._conn) as session:
+ collection = self.get_collection(session)
+ if not collection:
+ raise ValueError("Collection not found")
+
+ filter_by = self.EmbeddingStore.collection_id == collection.uuid
+
+ if filter is not None:
+ filter_clauses = []
+ IN, NIN = "in", "nin"
+ for key, value in filter.items():
+ if isinstance(value, dict):
+ value_case_insensitive = {
+ k.lower(): v for k, v in value.items()
+ }
+ if IN in map(str.lower, value):
+ filter_by_metadata = self.EmbeddingStore.cmetadata[
+ key
+ ].astext.in_(value_case_insensitive[IN])
+ elif NIN in map(str.lower, value):
+ filter_by_metadata = self.EmbeddingStore.cmetadata[
+ key
+ ].astext.not_in(value_case_insensitive[NIN])
+ else:
+ filter_by_metadata = None
+ if filter_by_metadata is not None:
+ filter_clauses.append(filter_by_metadata)
+ else:
+ filter_by_metadata = self.EmbeddingStore.cmetadata[
+ key
+ ].astext == str(value)
+ filter_clauses.append(filter_by_metadata)
+
+ filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
+
+ _type = self.EmbeddingStore
+
+ results: List[Any] = (
+ session.query(
+ self.EmbeddingStore,
+ self.distance_strategy(embedding).label("distance"), # type: ignore
+ )
+ .filter(filter_by)
+ .order_by(sqlalchemy.asc("distance"))
+ .join(
+ self.CollectionStore,
+ self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
+ )
+ .limit(k)
+ .all()
+ )
+ return results
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return _results_to_docs(docs_and_scores)
+
+ @classmethod
+ def from_texts(
+ cls: Type[PGVector],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGVector:
+ """
+ Return VectorStore initialized from texts and embeddings.
+ Postgres connection string is required
+ "Either pass it as a parameter
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
+ """
+ embeddings = embedding.embed_documents(list(texts))
+
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGVector:
+ """Construct PGVector wrapper from raw documents and pre-
+ generated embeddings.
+
+ Return VectorStore initialized from documents and embeddings.
+ Postgres connection string is required
+ "Either pass it as a parameter
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import PGVector
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_existing_index(
+ cls: Type[PGVector],
+ embedding: Embeddings,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGVector:
+ """
+ Get instance of an existing PGVector store.This method will
+ return the instance of the store without inserting any new
+ embeddings
+ """
+
+ connection_string = cls.get_connection_string(kwargs)
+
+ store = cls(
+ connection_string=connection_string,
+ collection_name=collection_name,
+ embedding_function=embedding,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ )
+
+ return store
+
+ @classmethod
+ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
+ connection_string: str = get_from_dict_or_env(
+ data=kwargs,
+ key="connection_string",
+ env_key="PGVECTOR_CONNECTION_STRING",
+ )
+
+ if not connection_string:
+ raise ValueError(
+ "Postgres connection string is required"
+ "Either pass it as a parameter"
+ "or set the PGVECTOR_CONNECTION_STRING environment variable."
+ )
+
+ return connection_string
+
+ @classmethod
+ def from_documents(
+ cls: Type[PGVector],
+ documents: List[Document],
+ embedding: Embeddings,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> PGVector:
+ """
+ Return VectorStore initialized from documents and embeddings.
+ Postgres connection string is required
+ "Either pass it as a parameter
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
+ """
+
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+ connection_string = cls.get_connection_string(kwargs)
+
+ kwargs["connection_string"] = connection_string
+
+ return cls.from_texts(
+ texts=texts,
+ pre_delete_collection=pre_delete_collection,
+ embedding=embedding,
+ distance_strategy=distance_strategy,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ **kwargs,
+ )
+
+ @classmethod
+ def connection_string_from_db_params(
+ cls,
+ driver: str,
+ host: str,
+ port: int,
+ database: str,
+ user: str,
+ password: str,
+ ) -> str:
+ """Return connection string from database parameters."""
+ return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+ if self.override_relevance_score_fn is not None:
+ return self.override_relevance_score_fn
+
+ # Default strategy is to rely on distance strategy provided
+ # in vectorstore constructor
+ if self._distance_strategy == DistanceStrategy.COSINE:
+ return self._cosine_relevance_score_fn
+ elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
+ return self._euclidean_relevance_score_fn
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ return self._max_inner_product_relevance_score_fn
+ else:
+ raise ValueError(
+ "No supported normalization function"
+ f" for distance_strategy of {self._distance_strategy}."
+ "Consider providing relevance_score_fn to PGVector constructor."
+ )
+
+ def max_marginal_relevance_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs selected using the maximal marginal relevance with score
+ to embedding vector.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult (float): Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
+ relevance to the query and score for each.
+ """
+ results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
+
+ embedding_list = [result.EmbeddingStore.embedding for result in results]
+
+ mmr_selected = maximal_marginal_relevance(
+ np.array(embedding, dtype=np.float32),
+ embedding_list,
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+
+ candidates = self._results_to_docs_and_scores(results)
+
+ return [r for i, r in enumerate(candidates) if i in mmr_selected]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query (str): Text to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult (float): Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Document]: List of Documents selected by maximal marginal relevance.
+ """
+ embedding = self.embedding_function.embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ **kwargs,
+ )
+
+ def max_marginal_relevance_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs selected using the maximal marginal relevance with score.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query (str): Text to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult (float): Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
+ relevance to the query and score for each.
+ """
+ embedding = self.embedding_function.embed_query(query)
+ docs = self.max_marginal_relevance_search_with_score_by_vector(
+ embedding=embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ **kwargs,
+ )
+ return docs
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance
+ to embedding vector.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding (str): Text to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult (float): Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Document]: List of Documents selected by maximal marginal relevance.
+ """
+ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ **kwargs,
+ )
+
+ return _results_to_docs(docs_and_scores)
+
+ async def amax_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance."""
+
+ # This is a temporary workaround to make the similarity search
+ # asynchronous. The proper solution is to make the similarity search
+ # asynchronous in the vector store implementations.
+ func = partial(
+ self.max_marginal_relevance_search_by_vector,
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ **kwargs,
+ )
+ return await asyncio.get_event_loop().run_in_executor(None, func)
diff --git a/libs/community/langchain_community/vectorstores/pinecone.py b/libs/community/langchain_community/vectorstores/pinecone.py
new file mode 100644
index 00000000000..0ecfb875bf4
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/pinecone.py
@@ -0,0 +1,477 @@
+from __future__ import annotations
+
+import logging
+import uuid
+import warnings
+from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils.iter import batch_iterate
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import (
+ DistanceStrategy,
+ maximal_marginal_relevance,
+)
+
+if TYPE_CHECKING:
+ from pinecone import Index
+
+logger = logging.getLogger(__name__)
+
+
+class Pinecone(VectorStore):
+ """`Pinecone` vector store.
+
+ To use, you should have the ``pinecone-client`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Pinecone
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ import pinecone
+
+ # The environment should be the one specified next to the API key
+ # in your Pinecone console
+ pinecone.init(api_key="***", environment="...")
+ index = pinecone.Index("langchain-demo")
+ embeddings = OpenAIEmbeddings()
+ vectorstore = Pinecone(index, embeddings.embed_query, "text")
+ """
+
+ def __init__(
+ self,
+ index: Any,
+ embedding: Union[Embeddings, Callable],
+ text_key: str,
+ namespace: Optional[str] = None,
+ distance_strategy: Optional[DistanceStrategy] = DistanceStrategy.COSINE,
+ ):
+ """Initialize with Pinecone client."""
+ try:
+ import pinecone
+ except ImportError:
+ raise ImportError(
+ "Could not import pinecone python package. "
+ "Please install it with `pip install pinecone-client`."
+ )
+ if not isinstance(embedding, Embeddings):
+ warnings.warn(
+ "Passing in `embedding` as a Callable is deprecated. Please pass in an"
+ " Embeddings object instead."
+ )
+ if not isinstance(index, pinecone.index.Index):
+ raise ValueError(
+ f"client should be an instance of pinecone.index.Index, "
+ f"got {type(index)}"
+ )
+ self._index = index
+ self._embedding = embedding
+ self._text_key = text_key
+ self._namespace = namespace
+ self.distance_strategy = distance_strategy
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ """Access the query embedding object if available."""
+ if isinstance(self._embedding, Embeddings):
+ return self._embedding
+ return None
+
+ def _embed_documents(self, texts: Iterable[str]) -> List[List[float]]:
+ """Embed search docs."""
+ if isinstance(self._embedding, Embeddings):
+ return self._embedding.embed_documents(list(texts))
+ return [self._embedding(t) for t in texts]
+
+ def _embed_query(self, text: str) -> List[float]:
+ """Embed query text."""
+ if isinstance(self._embedding, Embeddings):
+ return self._embedding.embed_query(text)
+ return self._embedding(text)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ namespace: Optional[str] = None,
+ batch_size: int = 32,
+ embedding_chunk_size: int = 1000,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Upsert optimization is done by chunking the embeddings and upserting them.
+ This is done to avoid memory issues and optimize using HTTP based embeddings.
+ For OpenAI embeddings, use pool_threads>4 when constructing the pinecone.Index,
+ embedding_chunk_size>1000 and batch_size~64 for best performance.
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids to associate with the texts.
+ namespace: Optional pinecone namespace to add the texts to.
+ batch_size: Batch size to use when adding the texts to the vectorstore.
+ embedding_chunk_size: Chunk size to use when embedding the texts.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+
+ """
+ if namespace is None:
+ namespace = self._namespace
+
+ texts = list(texts)
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
+ metadatas = metadatas or [{} for _ in texts]
+ for metadata, text in zip(metadatas, texts):
+ metadata[self._text_key] = text
+
+ # For loops to avoid memory issues and optimize when using HTTP based embeddings
+ # The first loop runs the embeddings, it benefits when using OpenAI embeddings
+ # The second loops runs the pinecone upsert asynchronously.
+ for i in range(0, len(texts), embedding_chunk_size):
+ chunk_texts = texts[i : i + embedding_chunk_size]
+ chunk_ids = ids[i : i + embedding_chunk_size]
+ chunk_metadatas = metadatas[i : i + embedding_chunk_size]
+ embeddings = self._embed_documents(chunk_texts)
+ async_res = [
+ self._index.upsert(
+ vectors=batch,
+ namespace=namespace,
+ async_req=True,
+ **kwargs,
+ )
+ for batch in batch_iterate(
+ batch_size, zip(chunk_ids, embeddings, chunk_metadatas)
+ )
+ ]
+ [res.get() for res in async_res]
+
+ return ids
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ namespace: Optional[str] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return pinecone documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Dictionary of argument(s) to filter on metadata
+ namespace: Namespace to search in. Default will search in '' namespace.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ return self.similarity_search_by_vector_with_score(
+ self._embed_query(query), k=k, filter=filter, namespace=namespace
+ )
+
+ def similarity_search_by_vector_with_score(
+ self,
+ embedding: List[float],
+ *,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ namespace: Optional[str] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Return pinecone documents most similar to embedding, along with scores."""
+
+ if namespace is None:
+ namespace = self._namespace
+ docs = []
+ results = self._index.query(
+ [embedding],
+ top_k=k,
+ include_metadata=True,
+ namespace=namespace,
+ filter=filter,
+ )
+ for res in results["matches"]:
+ metadata = res["metadata"]
+ if self._text_key in metadata:
+ text = metadata.pop(self._text_key)
+ score = res["score"]
+ docs.append((Document(page_content=text, metadata=metadata), score))
+ else:
+ logger.warning(
+ f"Found document with no `{self._text_key}` key. Skipping."
+ )
+ return docs
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[dict] = None,
+ namespace: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return pinecone documents most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Dictionary of argument(s) to filter on metadata
+ namespace: Namespace to search in. Default will search in '' namespace.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query, k=k, filter=filter, namespace=namespace, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+
+ if self.distance_strategy == DistanceStrategy.COSINE:
+ return self._cosine_relevance_score_fn
+ elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ return self._max_inner_product_relevance_score_fn
+ elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
+ return self._euclidean_relevance_score_fn
+ else:
+ raise ValueError(
+ "Unknown distance strategy, must be cosine, max_inner_product "
+ "(dot product), or euclidean"
+ )
+
+ @staticmethod
+ def _cosine_relevance_score_fn(score: float) -> float:
+ """Pinecone returns cosine similarity scores between [-1,1]"""
+ return (score + 1) / 2
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[dict] = None,
+ namespace: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ if namespace is None:
+ namespace = self._namespace
+ results = self._index.query(
+ [embedding],
+ top_k=fetch_k,
+ include_values=True,
+ include_metadata=True,
+ namespace=namespace,
+ filter=filter,
+ )
+ mmr_selected = maximal_marginal_relevance(
+ np.array([embedding], dtype=np.float32),
+ [item["values"] for item in results["matches"]],
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ selected = [results["matches"][i]["metadata"] for i in mmr_selected]
+ return [
+ Document(page_content=metadata.pop((self._text_key)), metadata=metadata)
+ for metadata in selected
+ ]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[dict] = None,
+ namespace: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding = self._embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ embedding, k, fetch_k, lambda_mult, filter, namespace
+ )
+
+ @classmethod
+ def get_pinecone_index(
+ cls,
+ index_name: Optional[str],
+ pool_threads: int = 4,
+ ) -> Index:
+ """Return a Pinecone Index instance.
+
+ Args:
+ index_name: Name of the index to use.
+ pool_threads: Number of threads to use for index upsert.
+ Returns:
+ Pinecone Index instance."""
+
+ try:
+ import pinecone
+ except ImportError:
+ raise ValueError(
+ "Could not import pinecone python package. "
+ "Please install it with `pip install pinecone-client`."
+ )
+
+ indexes = pinecone.list_indexes() # checks if provided index exists
+
+ if index_name in indexes:
+ index = pinecone.Index(index_name, pool_threads=pool_threads)
+ elif len(indexes) == 0:
+ raise ValueError(
+ "No active indexes found in your Pinecone project, "
+ "are you sure you're using the right Pinecone API key and Environment? "
+ "Please double check your Pinecone dashboard."
+ )
+ else:
+ raise ValueError(
+ f"Index '{index_name}' not found in your Pinecone project. "
+ f"Did you mean one of the following indexes: {', '.join(indexes)}"
+ )
+ return index
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ batch_size: int = 32,
+ text_key: str = "text",
+ namespace: Optional[str] = None,
+ index_name: Optional[str] = None,
+ upsert_kwargs: Optional[dict] = None,
+ pool_threads: int = 4,
+ embeddings_chunk_size: int = 1000,
+ **kwargs: Any,
+ ) -> Pinecone:
+ """Construct Pinecone wrapper from raw documents.
+
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Adds the documents to a provided Pinecone index
+
+ This is intended to be a quick way to get started.
+
+ The `pool_threads` affects the speed of the upsert operations.
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Pinecone
+ from langchain_community.embeddings import OpenAIEmbeddings
+ import pinecone
+
+ # The environment should be the one specified next to the API key
+ # in your Pinecone console
+ pinecone.init(api_key="***", environment="...")
+ embeddings = OpenAIEmbeddings()
+ pinecone = Pinecone.from_texts(
+ texts,
+ embeddings,
+ index_name="langchain-demo"
+ )
+ """
+ pinecone_index = cls.get_pinecone_index(index_name, pool_threads)
+ pinecone = cls(pinecone_index, embedding, text_key, namespace, **kwargs)
+
+ pinecone.add_texts(
+ texts,
+ metadatas=metadatas,
+ ids=ids,
+ namespace=namespace,
+ batch_size=batch_size,
+ embedding_chunk_size=embeddings_chunk_size,
+ **(upsert_kwargs or {}),
+ )
+ return pinecone
+
+ @classmethod
+ def from_existing_index(
+ cls,
+ index_name: str,
+ embedding: Embeddings,
+ text_key: str = "text",
+ namespace: Optional[str] = None,
+ pool_threads: int = 4,
+ ) -> Pinecone:
+ """Load pinecone vectorstore from index name."""
+ pinecone_index = cls.get_pinecone_index(index_name, pool_threads)
+ return cls(pinecone_index, embedding, text_key, namespace)
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ delete_all: Optional[bool] = None,
+ namespace: Optional[str] = None,
+ filter: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Delete by vector IDs or filter.
+ Args:
+ ids: List of ids to delete.
+ filter: Dictionary of conditions to filter vectors to delete.
+ """
+
+ if namespace is None:
+ namespace = self._namespace
+
+ if delete_all:
+ self._index.delete(delete_all=True, namespace=namespace, **kwargs)
+ elif ids is not None:
+ chunk_size = 1000
+ for i in range(0, len(ids), chunk_size):
+ chunk = ids[i : i + chunk_size]
+ self._index.delete(ids=chunk, namespace=namespace, **kwargs)
+ elif filter is not None:
+ self._index.delete(filter=filter, namespace=namespace, **kwargs)
+ else:
+ raise ValueError("Either ids, delete_all, or filter must be provided.")
+
+ return None
diff --git a/libs/community/langchain_community/vectorstores/qdrant.py b/libs/community/langchain_community/vectorstores/qdrant.py
new file mode 100644
index 00000000000..ea881c4cbd6
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/qdrant.py
@@ -0,0 +1,2137 @@
+from __future__ import annotations
+
+import asyncio
+import functools
+import uuid
+import warnings
+from itertools import islice
+from operator import itemgetter
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncGenerator,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+if TYPE_CHECKING:
+ from qdrant_client import grpc # noqa
+ from qdrant_client.conversions import common_types
+ from qdrant_client.http import models as rest
+
+ DictFilter = Dict[str, Union[str, int, bool, dict, list]]
+ MetadataFilter = Union[DictFilter, common_types.Filter]
+
+
+class QdrantException(Exception):
+ """`Qdrant` related exceptions."""
+
+
+def sync_call_fallback(method: Callable) -> Callable:
+ """
+ Decorator to call the synchronous method of the class if the async method is not
+ implemented. This decorator might be only used for the methods that are defined
+ as async in the class.
+ """
+
+ @functools.wraps(method)
+ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
+ try:
+ return await method(self, *args, **kwargs)
+ except NotImplementedError:
+ # If the async method is not implemented, call the synchronous method
+ # by removing the first letter from the method name. For example,
+ # if the async method is called ``aaad_texts``, the synchronous method
+ # will be called ``aad_texts``.
+ sync_method = functools.partial(
+ getattr(self, method.__name__[1:]), *args, **kwargs
+ )
+ return await asyncio.get_event_loop().run_in_executor(None, sync_method)
+
+ return wrapper
+
+
+class Qdrant(VectorStore):
+ """`Qdrant` vector store.
+
+ To use you should have the ``qdrant-client`` package installed.
+
+ Example:
+ .. code-block:: python
+
+ from qdrant_client import QdrantClient
+ from langchain_community.vectorstores import Qdrant
+
+ client = QdrantClient()
+ collection_name = "MyCollection"
+ qdrant = Qdrant(client, collection_name, embedding_function)
+ """
+
+ CONTENT_KEY = "page_content"
+ METADATA_KEY = "metadata"
+ VECTOR_NAME = None
+
+ def __init__(
+ self,
+ client: Any,
+ collection_name: str,
+ embeddings: Optional[Embeddings] = None,
+ content_payload_key: str = CONTENT_KEY,
+ metadata_payload_key: str = METADATA_KEY,
+ distance_strategy: str = "COSINE",
+ vector_name: Optional[str] = VECTOR_NAME,
+ embedding_function: Optional[Callable] = None, # deprecated
+ ):
+ """Initialize with necessary components."""
+ try:
+ import qdrant_client
+ except ImportError:
+ raise ImportError(
+ "Could not import qdrant-client python package. "
+ "Please install it with `pip install qdrant-client`."
+ )
+
+ if not isinstance(client, qdrant_client.QdrantClient):
+ raise ValueError(
+ f"client should be an instance of qdrant_client.QdrantClient, "
+ f"got {type(client)}"
+ )
+
+ if embeddings is None and embedding_function is None:
+ raise ValueError(
+ "`embeddings` value can't be None. Pass `Embeddings` instance."
+ )
+
+ if embeddings is not None and embedding_function is not None:
+ raise ValueError(
+ "Both `embeddings` and `embedding_function` are passed. "
+ "Use `embeddings` only."
+ )
+
+ self._embeddings = embeddings
+ self._embeddings_function = embedding_function
+ self.client: qdrant_client.QdrantClient = client
+ self.collection_name = collection_name
+ self.content_payload_key = content_payload_key or self.CONTENT_KEY
+ self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
+ self.vector_name = vector_name or self.VECTOR_NAME
+
+ if embedding_function is not None:
+ warnings.warn(
+ "Using `embedding_function` is deprecated. "
+ "Pass `Embeddings` instance to `embeddings` instead."
+ )
+
+ if not isinstance(embeddings, Embeddings):
+ warnings.warn(
+ "`embeddings` should be an instance of `Embeddings`."
+ "Using `embeddings` as `embedding_function` which is deprecated"
+ )
+ self._embeddings_function = embeddings
+ self._embeddings = None
+
+ self.distance_strategy = distance_strategy.upper()
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embeddings
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[Sequence[str]] = None,
+ batch_size: int = 64,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids:
+ Optional list of ids to associate with the texts. Ids have to be
+ uuid-like strings.
+ batch_size:
+ How many vectors upload per-request.
+ Default: 64
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ added_ids = []
+ for batch_ids, points in self._generate_rest_batches(
+ texts, metadatas, ids, batch_size
+ ):
+ self.client.upsert(
+ collection_name=self.collection_name, points=points, **kwargs
+ )
+ added_ids.extend(batch_ids)
+
+ return added_ids
+
+ @sync_call_fallback
+ async def aadd_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[Sequence[str]] = None,
+ batch_size: int = 64,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids:
+ Optional list of ids to associate with the texts. Ids have to be
+ uuid-like strings.
+ batch_size:
+ How many vectors upload per-request.
+ Default: 64
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ from qdrant_client import grpc # noqa
+ from qdrant_client.conversions.conversion import RestToGrpc
+
+ added_ids = []
+ async for batch_ids, points in self._agenerate_rest_batches(
+ texts, metadatas, ids, batch_size
+ ):
+ await self.client.async_grpc_points.Upsert(
+ grpc.UpsertPoints(
+ collection_name=self.collection_name,
+ points=[RestToGrpc.convert_point_struct(point) for point in points],
+ )
+ )
+ added_ids.extend(batch_ids)
+
+ return added_ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ offset: int = 0,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ offset:
+ Offset of the first result to return.
+ May be used to paginate results.
+ Note: large offset values may cause performance issues.
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to QdrantClient.search()
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ results = self.similarity_search_with_score(
+ query,
+ k,
+ filter=filter,
+ search_params=search_params,
+ offset=offset,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+ return list(map(itemgetter(0), results))
+
+ @sync_call_fallback
+ async def asimilarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Filter by metadata. Defaults to None.
+ Returns:
+ List of Documents most similar to the query.
+ """
+ results = await self.asimilarity_search_with_score(query, k, filter, **kwargs)
+ return list(map(itemgetter(0), results))
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ offset: int = 0,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ offset:
+ Offset of the first result to return.
+ May be used to paginate results.
+ Note: large offset values may cause performance issues.
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to QdrantClient.search()
+
+ Returns:
+ List of documents most similar to the query text and distance for each.
+ """
+ return self.similarity_search_with_score_by_vector(
+ self._embed_query(query),
+ k,
+ filter=filter,
+ search_params=search_params,
+ offset=offset,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+
+ @sync_call_fallback
+ async def asimilarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ offset: int = 0,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ offset:
+ Offset of the first result to return.
+ May be used to paginate results.
+ Note: large offset values may cause performance issues.
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to
+ QdrantClient.async_grpc_points.Search().
+
+ Returns:
+ List of documents most similar to the query text and distance for each.
+ """
+ return await self.asimilarity_search_with_score_by_vector(
+ self._embed_query(query),
+ k,
+ filter=filter,
+ search_params=search_params,
+ offset=offset,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ offset: int = 0,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding vector to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ offset:
+ Offset of the first result to return.
+ May be used to paginate results.
+ Note: large offset values may cause performance issues.
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to QdrantClient.search()
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ results = self.similarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ search_params=search_params,
+ offset=offset,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+ return list(map(itemgetter(0), results))
+
+ @sync_call_fallback
+ async def asimilarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ offset: int = 0,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding vector to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ offset:
+ Offset of the first result to return.
+ May be used to paginate results.
+ Note: large offset values may cause performance issues.
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to
+ QdrantClient.async_grpc_points.Search().
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ results = await self.asimilarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ search_params=search_params,
+ offset=offset,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+ return list(map(itemgetter(0), results))
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ offset: int = 0,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding vector to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ offset:
+ Offset of the first result to return.
+ May be used to paginate results.
+ Note: large offset values may cause performance issues.
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to QdrantClient.search()
+
+ Returns:
+ List of documents most similar to the query text and distance for each.
+ """
+ if filter is not None and isinstance(filter, dict):
+ warnings.warn(
+ "Using dict as a `filter` is deprecated. Please use qdrant-client "
+ "filters directly: "
+ "https://qdrant.tech/documentation/concepts/filtering/",
+ DeprecationWarning,
+ )
+ qdrant_filter = self._qdrant_filter_from_dict(filter)
+ else:
+ qdrant_filter = filter
+
+ query_vector = embedding
+ if self.vector_name is not None:
+ query_vector = (self.vector_name, embedding) # type: ignore[assignment]
+
+ results = self.client.search(
+ collection_name=self.collection_name,
+ query_vector=query_vector,
+ query_filter=qdrant_filter,
+ search_params=search_params,
+ limit=k,
+ offset=offset,
+ with_payload=True,
+ with_vectors=False, # Langchain does not expect vectors to be returned
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+ return [
+ (
+ self._document_from_scored_point(
+ result, self.content_payload_key, self.metadata_payload_key
+ ),
+ result.score,
+ )
+ for result in results
+ ]
+
+ async def _asearch_with_score_by_vector(
+ self,
+ embedding: List[float],
+ *,
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ offset: int = 0,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ with_vectors: bool = False,
+ **kwargs: Any,
+ ) -> Any:
+ """Return results most similar to embedding vector."""
+ from qdrant_client import grpc # noqa
+ from qdrant_client.conversions.conversion import RestToGrpc
+ from qdrant_client.http import models as rest
+
+ if filter is not None and isinstance(filter, dict):
+ warnings.warn(
+ "Using dict as a `filter` is deprecated. Please use qdrant-client "
+ "filters directly: "
+ "https://qdrant.tech/documentation/concepts/filtering/",
+ DeprecationWarning,
+ )
+ qdrant_filter = self._qdrant_filter_from_dict(filter)
+ else:
+ qdrant_filter = filter
+
+ if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
+ qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
+
+ response = await self.client.async_grpc_points.Search(
+ grpc.SearchPoints(
+ collection_name=self.collection_name,
+ vector_name=self.vector_name,
+ vector=embedding,
+ filter=qdrant_filter,
+ params=search_params,
+ limit=k,
+ offset=offset,
+ with_payload=grpc.WithPayloadSelector(enable=True),
+ with_vectors=grpc.WithVectorsSelector(enable=with_vectors),
+ score_threshold=score_threshold,
+ read_consistency=consistency,
+ **kwargs,
+ )
+ )
+ return response
+
+ @sync_call_fallback
+ async def asimilarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ offset: int = 0,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding vector to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ offset:
+ Offset of the first result to return.
+ May be used to paginate results.
+ Note: large offset values may cause performance issues.
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to
+ QdrantClient.async_grpc_points.Search().
+
+ Returns:
+ List of documents most similar to the query text and distance for each.
+ """
+ response = await self._asearch_with_score_by_vector(
+ embedding,
+ k=k,
+ filter=filter,
+ search_params=search_params,
+ offset=offset,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+
+ return [
+ (
+ self._document_from_scored_point_grpc(
+ result, self.content_payload_key, self.metadata_payload_key
+ ),
+ result.score,
+ )
+ for result in response.result
+ ]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to QdrantClient.search()
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ query_embedding = self._embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ query_embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ search_params=search_params,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+
+ @sync_call_fallback
+ async def amax_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to
+ QdrantClient.async_grpc_points.Search().
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ query_embedding = self._embed_query(query)
+ return await self.amax_marginal_relevance_search_by_vector(
+ query_embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ search_params=search_params,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to QdrantClient.search()
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ results = self.max_marginal_relevance_search_with_score_by_vector(
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ search_params=search_params,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+ return list(map(itemgetter(0), results))
+
+ @sync_call_fallback
+ async def amax_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to
+ QdrantClient.async_grpc_points.Search().
+ Returns:
+ List of Documents selected by maximal marginal relevance and distance for
+ each.
+ """
+ results = await self.amax_marginal_relevance_search_with_score_by_vector(
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ search_params=search_params,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+ return list(map(itemgetter(0), results))
+
+ def max_marginal_relevance_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter: Filter by metadata. Defaults to None.
+ search_params: Additional search params
+ score_threshold:
+ Define a minimal score threshold for the result.
+ If defined, less similar results will not be returned.
+ Score of the returned result might be higher or smaller than the
+ threshold depending on the Distance function used.
+ E.g. for cosine similarity only higher scores will be returned.
+ consistency:
+ Read consistency of the search. Defines how many replicas should be
+ queried before returning the result.
+ Values:
+ - int - number of replicas to query, values should present in all
+ queried replicas
+ - 'majority' - query all replicas, but return values present in the
+ majority of replicas
+ - 'quorum' - query the majority of replicas, return values present in
+ all of them
+ - 'all' - query all replicas, and return values present in all replicas
+ **kwargs:
+ Any other named arguments to pass through to QdrantClient.search()
+ Returns:
+ List of Documents selected by maximal marginal relevance and distance for
+ each.
+ """
+ query_vector = embedding
+ if self.vector_name is not None:
+ query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
+
+ results = self.client.search(
+ collection_name=self.collection_name,
+ query_vector=query_vector,
+ query_filter=filter,
+ search_params=search_params,
+ limit=fetch_k,
+ with_payload=True,
+ with_vectors=True,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ **kwargs,
+ )
+ embeddings = [
+ result.vector.get(self.vector_name) # type: ignore[index, union-attr]
+ if self.vector_name is not None
+ else result.vector
+ for result in results
+ ]
+ mmr_selected = maximal_marginal_relevance(
+ np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
+ )
+ return [
+ (
+ self._document_from_scored_point(
+ results[i], self.content_payload_key, self.metadata_payload_key
+ ),
+ results[i].score,
+ )
+ for i in mmr_selected
+ ]
+
+ @sync_call_fallback
+ async def amax_marginal_relevance_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[MetadataFilter] = None,
+ search_params: Optional[common_types.SearchParams] = None,
+ score_threshold: Optional[float] = None,
+ consistency: Optional[common_types.ReadConsistency] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ Defaults to 20.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance and distance for
+ each.
+ """
+ from qdrant_client.conversions.conversion import GrpcToRest
+
+ response = await self._asearch_with_score_by_vector(
+ embedding,
+ k=fetch_k,
+ filter=filter,
+ search_params=search_params,
+ score_threshold=score_threshold,
+ consistency=consistency,
+ with_vectors=True,
+ **kwargs,
+ )
+ results = [
+ GrpcToRest.convert_vectors(result.vectors) for result in response.result
+ ]
+ embeddings: List[List[float]] = [
+ result.get(self.vector_name) # type: ignore
+ if isinstance(result, dict)
+ else result
+ for result in results
+ ]
+ mmr_selected: List[int] = maximal_marginal_relevance(
+ np.array(embedding),
+ embeddings,
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ return [
+ (
+ self._document_from_scored_point_grpc(
+ response.result[i],
+ self.content_payload_key,
+ self.metadata_payload_key,
+ ),
+ response.result[i].score,
+ )
+ for i in mmr_selected
+ ]
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete by vector ID or other criteria.
+
+ Args:
+ ids: List of ids to delete.
+ **kwargs: Other keyword arguments that subclasses might use.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+ from qdrant_client.http import models as rest
+
+ result = self.client.delete(
+ collection_name=self.collection_name,
+ points_selector=ids,
+ )
+ return result.status == rest.UpdateStatus.COMPLETED
+
+ @classmethod
+ def from_texts(
+ cls: Type[Qdrant],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[Sequence[str]] = None,
+ location: Optional[str] = None,
+ url: Optional[str] = None,
+ port: Optional[int] = 6333,
+ grpc_port: int = 6334,
+ prefer_grpc: bool = False,
+ https: Optional[bool] = None,
+ api_key: Optional[str] = None,
+ prefix: Optional[str] = None,
+ timeout: Optional[float] = None,
+ host: Optional[str] = None,
+ path: Optional[str] = None,
+ collection_name: Optional[str] = None,
+ distance_func: str = "Cosine",
+ content_payload_key: str = CONTENT_KEY,
+ metadata_payload_key: str = METADATA_KEY,
+ vector_name: Optional[str] = VECTOR_NAME,
+ batch_size: int = 64,
+ shard_number: Optional[int] = None,
+ replication_factor: Optional[int] = None,
+ write_consistency_factor: Optional[int] = None,
+ on_disk_payload: Optional[bool] = None,
+ hnsw_config: Optional[common_types.HnswConfigDiff] = None,
+ optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
+ wal_config: Optional[common_types.WalConfigDiff] = None,
+ quantization_config: Optional[common_types.QuantizationConfig] = None,
+ init_from: Optional[common_types.InitFrom] = None,
+ on_disk: Optional[bool] = None,
+ force_recreate: bool = False,
+ **kwargs: Any,
+ ) -> Qdrant:
+ """Construct Qdrant wrapper from a list of texts.
+
+ Args:
+ texts: A list of texts to be indexed in Qdrant.
+ embedding: A subclass of `Embeddings`, responsible for text vectorization.
+ metadatas:
+ An optional list of metadata. If provided it has to be of the same
+ length as a list of texts.
+ ids:
+ Optional list of ids to associate with the texts. Ids have to be
+ uuid-like strings.
+ location:
+ If `:memory:` - use in-memory Qdrant instance.
+ If `str` - use it as a `url` parameter.
+ If `None` - fallback to relying on `host` and `port` parameters.
+ url: either host or str of "Optional[scheme], host, Optional[port],
+ Optional[prefix]". Default: `None`
+ port: Port of the REST API interface. Default: 6333
+ grpc_port: Port of the gRPC interface. Default: 6334
+ prefer_grpc:
+ If true - use gPRC interface whenever possible in custom methods.
+ Default: False
+ https: If true - use HTTPS(SSL) protocol. Default: None
+ api_key: API key for authentication in Qdrant Cloud. Default: None
+ prefix:
+ If not None - add prefix to the REST URL path.
+ Example: service/v1 will result in
+ http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
+ Default: None
+ timeout:
+ Timeout for REST and gRPC API requests.
+ Default: 5.0 seconds for REST and unlimited for gRPC
+ host:
+ Host name of Qdrant service. If url and host are None, set to
+ 'localhost'. Default: None
+ path:
+ Path in which the vectors will be stored while using local mode.
+ Default: None
+ collection_name:
+ Name of the Qdrant collection to be used. If not provided,
+ it will be created randomly. Default: None
+ distance_func:
+ Distance function. One of: "Cosine" / "Euclid" / "Dot".
+ Default: "Cosine"
+ content_payload_key:
+ A payload key used to store the content of the document.
+ Default: "page_content"
+ metadata_payload_key:
+ A payload key used to store the metadata of the document.
+ Default: "metadata"
+ vector_name:
+ Name of the vector to be used internally in Qdrant.
+ Default: None
+ batch_size:
+ How many vectors upload per-request.
+ Default: 64
+ shard_number: Number of shards in collection. Default is 1, minimum is 1.
+ replication_factor:
+ Replication factor for collection. Default is 1, minimum is 1.
+ Defines how many copies of each shard will be created.
+ Have effect only in distributed mode.
+ write_consistency_factor:
+ Write consistency factor for collection. Default is 1, minimum is 1.
+ Defines how many replicas should apply the operation for us to consider
+ it successful. Increasing this number will make the collection more
+ resilient to inconsistencies, but will also make it fail if not enough
+ replicas are available.
+ Does not have any performance impact.
+ Have effect only in distributed mode.
+ on_disk_payload:
+ If true - point`s payload will not be stored in memory.
+ It will be read from the disk every time it is requested.
+ This setting saves RAM by (slightly) increasing the response time.
+ Note: those payload values that are involved in filtering and are
+ indexed - remain in RAM.
+ hnsw_config: Params for HNSW index
+ optimizers_config: Params for optimizer
+ wal_config: Params for Write-Ahead-Log
+ quantization_config:
+ Params for quantization, if None - quantization will be disabled
+ init_from:
+ Use data stored in another collection to initialize this collection
+ force_recreate:
+ Force recreating the collection
+ **kwargs:
+ Additional arguments passed directly into REST client initialization
+
+ This is a user-friendly interface that:
+ 1. Creates embeddings, one for each text
+ 2. Initializes the Qdrant database as an in-memory docstore by default
+ (and overridable to a remote docstore)
+ 3. Adds the text embeddings to the Qdrant database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Qdrant
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
+ """
+ qdrant = cls.construct_instance(
+ texts,
+ embedding,
+ location,
+ url,
+ port,
+ grpc_port,
+ prefer_grpc,
+ https,
+ api_key,
+ prefix,
+ timeout,
+ host,
+ path,
+ collection_name,
+ distance_func,
+ content_payload_key,
+ metadata_payload_key,
+ vector_name,
+ shard_number,
+ replication_factor,
+ write_consistency_factor,
+ on_disk_payload,
+ hnsw_config,
+ optimizers_config,
+ wal_config,
+ quantization_config,
+ init_from,
+ on_disk,
+ force_recreate,
+ **kwargs,
+ )
+ qdrant.add_texts(texts, metadatas, ids, batch_size)
+ return qdrant
+
+ @classmethod
+ @sync_call_fallback
+ async def afrom_texts(
+ cls: Type[Qdrant],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[Sequence[str]] = None,
+ location: Optional[str] = None,
+ url: Optional[str] = None,
+ port: Optional[int] = 6333,
+ grpc_port: int = 6334,
+ prefer_grpc: bool = False,
+ https: Optional[bool] = None,
+ api_key: Optional[str] = None,
+ prefix: Optional[str] = None,
+ timeout: Optional[float] = None,
+ host: Optional[str] = None,
+ path: Optional[str] = None,
+ collection_name: Optional[str] = None,
+ distance_func: str = "Cosine",
+ content_payload_key: str = CONTENT_KEY,
+ metadata_payload_key: str = METADATA_KEY,
+ vector_name: Optional[str] = VECTOR_NAME,
+ batch_size: int = 64,
+ shard_number: Optional[int] = None,
+ replication_factor: Optional[int] = None,
+ write_consistency_factor: Optional[int] = None,
+ on_disk_payload: Optional[bool] = None,
+ hnsw_config: Optional[common_types.HnswConfigDiff] = None,
+ optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
+ wal_config: Optional[common_types.WalConfigDiff] = None,
+ quantization_config: Optional[common_types.QuantizationConfig] = None,
+ init_from: Optional[common_types.InitFrom] = None,
+ on_disk: Optional[bool] = None,
+ force_recreate: bool = False,
+ **kwargs: Any,
+ ) -> Qdrant:
+ """Construct Qdrant wrapper from a list of texts.
+
+ Args:
+ texts: A list of texts to be indexed in Qdrant.
+ embedding: A subclass of `Embeddings`, responsible for text vectorization.
+ metadatas:
+ An optional list of metadata. If provided it has to be of the same
+ length as a list of texts.
+ ids:
+ Optional list of ids to associate with the texts. Ids have to be
+ uuid-like strings.
+ location:
+ If `:memory:` - use in-memory Qdrant instance.
+ If `str` - use it as a `url` parameter.
+ If `None` - fallback to relying on `host` and `port` parameters.
+ url: either host or str of "Optional[scheme], host, Optional[port],
+ Optional[prefix]". Default: `None`
+ port: Port of the REST API interface. Default: 6333
+ grpc_port: Port of the gRPC interface. Default: 6334
+ prefer_grpc:
+ If true - use gPRC interface whenever possible in custom methods.
+ Default: False
+ https: If true - use HTTPS(SSL) protocol. Default: None
+ api_key: API key for authentication in Qdrant Cloud. Default: None
+ prefix:
+ If not None - add prefix to the REST URL path.
+ Example: service/v1 will result in
+ http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
+ Default: None
+ timeout:
+ Timeout for REST and gRPC API requests.
+ Default: 5.0 seconds for REST and unlimited for gRPC
+ host:
+ Host name of Qdrant service. If url and host are None, set to
+ 'localhost'. Default: None
+ path:
+ Path in which the vectors will be stored while using local mode.
+ Default: None
+ collection_name:
+ Name of the Qdrant collection to be used. If not provided,
+ it will be created randomly. Default: None
+ distance_func:
+ Distance function. One of: "Cosine" / "Euclid" / "Dot".
+ Default: "Cosine"
+ content_payload_key:
+ A payload key used to store the content of the document.
+ Default: "page_content"
+ metadata_payload_key:
+ A payload key used to store the metadata of the document.
+ Default: "metadata"
+ vector_name:
+ Name of the vector to be used internally in Qdrant.
+ Default: None
+ batch_size:
+ How many vectors upload per-request.
+ Default: 64
+ shard_number: Number of shards in collection. Default is 1, minimum is 1.
+ replication_factor:
+ Replication factor for collection. Default is 1, minimum is 1.
+ Defines how many copies of each shard will be created.
+ Have effect only in distributed mode.
+ write_consistency_factor:
+ Write consistency factor for collection. Default is 1, minimum is 1.
+ Defines how many replicas should apply the operation for us to consider
+ it successful. Increasing this number will make the collection more
+ resilient to inconsistencies, but will also make it fail if not enough
+ replicas are available.
+ Does not have any performance impact.
+ Have effect only in distributed mode.
+ on_disk_payload:
+ If true - point`s payload will not be stored in memory.
+ It will be read from the disk every time it is requested.
+ This setting saves RAM by (slightly) increasing the response time.
+ Note: those payload values that are involved in filtering and are
+ indexed - remain in RAM.
+ hnsw_config: Params for HNSW index
+ optimizers_config: Params for optimizer
+ wal_config: Params for Write-Ahead-Log
+ quantization_config:
+ Params for quantization, if None - quantization will be disabled
+ init_from:
+ Use data stored in another collection to initialize this collection
+ force_recreate:
+ Force recreating the collection
+ **kwargs:
+ Additional arguments passed directly into REST client initialization
+
+ This is a user-friendly interface that:
+ 1. Creates embeddings, one for each text
+ 2. Initializes the Qdrant database as an in-memory docstore by default
+ (and overridable to a remote docstore)
+ 3. Adds the text embeddings to the Qdrant database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Qdrant
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
+ """
+ qdrant = await cls.aconstruct_instance(
+ texts,
+ embedding,
+ location,
+ url,
+ port,
+ grpc_port,
+ prefer_grpc,
+ https,
+ api_key,
+ prefix,
+ timeout,
+ host,
+ path,
+ collection_name,
+ distance_func,
+ content_payload_key,
+ metadata_payload_key,
+ vector_name,
+ shard_number,
+ replication_factor,
+ write_consistency_factor,
+ on_disk_payload,
+ hnsw_config,
+ optimizers_config,
+ wal_config,
+ quantization_config,
+ init_from,
+ on_disk,
+ force_recreate,
+ **kwargs,
+ )
+ await qdrant.aadd_texts(texts, metadatas, ids, batch_size)
+ return qdrant
+
+ @classmethod
+ def construct_instance(
+ cls: Type[Qdrant],
+ texts: List[str],
+ embedding: Embeddings,
+ location: Optional[str] = None,
+ url: Optional[str] = None,
+ port: Optional[int] = 6333,
+ grpc_port: int = 6334,
+ prefer_grpc: bool = False,
+ https: Optional[bool] = None,
+ api_key: Optional[str] = None,
+ prefix: Optional[str] = None,
+ timeout: Optional[float] = None,
+ host: Optional[str] = None,
+ path: Optional[str] = None,
+ collection_name: Optional[str] = None,
+ distance_func: str = "Cosine",
+ content_payload_key: str = CONTENT_KEY,
+ metadata_payload_key: str = METADATA_KEY,
+ vector_name: Optional[str] = VECTOR_NAME,
+ shard_number: Optional[int] = None,
+ replication_factor: Optional[int] = None,
+ write_consistency_factor: Optional[int] = None,
+ on_disk_payload: Optional[bool] = None,
+ hnsw_config: Optional[common_types.HnswConfigDiff] = None,
+ optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
+ wal_config: Optional[common_types.WalConfigDiff] = None,
+ quantization_config: Optional[common_types.QuantizationConfig] = None,
+ init_from: Optional[common_types.InitFrom] = None,
+ on_disk: Optional[bool] = None,
+ force_recreate: bool = False,
+ **kwargs: Any,
+ ) -> Qdrant:
+ try:
+ import qdrant_client
+ except ImportError:
+ raise ValueError(
+ "Could not import qdrant-client python package. "
+ "Please install it with `pip install qdrant-client`."
+ )
+ from grpc import RpcError
+ from qdrant_client.http import models as rest
+ from qdrant_client.http.exceptions import UnexpectedResponse
+
+ # Just do a single quick embedding to get vector size
+ partial_embeddings = embedding.embed_documents(texts[:1])
+ vector_size = len(partial_embeddings[0])
+ collection_name = collection_name or uuid.uuid4().hex
+ distance_func = distance_func.upper()
+ client = qdrant_client.QdrantClient(
+ location=location,
+ url=url,
+ port=port,
+ grpc_port=grpc_port,
+ prefer_grpc=prefer_grpc,
+ https=https,
+ api_key=api_key,
+ prefix=prefix,
+ timeout=timeout,
+ host=host,
+ path=path,
+ **kwargs,
+ )
+ try:
+ # Skip any validation in case of forced collection recreate.
+ if force_recreate:
+ raise ValueError
+
+ # Get the vector configuration of the existing collection and vector, if it
+ # was specified. If the old configuration does not match the current one,
+ # an exception is being thrown.
+ collection_info = client.get_collection(collection_name=collection_name)
+ current_vector_config = collection_info.config.params.vectors
+ if isinstance(current_vector_config, dict) and vector_name is not None:
+ if vector_name not in current_vector_config:
+ raise QdrantException(
+ f"Existing Qdrant collection {collection_name} does not "
+ f"contain vector named {vector_name}. Did you mean one of the "
+ f"existing vectors: {', '.join(current_vector_config.keys())}? "
+ f"If you want to recreate the collection, set `force_recreate` "
+ f"parameter to `True`."
+ )
+ current_vector_config = current_vector_config.get(vector_name) # type: ignore[assignment]
+ elif isinstance(current_vector_config, dict) and vector_name is None:
+ raise QdrantException(
+ f"Existing Qdrant collection {collection_name} uses named vectors. "
+ f"If you want to reuse it, please set `vector_name` to any of the "
+ f"existing named vectors: "
+ f"{', '.join(current_vector_config.keys())}." # noqa
+ f"If you want to recreate the collection, set `force_recreate` "
+ f"parameter to `True`."
+ )
+ elif (
+ not isinstance(current_vector_config, dict) and vector_name is not None
+ ):
+ raise QdrantException(
+ f"Existing Qdrant collection {collection_name} doesn't use named "
+ f"vectors. If you want to reuse it, please set `vector_name` to "
+ f"`None`. If you want to recreate the collection, set "
+ f"`force_recreate` parameter to `True`."
+ )
+
+ # Check if the vector configuration has the same dimensionality.
+ if current_vector_config.size != vector_size: # type: ignore[union-attr]
+ raise QdrantException(
+ f"Existing Qdrant collection is configured for vectors with "
+ f"{current_vector_config.size} " # type: ignore[union-attr]
+ f"dimensions. Selected embeddings are {vector_size}-dimensional. "
+ f"If you want to recreate the collection, set `force_recreate` "
+ f"parameter to `True`."
+ )
+
+ current_distance_func = (
+ current_vector_config.distance.name.upper() # type: ignore[union-attr]
+ )
+ if current_distance_func != distance_func:
+ raise QdrantException(
+ f"Existing Qdrant collection is configured for "
+ f"{current_distance_func} similarity, but requested "
+ f"{distance_func}. Please set `distance_func` parameter to "
+ f"`{current_distance_func}` if you want to reuse it. "
+ f"If you want to recreate the collection, set `force_recreate` "
+ f"parameter to `True`."
+ )
+ except (UnexpectedResponse, RpcError, ValueError):
+ vectors_config = rest.VectorParams(
+ size=vector_size,
+ distance=rest.Distance[distance_func],
+ on_disk=on_disk,
+ )
+
+ # If vector name was provided, we're going to use the named vectors feature
+ # with just a single vector.
+ if vector_name is not None:
+ vectors_config = { # type: ignore[assignment]
+ vector_name: vectors_config,
+ }
+
+ client.recreate_collection(
+ collection_name=collection_name,
+ vectors_config=vectors_config,
+ shard_number=shard_number,
+ replication_factor=replication_factor,
+ write_consistency_factor=write_consistency_factor,
+ on_disk_payload=on_disk_payload,
+ hnsw_config=hnsw_config,
+ optimizers_config=optimizers_config,
+ wal_config=wal_config,
+ quantization_config=quantization_config,
+ init_from=init_from,
+ timeout=timeout, # type: ignore[arg-type]
+ )
+ qdrant = cls(
+ client=client,
+ collection_name=collection_name,
+ embeddings=embedding,
+ content_payload_key=content_payload_key,
+ metadata_payload_key=metadata_payload_key,
+ distance_strategy=distance_func,
+ vector_name=vector_name,
+ )
+ return qdrant
+
+ @classmethod
+ async def aconstruct_instance(
+ cls: Type[Qdrant],
+ texts: List[str],
+ embedding: Embeddings,
+ location: Optional[str] = None,
+ url: Optional[str] = None,
+ port: Optional[int] = 6333,
+ grpc_port: int = 6334,
+ prefer_grpc: bool = False,
+ https: Optional[bool] = None,
+ api_key: Optional[str] = None,
+ prefix: Optional[str] = None,
+ timeout: Optional[float] = None,
+ host: Optional[str] = None,
+ path: Optional[str] = None,
+ collection_name: Optional[str] = None,
+ distance_func: str = "Cosine",
+ content_payload_key: str = CONTENT_KEY,
+ metadata_payload_key: str = METADATA_KEY,
+ vector_name: Optional[str] = VECTOR_NAME,
+ shard_number: Optional[int] = None,
+ replication_factor: Optional[int] = None,
+ write_consistency_factor: Optional[int] = None,
+ on_disk_payload: Optional[bool] = None,
+ hnsw_config: Optional[common_types.HnswConfigDiff] = None,
+ optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
+ wal_config: Optional[common_types.WalConfigDiff] = None,
+ quantization_config: Optional[common_types.QuantizationConfig] = None,
+ init_from: Optional[common_types.InitFrom] = None,
+ on_disk: Optional[bool] = None,
+ force_recreate: bool = False,
+ **kwargs: Any,
+ ) -> Qdrant:
+ try:
+ import qdrant_client
+ except ImportError:
+ raise ValueError(
+ "Could not import qdrant-client python package. "
+ "Please install it with `pip install qdrant-client`."
+ )
+ from grpc import RpcError
+ from qdrant_client.http import models as rest
+ from qdrant_client.http.exceptions import UnexpectedResponse
+
+ # Just do a single quick embedding to get vector size
+ partial_embeddings = await embedding.aembed_documents(texts[:1])
+ vector_size = len(partial_embeddings[0])
+ collection_name = collection_name or uuid.uuid4().hex
+ distance_func = distance_func.upper()
+ client = qdrant_client.QdrantClient(
+ location=location,
+ url=url,
+ port=port,
+ grpc_port=grpc_port,
+ prefer_grpc=prefer_grpc,
+ https=https,
+ api_key=api_key,
+ prefix=prefix,
+ timeout=timeout,
+ host=host,
+ path=path,
+ **kwargs,
+ )
+ try:
+ # Skip any validation in case of forced collection recreate.
+ if force_recreate:
+ raise ValueError
+
+ # Get the vector configuration of the existing collection and vector, if it
+ # was specified. If the old configuration does not match the current one,
+ # an exception is being thrown.
+ collection_info = client.get_collection(collection_name=collection_name)
+ current_vector_config = collection_info.config.params.vectors
+ if isinstance(current_vector_config, dict) and vector_name is not None:
+ if vector_name not in current_vector_config:
+ raise QdrantException(
+ f"Existing Qdrant collection {collection_name} does not "
+ f"contain vector named {vector_name}. Did you mean one of the "
+ f"existing vectors: {', '.join(current_vector_config.keys())}? "
+ f"If you want to recreate the collection, set `force_recreate` "
+ f"parameter to `True`."
+ )
+ current_vector_config = current_vector_config.get(vector_name) # type: ignore[assignment]
+ elif isinstance(current_vector_config, dict) and vector_name is None:
+ raise QdrantException(
+ f"Existing Qdrant collection {collection_name} uses named vectors. "
+ f"If you want to reuse it, please set `vector_name` to any of the "
+ f"existing named vectors: "
+ f"{', '.join(current_vector_config.keys())}." # noqa
+ f"If you want to recreate the collection, set `force_recreate` "
+ f"parameter to `True`."
+ )
+ elif (
+ not isinstance(current_vector_config, dict) and vector_name is not None
+ ):
+ raise QdrantException(
+ f"Existing Qdrant collection {collection_name} doesn't use named "
+ f"vectors. If you want to reuse it, please set `vector_name` to "
+ f"`None`. If you want to recreate the collection, set "
+ f"`force_recreate` parameter to `True`."
+ )
+
+ # Check if the vector configuration has the same dimensionality.
+ if current_vector_config.size != vector_size: # type: ignore[union-attr]
+ raise QdrantException(
+ f"Existing Qdrant collection is configured for vectors with "
+ f"{current_vector_config.size} " # type: ignore[union-attr]
+ f"dimensions. Selected embeddings are {vector_size}-dimensional. "
+ f"If you want to recreate the collection, set `force_recreate` "
+ f"parameter to `True`."
+ )
+
+ current_distance_func = (
+ current_vector_config.distance.name.upper() # type: ignore[union-attr]
+ )
+ if current_distance_func != distance_func:
+ raise QdrantException(
+ f"Existing Qdrant collection is configured for "
+ f"{current_vector_config.distance} " # type: ignore[union-attr]
+ f"similarity. Please set `distance_func` parameter to "
+ f"`{distance_func}` if you want to reuse it. If you want to "
+ f"recreate the collection, set `force_recreate` parameter to "
+ f"`True`."
+ )
+ except (UnexpectedResponse, RpcError, ValueError):
+ vectors_config = rest.VectorParams(
+ size=vector_size,
+ distance=rest.Distance[distance_func],
+ on_disk=on_disk,
+ )
+
+ # If vector name was provided, we're going to use the named vectors feature
+ # with just a single vector.
+ if vector_name is not None:
+ vectors_config = { # type: ignore[assignment]
+ vector_name: vectors_config,
+ }
+
+ client.recreate_collection(
+ collection_name=collection_name,
+ vectors_config=vectors_config,
+ shard_number=shard_number,
+ replication_factor=replication_factor,
+ write_consistency_factor=write_consistency_factor,
+ on_disk_payload=on_disk_payload,
+ hnsw_config=hnsw_config,
+ optimizers_config=optimizers_config,
+ wal_config=wal_config,
+ quantization_config=quantization_config,
+ init_from=init_from,
+ timeout=timeout, # type: ignore[arg-type]
+ )
+ qdrant = cls(
+ client=client,
+ collection_name=collection_name,
+ embeddings=embedding,
+ content_payload_key=content_payload_key,
+ metadata_payload_key=metadata_payload_key,
+ distance_strategy=distance_func,
+ vector_name=vector_name,
+ )
+ return qdrant
+
+ @staticmethod
+ def _cosine_relevance_score_fn(distance: float) -> float:
+ """Normalize the distance to a score on a scale [0, 1]."""
+ return (distance + 1.0) / 2.0
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+
+ if self.distance_strategy == "COSINE":
+ return self._cosine_relevance_score_fn
+ elif self.distance_strategy == "DOT":
+ return self._max_inner_product_relevance_score_fn
+ elif self.distance_strategy == "EUCLID":
+ return self._euclidean_relevance_score_fn
+ else:
+ raise ValueError(
+ "Unknown distance strategy, must be cosine, "
+ "max_inner_product, or euclidean"
+ )
+
+ def _similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and relevance scores in the range [0, 1].
+
+ 0 is dissimilar, 1 is most similar.
+
+ Args:
+ query: input text
+ k: Number of Documents to return. Defaults to 4.
+ **kwargs: kwargs to be passed to similarity search. Should include:
+ score_threshold: Optional, a floating point value between 0 to 1 to
+ filter the resulting set of retrieved docs
+
+ Returns:
+ List of Tuples of (doc, similarity_score)
+ """
+ return self.similarity_search_with_score(query, k, **kwargs)
+
+ @classmethod
+ def _build_payloads(
+ cls,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]],
+ content_payload_key: str,
+ metadata_payload_key: str,
+ ) -> List[dict]:
+ payloads = []
+ for i, text in enumerate(texts):
+ if text is None:
+ raise ValueError(
+ "At least one of the texts is None. Please remove it before "
+ "calling .from_texts or .add_texts on Qdrant instance."
+ )
+ metadata = metadatas[i] if metadatas is not None else None
+ payloads.append(
+ {
+ content_payload_key: text,
+ metadata_payload_key: metadata,
+ }
+ )
+
+ return payloads
+
+ @classmethod
+ def _document_from_scored_point(
+ cls,
+ scored_point: Any,
+ content_payload_key: str,
+ metadata_payload_key: str,
+ ) -> Document:
+ return Document(
+ page_content=scored_point.payload.get(content_payload_key),
+ metadata=scored_point.payload.get(metadata_payload_key) or {},
+ )
+
+ @classmethod
+ def _document_from_scored_point_grpc(
+ cls,
+ scored_point: Any,
+ content_payload_key: str,
+ metadata_payload_key: str,
+ ) -> Document:
+ from qdrant_client.conversions.conversion import grpc_to_payload
+
+ payload = grpc_to_payload(scored_point.payload)
+ return Document(
+ page_content=payload[content_payload_key],
+ metadata=payload.get(metadata_payload_key) or {},
+ )
+
+ def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
+ from qdrant_client.http import models as rest
+
+ out = []
+
+ if isinstance(value, dict):
+ for _key, value in value.items():
+ out.extend(self._build_condition(f"{key}.{_key}", value))
+ elif isinstance(value, list):
+ for _value in value:
+ if isinstance(_value, dict):
+ out.extend(self._build_condition(f"{key}[]", _value))
+ else:
+ out.extend(self._build_condition(f"{key}", _value))
+ else:
+ out.append(
+ rest.FieldCondition(
+ key=f"{self.metadata_payload_key}.{key}",
+ match=rest.MatchValue(value=value),
+ )
+ )
+
+ return out
+
+ def _qdrant_filter_from_dict(
+ self, filter: Optional[DictFilter]
+ ) -> Optional[rest.Filter]:
+ from qdrant_client.http import models as rest
+
+ if not filter:
+ return None
+
+ return rest.Filter(
+ must=[
+ condition
+ for key, value in filter.items()
+ for condition in self._build_condition(key, value)
+ ]
+ )
+
+ def _embed_query(self, query: str) -> List[float]:
+ """Embed query text.
+
+ Used to provide backward compatibility with `embedding_function` argument.
+
+ Args:
+ query: Query text.
+
+ Returns:
+ List of floats representing the query embedding.
+ """
+ if self.embeddings is not None:
+ embedding = self.embeddings.embed_query(query)
+ else:
+ if self._embeddings_function is not None:
+ embedding = self._embeddings_function(query)
+ else:
+ raise ValueError("Neither of embeddings or embedding_function is set")
+ return embedding.tolist() if hasattr(embedding, "tolist") else embedding
+
+ def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
+ """Embed search texts.
+
+ Used to provide backward compatibility with `embedding_function` argument.
+
+ Args:
+ texts: Iterable of texts to embed.
+
+ Returns:
+ List of floats representing the texts embedding.
+ """
+ if self.embeddings is not None:
+ embeddings = self.embeddings.embed_documents(list(texts))
+ if hasattr(embeddings, "tolist"):
+ embeddings = embeddings.tolist()
+ elif self._embeddings_function is not None:
+ embeddings = []
+ for text in texts:
+ embedding = self._embeddings_function(text)
+ if hasattr(embeddings, "tolist"):
+ embedding = embedding.tolist()
+ embeddings.append(embedding)
+ else:
+ raise ValueError("Neither of embeddings or embedding_function is set")
+
+ return embeddings
+
+ async def _aembed_texts(self, texts: Iterable[str]) -> List[List[float]]:
+ """Embed search texts.
+
+ Used to provide backward compatibility with `embedding_function` argument.
+
+ Args:
+ texts: Iterable of texts to embed.
+
+ Returns:
+ List of floats representing the texts embedding.
+ """
+ if self.embeddings is not None:
+ embeddings = await self.embeddings.aembed_documents(list(texts))
+ if hasattr(embeddings, "tolist"):
+ embeddings = embeddings.tolist()
+ elif self._embeddings_function is not None:
+ embeddings = []
+ for text in texts:
+ embedding = self._embeddings_function(text)
+ if hasattr(embeddings, "tolist"):
+ embedding = embedding.tolist()
+ embeddings.append(embedding)
+ else:
+ raise ValueError("Neither of embeddings or embedding_function is set")
+
+ return embeddings
+
+ def _generate_rest_batches(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[Sequence[str]] = None,
+ batch_size: int = 64,
+ ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
+ from qdrant_client.http import models as rest
+
+ texts_iterator = iter(texts)
+ metadatas_iterator = iter(metadatas or [])
+ ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
+ while batch_texts := list(islice(texts_iterator, batch_size)):
+ # Take the corresponding metadata and id for each text in a batch
+ batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
+ batch_ids = list(islice(ids_iterator, batch_size))
+
+ # Generate the embeddings for all the texts in a batch
+ batch_embeddings = self._embed_texts(batch_texts)
+
+ points = [
+ rest.PointStruct(
+ id=point_id,
+ vector=vector
+ if self.vector_name is None
+ else {self.vector_name: vector},
+ payload=payload,
+ )
+ for point_id, vector, payload in zip(
+ batch_ids,
+ batch_embeddings,
+ self._build_payloads(
+ batch_texts,
+ batch_metadatas,
+ self.content_payload_key,
+ self.metadata_payload_key,
+ ),
+ )
+ ]
+
+ yield batch_ids, points
+
+ async def _agenerate_rest_batches(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[Sequence[str]] = None,
+ batch_size: int = 64,
+ ) -> AsyncGenerator[Tuple[List[str], List[rest.PointStruct]], None]:
+ from qdrant_client.http import models as rest
+
+ texts_iterator = iter(texts)
+ metadatas_iterator = iter(metadatas or [])
+ ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
+ while batch_texts := list(islice(texts_iterator, batch_size)):
+ # Take the corresponding metadata and id for each text in a batch
+ batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
+ batch_ids = list(islice(ids_iterator, batch_size))
+
+ # Generate the embeddings for all the texts in a batch
+ batch_embeddings = await self._aembed_texts(batch_texts)
+
+ points = [
+ rest.PointStruct(
+ id=point_id,
+ vector=vector
+ if self.vector_name is None
+ else {self.vector_name: vector},
+ payload=payload,
+ )
+ for point_id, vector, payload in zip(
+ batch_ids,
+ batch_embeddings,
+ self._build_payloads(
+ batch_texts,
+ batch_metadatas,
+ self.content_payload_key,
+ self.metadata_payload_key,
+ ),
+ )
+ ]
+
+ yield batch_ids, points
diff --git a/libs/community/langchain_community/vectorstores/redis/__init__.py b/libs/community/langchain_community/vectorstores/redis/__init__.py
new file mode 100644
index 00000000000..dc088facf4f
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/redis/__init__.py
@@ -0,0 +1,16 @@
+from .base import Redis, RedisVectorStoreRetriever
+from .filters import (
+ RedisFilter,
+ RedisNum,
+ RedisTag,
+ RedisText,
+)
+
+__all__ = [
+ "Redis",
+ "RedisFilter",
+ "RedisTag",
+ "RedisText",
+ "RedisNum",
+ "RedisVectorStoreRetriever",
+]
diff --git a/libs/community/langchain_community/vectorstores/redis/base.py b/libs/community/langchain_community/vectorstores/redis/base.py
new file mode 100644
index 00000000000..6b7732bb4e1
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/redis/base.py
@@ -0,0 +1,1475 @@
+"""Wrapper around Redis vector database."""
+
+from __future__ import annotations
+
+import logging
+import os
+import uuid
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
+
+import numpy as np
+import yaml
+from langchain_core._api import deprecated
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
+
+from langchain_community.utilities.redis import (
+ _array_to_buffer,
+ _buffer_to_array,
+ check_redis_module_exist,
+ get_client,
+)
+from langchain_community.vectorstores.redis.constants import (
+ REDIS_REQUIRED_MODULES,
+ REDIS_TAG_SEPARATOR,
+)
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from redis.client import Redis as RedisType
+ from redis.commands.search.query import Query
+
+ from langchain_community.vectorstores.redis.filters import RedisFilterExpression
+ from langchain_community.vectorstores.redis.schema import RedisModel
+
+
+def _default_relevance_score(val: float) -> float:
+ return 1 - val
+
+
+def check_index_exists(client: RedisType, index_name: str) -> bool:
+ """Check if Redis index exists."""
+ try:
+ client.ft(index_name).info()
+ except: # noqa: E722
+ logger.debug("Index does not exist")
+ return False
+ logger.debug("Index already exists")
+ return True
+
+
+class Redis(VectorStore):
+ """Redis vector database.
+
+ To use, you should have the ``redis`` python package installed
+ and have a running Redis Enterprise or Redis-Stack server
+
+ For production use cases, it is recommended to use Redis Enterprise
+ as the scaling, performance, stability and availability is much
+ better than Redis-Stack.
+
+ For testing and prototyping, however, this is not required.
+ Redis-Stack is available as a docker container the full vector
+ search API available.
+
+ .. code-block:: bash
+
+ # to run redis stack in docker locally
+ docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
+
+ Once running, you can connect to the redis server with the following url schemas:
+ - redis://: # simple connection
+ - redis://:@: # connection with authentication
+ - rediss://: # connection with SSL
+ - rediss://:@: # connection with SSL and auth
+
+
+ Examples:
+
+ The following examples show various ways to use the Redis VectorStore with
+ LangChain.
+
+ For all the following examples assume we have the following imports:
+
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Redis
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ Initialize, create index, and load Documents
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Redis
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ rds = Redis.from_documents(
+ documents, # a list of Document objects from loaders or created
+ embeddings, # an Embeddings object
+ redis_url="redis://localhost:6379",
+ )
+
+ Initialize, create index, and load Documents with metadata
+ .. code-block:: python
+
+
+ rds = Redis.from_texts(
+ texts, # a list of strings
+ metadata, # a list of metadata dicts
+ embeddings, # an Embeddings object
+ redis_url="redis://localhost:6379",
+ )
+
+ Initialize, create index, and load Documents with metadata and return keys
+
+ .. code-block:: python
+
+ rds, keys = Redis.from_texts_return_keys(
+ texts, # a list of strings
+ metadata, # a list of metadata dicts
+ embeddings, # an Embeddings object
+ redis_url="redis://localhost:6379",
+ )
+
+ For use cases where the index needs to stay alive, you can initialize
+ with an index name such that it's easier to reference later
+
+ .. code-block:: python
+
+ rds = Redis.from_texts(
+ texts, # a list of strings
+ metadata, # a list of metadata dicts
+ embeddings, # an Embeddings object
+ index_name="my-index",
+ redis_url="redis://localhost:6379",
+ )
+
+ Initialize and connect to an existing index (from above)
+
+ .. code-block:: python
+
+ # must pass in schema and key_prefix from another index
+ existing_rds = Redis.from_existing_index(
+ embeddings, # an Embeddings object
+ index_name="my-index",
+ schema=rds.schema, # schema dumped from another index
+ key_prefix=rds.key_prefix, # key prefix from another index
+ redis_url="redis://localhost:6379",
+ )
+
+
+ Advanced examples:
+
+ Custom vector schema can be supplied to change the way that
+ Redis creates the underlying vector schema. This is useful
+ for production use cases where you want to optimize the
+ vector schema for your use case. ex. using HNSW instead of
+ FLAT (knn) which is the default
+
+ .. code-block:: python
+
+ vector_schema = {
+ "algorithm": "HNSW"
+ }
+
+ rds = Redis.from_texts(
+ texts, # a list of strings
+ metadata, # a list of metadata dicts
+ embeddings, # an Embeddings object
+ vector_schema=vector_schema,
+ redis_url="redis://localhost:6379",
+ )
+
+ Custom index schema can be supplied to change the way that the
+ metadata is indexed. This is useful for you would like to use the
+ hybrid querying (filtering) capability of Redis.
+
+ By default, this implementation will automatically generate the index
+ schema according to the following rules:
+ - All strings are indexed as text fields
+ - All numbers are indexed as numeric fields
+ - All lists of strings are indexed as tag fields (joined by
+ langchain.vectorstores.redis.constants.REDIS_TAG_SEPARATOR)
+ - All None values are not indexed but still stored in Redis these are
+ not retrievable through the interface here, but the raw Redis client
+ can be used to retrieve them.
+ - All other types are not indexed
+
+ To override these rules, you can pass in a custom index schema like the following
+
+ .. code-block:: yaml
+
+ tag:
+ - name: credit_score
+ text:
+ - name: user
+ - name: job
+
+ Typically, the ``credit_score`` field would be a text field since it's a string,
+ however, we can override this behavior by specifying the field type as shown with
+ the yaml config (can also be a dictionary) above and the code below.
+
+ .. code-block:: python
+
+ rds = Redis.from_texts(
+ texts, # a list of strings
+ metadata, # a list of metadata dicts
+ embeddings, # an Embeddings object
+ index_schema="path/to/index_schema.yaml", # can also be a dictionary
+ redis_url="redis://localhost:6379",
+ )
+
+ When connecting to an existing index where a custom schema has been applied, it's
+ important to pass in the same schema to the ``from_existing_index`` method.
+ Otherwise, the schema for newly added samples will be incorrect and metadata
+ will not be returned.
+
+ """
+
+ DEFAULT_VECTOR_SCHEMA = {
+ "name": "content_vector",
+ "algorithm": "FLAT",
+ "dims": 1536,
+ "distance_metric": "COSINE",
+ "datatype": "FLOAT32",
+ }
+
+ def __init__(
+ self,
+ redis_url: str,
+ index_name: str,
+ embedding: Embeddings,
+ index_schema: Optional[Union[Dict[str, str], str, os.PathLike]] = None,
+ vector_schema: Optional[Dict[str, Union[str, int]]] = None,
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ key_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ """Initialize Redis vector store with necessary components."""
+ self._check_deprecated_kwargs(kwargs)
+ try:
+ # TODO use importlib to check if redis is installed
+ import redis # noqa: F401
+
+ except ImportError as e:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ ) from e
+
+ self.index_name = index_name
+ self._embeddings = embedding
+ try:
+ redis_client = get_client(redis_url=redis_url, **kwargs)
+ # check if redis has redisearch module installed
+ check_redis_module_exist(redis_client, REDIS_REQUIRED_MODULES)
+ except ValueError as e:
+ raise ValueError(f"Redis failed to connect: {e}")
+
+ self.client = redis_client
+ self.relevance_score_fn = relevance_score_fn
+ self._schema = self._get_schema_with_defaults(index_schema, vector_schema)
+ self.key_prefix = key_prefix if key_prefix is not None else f"doc:{index_name}"
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ """Access the query embedding object if available."""
+ return self._embeddings
+
+ @classmethod
+ def from_texts_return_keys(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ index_name: Optional[str] = None,
+ index_schema: Optional[Union[Dict[str, str], str, os.PathLike]] = None,
+ vector_schema: Optional[Dict[str, Union[str, int]]] = None,
+ **kwargs: Any,
+ ) -> Tuple[Redis, List[str]]:
+ """Create a Redis vectorstore from raw documents.
+
+ This is a user-friendly interface that:
+ 1. Embeds documents.
+ 2. Creates a new Redis index if it doesn't already exist
+ 3. Adds the documents to the newly created Redis index.
+ 4. Returns the keys of the newly created documents once stored.
+
+ This method will generate schema based on the metadata passed in
+ if the `index_schema` is not defined. If the `index_schema` is defined,
+ it will compare against the generated schema and warn if there are
+ differences. If you are purposefully defining the schema for the
+ metadata, then you can ignore that warning.
+
+ To examine the schema options, initialize an instance of this class
+ and print out the schema using the `Redis.schema`` property. This
+ will include the content and content_vector classes which are
+ always present in the langchain schema.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Redis
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ redis, keys = Redis.from_texts_return_keys(
+ texts,
+ embeddings,
+ redis_url="redis://localhost:6379"
+ )
+
+ Args:
+ texts (List[str]): List of texts to add to the vectorstore.
+ embedding (Embeddings): Embeddings to use for the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadata
+ dicts to add to the vectorstore. Defaults to None.
+ index_name (Optional[str], optional): Optional name of the index to
+ create or add to. Defaults to None.
+ index_schema (Optional[Union[Dict[str, str], str, os.PathLike]], optional):
+ Optional fields to index within the metadata. Overrides generated
+ schema. Defaults to None.
+ vector_schema (Optional[Dict[str, Union[str, int]]], optional): Optional
+ vector schema to use. Defaults to None.
+ **kwargs (Any): Additional keyword arguments to pass to the Redis client.
+
+ Returns:
+ Tuple[Redis, List[str]]: Tuple of the Redis instance and the keys of
+ the newly created documents.
+
+ Raises:
+ ValueError: If the number of metadatas does not match the number of texts.
+ """
+ try:
+ # TODO use importlib to check if redis is installed
+ import redis # noqa: F401
+
+ from langchain_community.vectorstores.redis.schema import read_schema
+
+ except ImportError as e:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ ) from e
+
+ redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
+
+ if "redis_url" in kwargs:
+ kwargs.pop("redis_url")
+
+ # flag to use generated schema
+ if "generate" in kwargs:
+ kwargs.pop("generate")
+
+ # see if the user specified keys
+ keys = None
+ if "keys" in kwargs:
+ keys = kwargs.pop("keys")
+
+ # Name of the search index if not given
+ if not index_name:
+ index_name = uuid.uuid4().hex
+
+ # type check for metadata
+ if metadatas:
+ if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore # noqa: E501
+ raise ValueError("Number of metadatas must match number of texts")
+ if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)):
+ raise ValueError("Metadatas must be a list of dicts")
+
+ generated_schema = _generate_field_schema(metadatas[0])
+ if index_schema:
+ # read in the schema solely to compare to the generated schema
+ user_schema = read_schema(index_schema) # type: ignore
+
+ # the very rare case where a super user decides to pass the index
+ # schema and a document loader is used that has metadata which
+ # we need to map into fields.
+ if user_schema != generated_schema:
+ logger.warning(
+ "`index_schema` does not match generated metadata schema.\n"
+ + "If you meant to manually override the schema, please "
+ + "ignore this message.\n"
+ + f"index_schema: {user_schema}\n"
+ + f"generated_schema: {generated_schema}\n"
+ )
+ else:
+ # use the generated schema
+ index_schema = generated_schema
+
+ # Create instance
+ # init the class -- if Redis is unavailable, will throw exception
+ instance = cls(
+ redis_url,
+ index_name,
+ embedding,
+ index_schema=index_schema,
+ vector_schema=vector_schema,
+ **kwargs,
+ )
+
+ # Add data to Redis
+ keys = instance.add_texts(texts, metadatas, keys=keys)
+ return instance, keys
+
+ @classmethod
+ def from_texts(
+ cls: Type[Redis],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ index_name: Optional[str] = None,
+ index_schema: Optional[Union[Dict[str, str], str, os.PathLike]] = None,
+ vector_schema: Optional[Dict[str, Union[str, int]]] = None,
+ **kwargs: Any,
+ ) -> Redis:
+ """Create a Redis vectorstore from a list of texts.
+
+ This is a user-friendly interface that:
+ 1. Embeds documents.
+ 2. Creates a new Redis index if it doesn't already exist
+ 3. Adds the documents to the newly created Redis index.
+
+ This method will generate schema based on the metadata passed in
+ if the `index_schema` is not defined. If the `index_schema` is defined,
+ it will compare against the generated schema and warn if there are
+ differences. If you are purposefully defining the schema for the
+ metadata, then you can ignore that warning.
+
+ To examine the schema options, initialize an instance of this class
+ and print out the schema using the `Redis.schema`` property. This
+ will include the content and content_vector classes which are
+ always present in the langchain schema.
+
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Redis
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ redisearch = RediSearch.from_texts(
+ texts,
+ embeddings,
+ redis_url="redis://username:password@localhost:6379"
+ )
+
+ Args:
+ texts (List[str]): List of texts to add to the vectorstore.
+ embedding (Embeddings): Embedding model class (i.e. OpenAIEmbeddings)
+ for embedding queries.
+ metadatas (Optional[List[dict]], optional): Optional list of metadata dicts
+ to add to the vectorstore. Defaults to None.
+ index_name (Optional[str], optional): Optional name of the index to create
+ or add to. Defaults to None.
+ index_schema (Optional[Union[Dict[str, str], str, os.PathLike]], optional):
+ Optional fields to index within the metadata. Overrides generated
+ schema. Defaults to None.
+ vector_schema (Optional[Dict[str, Union[str, int]]], optional): Optional
+ vector schema to use. Defaults to None.
+ **kwargs (Any): Additional keyword arguments to pass to the Redis client.
+
+ Returns:
+ Redis: Redis VectorStore instance.
+
+ Raises:
+ ValueError: If the number of metadatas does not match the number of texts.
+ ImportError: If the redis python package is not installed.
+ """
+ instance, _ = cls.from_texts_return_keys(
+ texts,
+ embedding,
+ metadatas=metadatas,
+ index_name=index_name,
+ index_schema=index_schema,
+ vector_schema=vector_schema,
+ **kwargs,
+ )
+ return instance
+
+ @classmethod
+ def from_existing_index(
+ cls,
+ embedding: Embeddings,
+ index_name: str,
+ schema: Union[Dict[str, str], str, os.PathLike],
+ key_prefix: Optional[str] = None,
+ **kwargs: Any,
+ ) -> Redis:
+ """Connect to an existing Redis index.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Redis
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+
+ # must pass in schema and key_prefix from another index
+ existing_rds = Redis.from_existing_index(
+ embeddings,
+ index_name="my-index",
+ schema=rds.schema, # schema dumped from another index
+ key_prefix=rds.key_prefix, # key prefix from another index
+ redis_url="redis://username:password@localhost:6379",
+ )
+
+ Args:
+ embedding (Embeddings): Embedding model class (i.e. OpenAIEmbeddings)
+ for embedding queries.
+ index_name (str): Name of the index to connect to.
+ schema (Union[Dict[str, str], str, os.PathLike]): Schema of the index
+ and the vector schema. Can be a dict, or path to yaml file.
+ key_prefix (Optional[str]): Prefix to use for all keys in Redis associated
+ with this index.
+ **kwargs (Any): Additional keyword arguments to pass to the Redis client.
+
+ Returns:
+ Redis: Redis VectorStore instance.
+
+ Raises:
+ ValueError: If the index does not exist.
+ ImportError: If the redis python package is not installed.
+ """
+ redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
+ # We need to first remove redis_url from kwargs,
+ # otherwise passing it to Redis will result in an error.
+ if "redis_url" in kwargs:
+ kwargs.pop("redis_url")
+
+ # Create instance
+ # init the class -- if Redis is unavailable, will throw exception
+ instance = cls(
+ redis_url,
+ index_name,
+ embedding,
+ index_schema=schema,
+ key_prefix=key_prefix,
+ **kwargs,
+ )
+
+ # Check for existence of the declared index
+ if not check_index_exists(instance.client, index_name):
+ # Will only raise if the running Redis server does not
+ # have a record of this particular index
+ raise ValueError(
+ f"Redis failed to connect: Index {index_name} does not exist."
+ )
+
+ return instance
+
+ @property
+ def schema(self) -> Dict[str, List[Any]]:
+ """Return the schema of the index."""
+ return self._schema.as_dict()
+
+ def write_schema(self, path: Union[str, os.PathLike]) -> None:
+ """Write the schema to a yaml file."""
+ with open(path, "w+") as f:
+ yaml.dump(self.schema, f)
+
+ @staticmethod
+ def delete(
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> bool:
+ """
+ Delete a Redis entry.
+
+ Args:
+ ids: List of ids (keys in redis) to delete.
+ redis_url: Redis connection url. This should be passed in the kwargs
+ or set as an environment variable: REDIS_URL.
+
+ Returns:
+ bool: Whether or not the deletions were successful.
+
+ Raises:
+ ValueError: If the redis python package is not installed.
+ ValueError: If the ids (keys in redis) are not provided
+ """
+ redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
+
+ if ids is None:
+ raise ValueError("'ids' (keys)() were not provided.")
+
+ try:
+ import redis # noqa: F401
+ except ImportError:
+ raise ValueError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ )
+ try:
+ # We need to first remove redis_url from kwargs,
+ # otherwise passing it to Redis will result in an error.
+ if "redis_url" in kwargs:
+ kwargs.pop("redis_url")
+ client = get_client(redis_url=redis_url, **kwargs)
+ except ValueError as e:
+ raise ValueError(f"Your redis connected error: {e}")
+ # Check if index exists
+ try:
+ client.delete(*ids)
+ logger.info("Entries deleted")
+ return True
+ except: # noqa: E722
+ # ids does not exist
+ return False
+
+ @staticmethod
+ def drop_index(
+ index_name: str,
+ delete_documents: bool,
+ **kwargs: Any,
+ ) -> bool:
+ """
+ Drop a Redis search index.
+
+ Args:
+ index_name (str): Name of the index to drop.
+ delete_documents (bool): Whether to drop the associated documents.
+
+ Returns:
+ bool: Whether or not the drop was successful.
+ """
+ redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
+ try:
+ import redis # noqa: F401
+ except ImportError:
+ raise ValueError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ )
+ try:
+ # We need to first remove redis_url from kwargs,
+ # otherwise passing it to Redis will result in an error.
+ if "redis_url" in kwargs:
+ kwargs.pop("redis_url")
+ client = get_client(redis_url=redis_url, **kwargs)
+ except ValueError as e:
+ raise ValueError(f"Your redis connected error: {e}")
+ # Check if index exists
+ try:
+ client.ft(index_name).dropindex(delete_documents)
+ logger.info("Drop index")
+ return True
+ except: # noqa: E722
+ # Index not exist
+ return False
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ embeddings: Optional[List[List[float]]] = None,
+ batch_size: int = 1000,
+ clean_metadata: bool = True,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add more texts to the vectorstore.
+
+ Args:
+ texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ Defaults to None.
+ embeddings (Optional[List[List[float]]], optional): Optional pre-generated
+ embeddings. Defaults to None.
+ keys (List[str]) or ids (List[str]): Identifiers of entries.
+ Defaults to None.
+ batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
+
+ Returns:
+ List[str]: List of ids added to the vectorstore
+ """
+ ids = []
+
+ # Get keys or ids from kwargs
+ # Other vectorstores use ids
+ keys_or_ids = kwargs.get("keys", kwargs.get("ids"))
+
+ # type check for metadata
+ if metadatas:
+ if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore # noqa: E501
+ raise ValueError("Number of metadatas must match number of texts")
+ if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)):
+ raise ValueError("Metadatas must be a list of dicts")
+
+ embeddings = embeddings or self._embeddings.embed_documents(list(texts))
+ self._create_index_if_not_exist(dim=len(embeddings[0]))
+
+ # Write data to redis
+ pipeline = self.client.pipeline(transaction=False)
+ for i, text in enumerate(texts):
+ # Use provided values by default or fallback
+ key = keys_or_ids[i] if keys_or_ids else str(uuid.uuid4().hex)
+ if not key.startswith(self.key_prefix + ":"):
+ key = self.key_prefix + ":" + key
+ metadata = metadatas[i] if metadatas else {}
+ metadata = _prepare_metadata(metadata) if clean_metadata else metadata
+ pipeline.hset(
+ key,
+ mapping={
+ self._schema.content_key: text,
+ self._schema.content_vector_key: _array_to_buffer(
+ embeddings[i], self._schema.vector_dtype
+ ),
+ **metadata,
+ },
+ )
+ ids.append(key)
+
+ # Write batch
+ if i % batch_size == 0:
+ pipeline.execute()
+
+ # Cleanup final batch
+ pipeline.execute()
+ return ids
+
+ def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
+ tags = kwargs.pop("tags", None) or []
+ tags.extend(self._get_retriever_tags())
+ return RedisVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
+
+ @deprecated("0.0.272", alternative="similarity_search(distance_threshold=0.1)")
+ def similarity_search_limit_score(
+ self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any
+ ) -> List[Document]:
+ """
+ Returns the most similar indexed documents to the query text within the
+ score_threshold range.
+
+ Deprecated: Use similarity_search with distance_threshold instead.
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+ score_threshold (float): The minimum matching *distance* required
+ for a document to be considered a match. Defaults to 0.2.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query text
+ including the match score for each document.
+
+ Note:
+ If there are no documents that satisfy the score_threshold value,
+ an empty list is returned.
+
+ """
+ return self.similarity_search(
+ query, k=k, distance_threshold=score_threshold, **kwargs
+ )
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[RedisFilterExpression] = None,
+ return_metadata: bool = True,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Run similarity search with **vector distance**.
+
+ The "scores" returned from this function are the raw vector
+ distances from the query vector. For similarity scores, use
+ ``similarity_search_with_relevance_scores``.
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+ filter (RedisFilterExpression, optional): Optional metadata filter.
+ Defaults to None.
+ return_metadata (bool, optional): Whether to return metadata.
+ Defaults to True.
+
+ Returns:
+ List[Tuple[Document, float]]: A list of documents that are
+ most similar to the query with the distance for each document.
+ """
+ try:
+ import redis
+
+ except ImportError as e:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ ) from e
+
+ if "score_threshold" in kwargs:
+ logger.warning(
+ "score_threshold is deprecated. Use distance_threshold instead."
+ + "score_threshold should only be used in "
+ + "similarity_search_with_relevance_scores."
+ + "score_threshold will be removed in a future release.",
+ )
+
+ query_embedding = self._embeddings.embed_query(query)
+
+ redis_query, params_dict = self._prepare_query(
+ query_embedding,
+ k=k,
+ filter=filter,
+ with_metadata=return_metadata,
+ with_distance=True,
+ **kwargs,
+ )
+
+ # Perform vector search
+ # ignore type because redis-py is wrong about bytes
+ try:
+ results = self.client.ft(self.index_name).search(redis_query, params_dict) # type: ignore # noqa: E501
+ except redis.exceptions.ResponseError as e:
+ # split error message and see if it starts with "Syntax"
+ if str(e).split(" ")[0] == "Syntax":
+ raise ValueError(
+ "Query failed with syntax error. "
+ + "This is likely due to malformation of "
+ + "filter, vector, or query argument"
+ ) from e
+ raise e
+
+ # Prepare document results
+ docs_with_scores: List[Tuple[Document, float]] = []
+ for result in results.docs:
+ metadata = {}
+ if return_metadata:
+ metadata = {"id": result.id}
+ metadata.update(self._collect_metadata(result))
+
+ doc = Document(page_content=result.content, metadata=metadata)
+ distance = self._calculate_fp_distance(result.distance)
+ docs_with_scores.append((doc, distance))
+
+ return docs_with_scores
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[RedisFilterExpression] = None,
+ return_metadata: bool = True,
+ distance_threshold: Optional[float] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+ filter (RedisFilterExpression, optional): Optional metadata filter.
+ Defaults to None.
+ return_metadata (bool, optional): Whether to return metadata.
+ Defaults to True.
+ distance_threshold (Optional[float], optional): Maximum vector distance
+ between selected documents and the query vector. Defaults to None.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query
+ text.
+ """
+ query_embedding = self._embeddings.embed_query(query)
+ return self.similarity_search_by_vector(
+ query_embedding,
+ k=k,
+ filter=filter,
+ return_metadata=return_metadata,
+ distance_threshold=distance_threshold,
+ **kwargs,
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[RedisFilterExpression] = None,
+ return_metadata: bool = True,
+ distance_threshold: Optional[float] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search between a query vector and the indexed vectors.
+
+ Args:
+ embedding (List[float]): The query vector for which to find similar
+ documents.
+ k (int): The number of documents to return. Default is 4.
+ filter (RedisFilterExpression, optional): Optional metadata filter.
+ Defaults to None.
+ return_metadata (bool, optional): Whether to return metadata.
+ Defaults to True.
+ distance_threshold (Optional[float], optional): Maximum vector distance
+ between selected documents and the query vector. Defaults to None.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query
+ text.
+ """
+ try:
+ import redis
+
+ except ImportError as e:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ ) from e
+
+ if "score_threshold" in kwargs:
+ logger.warning(
+ "score_threshold is deprecated. Use distance_threshold instead."
+ + "score_threshold should only be used in "
+ + "similarity_search_with_relevance_scores."
+ + "score_threshold will be removed in a future release.",
+ )
+
+ redis_query, params_dict = self._prepare_query(
+ embedding,
+ k=k,
+ filter=filter,
+ distance_threshold=distance_threshold,
+ with_metadata=return_metadata,
+ with_distance=False,
+ )
+
+ # Perform vector search
+ # ignore type because redis-py is wrong about bytes
+ try:
+ results = self.client.ft(self.index_name).search(redis_query, params_dict) # type: ignore # noqa: E501
+ except redis.exceptions.ResponseError as e:
+ # split error message and see if it starts with "Syntax"
+ if str(e).split(" ")[0] == "Syntax":
+ raise ValueError(
+ "Query failed with syntax error. "
+ + "This is likely due to malformation of "
+ + "filter, vector, or query argument"
+ ) from e
+ raise e
+
+ # Prepare document results
+ docs = []
+ for result in results.docs:
+ metadata = {}
+ if return_metadata:
+ metadata = {"id": result.id}
+ metadata.update(self._collect_metadata(result))
+
+ content_key = self._schema.content_key
+ docs.append(
+ Document(page_content=getattr(result, content_key), metadata=metadata)
+ )
+ return docs
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[RedisFilterExpression] = None,
+ return_metadata: bool = True,
+ distance_threshold: Optional[float] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query (str): Text to look up documents similar to.
+ k (int): Number of Documents to return. Defaults to 4.
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult (float): Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ filter (RedisFilterExpression, optional): Optional metadata filter.
+ Defaults to None.
+ return_metadata (bool, optional): Whether to return metadata.
+ Defaults to True.
+ distance_threshold (Optional[float], optional): Maximum vector distance
+ between selected documents and the query vector. Defaults to None.
+
+ Returns:
+ List[Document]: A list of Documents selected by maximal marginal relevance.
+ """
+ # Embed the query
+ query_embedding = self._embeddings.embed_query(query)
+
+ # Fetch the initial documents
+ prefetch_docs = self.similarity_search_by_vector(
+ query_embedding,
+ k=fetch_k,
+ filter=filter,
+ return_metadata=return_metadata,
+ distance_threshold=distance_threshold,
+ **kwargs,
+ )
+ prefetch_ids = [doc.metadata["id"] for doc in prefetch_docs]
+
+ # Get the embeddings for the fetched documents
+ prefetch_embeddings = [
+ _buffer_to_array(
+ cast(
+ bytes,
+ self.client.hget(prefetch_id, self._schema.content_vector_key),
+ ),
+ dtype=self._schema.vector_dtype,
+ )
+ for prefetch_id in prefetch_ids
+ ]
+
+ # Select documents using maximal marginal relevance
+ selected_indices = maximal_marginal_relevance(
+ np.array(query_embedding), prefetch_embeddings, lambda_mult=lambda_mult, k=k
+ )
+ selected_docs = [prefetch_docs[i] for i in selected_indices]
+
+ return selected_docs
+
+ def _collect_metadata(self, result: "Document") -> Dict[str, Any]:
+ """Collect metadata from Redis.
+
+ Method ensures that there isn't a mismatch between the metadata
+ and the index schema passed to this class by the user or generated
+ by this class.
+
+ Args:
+ result (Document): redis.commands.search.Document object returned
+ from Redis.
+
+ Returns:
+ Dict[str, Any]: Collected metadata.
+ """
+ # new metadata dict as modified by this method
+ meta = {}
+ for key in self._schema.metadata_keys:
+ try:
+ meta[key] = getattr(result, key)
+ except AttributeError:
+ # warning about attribute missing
+ logger.warning(
+ f"Metadata key {key} not found in metadata. "
+ + "Setting to None. \n"
+ + "Metadata fields defined for this instance: "
+ + f"{self._schema.metadata_keys}"
+ )
+ meta[key] = None
+ return meta
+
+ def _prepare_query(
+ self,
+ query_embedding: List[float],
+ k: int = 4,
+ filter: Optional[RedisFilterExpression] = None,
+ distance_threshold: Optional[float] = None,
+ with_metadata: bool = True,
+ with_distance: bool = False,
+ ) -> Tuple["Query", Dict[str, Any]]:
+ # Creates Redis query
+ params_dict: Dict[str, Union[str, bytes, float]] = {
+ "vector": _array_to_buffer(query_embedding, self._schema.vector_dtype),
+ }
+
+ # prepare return fields including score
+ return_fields = [self._schema.content_key]
+ if with_distance:
+ return_fields.append("distance")
+ if with_metadata:
+ return_fields.extend(self._schema.metadata_keys)
+
+ if distance_threshold:
+ params_dict["distance_threshold"] = distance_threshold
+ return (
+ self._prepare_range_query(
+ k, filter=filter, return_fields=return_fields
+ ),
+ params_dict,
+ )
+ return (
+ self._prepare_vector_query(k, filter=filter, return_fields=return_fields),
+ params_dict,
+ )
+
+ def _prepare_range_query(
+ self,
+ k: int,
+ filter: Optional[RedisFilterExpression] = None,
+ return_fields: Optional[List[str]] = None,
+ ) -> "Query":
+ try:
+ from redis.commands.search.query import Query
+ except ImportError as e:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ ) from e
+ return_fields = return_fields or []
+ vector_key = self._schema.content_vector_key
+ base_query = f"@{vector_key}:[VECTOR_RANGE $distance_threshold $vector]"
+
+ if filter:
+ base_query = "(" + base_query + " " + str(filter) + ")"
+
+ query_string = base_query + "=>{$yield_distance_as: distance}"
+
+ return (
+ Query(query_string)
+ .return_fields(*return_fields)
+ .sort_by("distance")
+ .paging(0, k)
+ .dialect(2)
+ )
+
+ def _prepare_vector_query(
+ self,
+ k: int,
+ filter: Optional[RedisFilterExpression] = None,
+ return_fields: Optional[List[str]] = None,
+ ) -> "Query":
+ """Prepare query for vector search.
+
+ Args:
+ k: Number of results to return.
+ filter: Optional metadata filter.
+
+ Returns:
+ query: Query object.
+ """
+ try:
+ from redis.commands.search.query import Query
+ except ImportError as e:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ ) from e
+ return_fields = return_fields or []
+ query_prefix = "*"
+ if filter:
+ query_prefix = f"{str(filter)}"
+ vector_key = self._schema.content_vector_key
+ base_query = f"({query_prefix})=>[KNN {k} @{vector_key} $vector AS distance]"
+
+ query = (
+ Query(base_query)
+ .return_fields(*return_fields)
+ .sort_by("distance")
+ .paging(0, k)
+ .dialect(2)
+ )
+ return query
+
+ def _get_schema_with_defaults(
+ self,
+ index_schema: Optional[Union[Dict[str, str], str, os.PathLike]] = None,
+ vector_schema: Optional[Dict[str, Union[str, int]]] = None,
+ ) -> "RedisModel":
+ # should only be called after init of Redis (so Import handled)
+ from langchain_community.vectorstores.redis.schema import (
+ RedisModel,
+ read_schema,
+ )
+
+ schema = RedisModel()
+ # read in schema (yaml file or dict) and
+ # pass to the Pydantic validators
+ if index_schema:
+ schema_values = read_schema(index_schema) # type: ignore
+ schema = RedisModel(**schema_values)
+
+ # ensure user did not exclude the content field
+ # no modifications if content field found
+ schema.add_content_field()
+
+ # if no content_vector field, add vector field to schema
+ # this makes adding a vector field to the schema optional when
+ # the user just wants additional metadata
+ try:
+ # see if user overrode the content vector
+ schema.content_vector
+ # if user overrode the content vector, check if they
+ # also passed vector schema. This won't be used since
+ # the index schema overrode the content vector
+ if vector_schema:
+ logger.warning(
+ "`vector_schema` is ignored since content_vector is "
+ + "overridden in `index_schema`."
+ )
+
+ # user did not override content vector
+ except ValueError:
+ # set default vector schema and update with user provided schema
+ # if the user provided any
+ vector_field = self.DEFAULT_VECTOR_SCHEMA.copy()
+ if vector_schema:
+ vector_field.update(vector_schema)
+
+ # add the vector field either way
+ schema.add_vector_field(vector_field)
+ return schema
+
+ def _create_index_if_not_exist(self, dim: int = 1536) -> None:
+ try:
+ from redis.commands.search.indexDefinition import ( # type: ignore
+ IndexDefinition,
+ IndexType,
+ )
+
+ except ImportError:
+ raise ImportError(
+ "Could not import redis python package. "
+ "Please install it with `pip install redis`."
+ )
+
+ # Set vector dimension
+ # can't obtain beforehand because we don't
+ # know which embedding model is being used.
+ self._schema.content_vector.dims = dim
+
+ # Check if index exists
+ if not check_index_exists(self.client, self.index_name):
+ # Create Redis Index
+ self.client.ft(self.index_name).create_index(
+ fields=self._schema.get_fields(),
+ definition=IndexDefinition(
+ prefix=[self.key_prefix], index_type=IndexType.HASH
+ ),
+ )
+
+ def _calculate_fp_distance(self, distance: str) -> float:
+ """Calculate the distance based on the vector datatype
+
+ Two datatypes supported:
+ - FLOAT32
+ - FLOAT64
+
+ if it's FLOAT32, we need to round the distance to 4 decimal places
+ otherwise, round to 7 decimal places.
+ """
+ if self._schema.content_vector.datatype == "FLOAT32":
+ return round(float(distance), 4)
+ return round(float(distance), 7)
+
+ def _check_deprecated_kwargs(self, kwargs: Mapping[str, Any]) -> None:
+ """Check for deprecated kwargs."""
+
+ deprecated_kwargs = {
+ "redis_host": "redis_url",
+ "redis_port": "redis_url",
+ "redis_password": "redis_url",
+ "content_key": "index_schema",
+ "vector_key": "vector_schema",
+ "distance_metric": "vector_schema",
+ }
+ for key, value in kwargs.items():
+ if key in deprecated_kwargs:
+ raise ValueError(
+ f"Keyword argument '{key}' is deprecated. "
+ f"Please use '{deprecated_kwargs[key]}' instead."
+ )
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ if self.relevance_score_fn:
+ return self.relevance_score_fn
+
+ metric_map = {
+ "COSINE": self._cosine_relevance_score_fn,
+ "IP": self._max_inner_product_relevance_score_fn,
+ "L2": self._euclidean_relevance_score_fn,
+ }
+ try:
+ return metric_map[self._schema.content_vector.distance_metric]
+ except KeyError:
+ return _default_relevance_score
+
+
+def _generate_field_schema(data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Generate a schema for the search index in Redis based on the input metadata.
+
+ Given a dictionary of metadata, this function categorizes each metadata
+ field into one of the three categories:
+ - text: The field contains textual data.
+ - numeric: The field contains numeric data (either integer or float).
+ - tag: The field contains list of tags (strings).
+
+ Args
+ data (Dict[str, Any]): A dictionary where keys are metadata field names
+ and values are the metadata values.
+
+ Returns:
+ Dict[str, Any]: A dictionary with three keys "text", "numeric", and "tag".
+ Each key maps to a list of fields that belong to that category.
+
+ Raises:
+ ValueError: If a metadata field cannot be categorized into any of
+ the three known types.
+ """
+ result: Dict[str, Any] = {
+ "text": [],
+ "numeric": [],
+ "tag": [],
+ }
+
+ for key, value in data.items():
+ # Numeric fields
+ try:
+ int(value)
+ result["numeric"].append({"name": key})
+ continue
+ except (ValueError, TypeError):
+ pass
+
+ # None values are not indexed as of now
+ if value is None:
+ continue
+
+ # if it's a list of strings, we assume it's a tag
+ if isinstance(value, (list, tuple)):
+ if not value or isinstance(value[0], str):
+ result["tag"].append({"name": key})
+ else:
+ name = type(value[0]).__name__
+ raise ValueError(
+ f"List/tuple values should contain strings: '{key}': {name}"
+ )
+ continue
+
+ # Check if value is string before processing further
+ if isinstance(value, str):
+ result["text"].append({"name": key})
+ continue
+
+ # Unable to classify the field value
+ name = type(value).__name__
+ raise ValueError(
+ "Could not generate Redis index field type mapping "
+ + f"for metadata: '{key}': {name}"
+ )
+
+ return result
+
+
+def _prepare_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Prepare metadata for indexing in Redis by sanitizing its values.
+
+ - String, integer, and float values remain unchanged.
+ - None or empty values are replaced with empty strings.
+ - Lists/tuples of strings are joined into a single string with a comma separator.
+
+ Args:
+ metadata (Dict[str, Any]): A dictionary where keys are metadata
+ field names and values are the metadata values.
+
+ Returns:
+ Dict[str, Any]: A sanitized dictionary ready for indexing in Redis.
+
+ Raises:
+ ValueError: If any metadata value is not one of the known
+ types (string, int, float, or list of strings).
+ """
+
+ def raise_error(key: str, value: Any) -> None:
+ raise ValueError(
+ f"Metadata value for key '{key}' must be a string, int, "
+ + f"float, or list of strings. Got {type(value).__name__}"
+ )
+
+ clean_meta: Dict[str, Union[str, float, int]] = {}
+ for key, value in metadata.items():
+ if value is None:
+ clean_meta[key] = ""
+ continue
+
+ # No transformation needed
+ if isinstance(value, (str, int, float)):
+ clean_meta[key] = value
+
+ # if it's a list/tuple of strings, we join it
+ elif isinstance(value, (list, tuple)):
+ if not value or isinstance(value[0], str):
+ clean_meta[key] = REDIS_TAG_SEPARATOR.join(value)
+ else:
+ raise_error(key, value)
+ else:
+ raise_error(key, value)
+ return clean_meta
+
+
+class RedisVectorStoreRetriever(VectorStoreRetriever):
+ """Retriever for Redis VectorStore."""
+
+ vectorstore: Redis
+ """Redis VectorStore."""
+ search_type: str = "similarity"
+ """Type of search to perform. Can be either
+ 'similarity',
+ 'similarity_distance_threshold',
+ 'similarity_score_threshold'
+ """
+
+ search_kwargs: Dict[str, Any] = {
+ "k": 4,
+ "score_threshold": 0.9,
+ # set to None to avoid distance used in score_threshold search
+ "distance_threshold": None,
+ }
+ """Default search kwargs."""
+
+ allowed_search_types = [
+ "similarity",
+ "similarity_distance_threshold",
+ "similarity_score_threshold",
+ "mmr",
+ ]
+ """Allowed search types."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ if self.search_type == "similarity":
+ docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
+ elif self.search_type == "similarity_distance_threshold":
+ if self.search_kwargs["distance_threshold"] is None:
+ raise ValueError(
+ "distance_threshold must be provided for "
+ + "similarity_distance_threshold retriever"
+ )
+ docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
+
+ elif self.search_type == "similarity_score_threshold":
+ docs_and_similarities = (
+ self.vectorstore.similarity_search_with_relevance_scores(
+ query, **self.search_kwargs
+ )
+ )
+ docs = [doc for doc, _ in docs_and_similarities]
+ elif self.search_type == "mmr":
+ docs = self.vectorstore.max_marginal_relevance_search(
+ query, **self.search_kwargs
+ )
+ else:
+ raise ValueError(f"search_type of {self.search_type} not allowed.")
+ return docs
+
+ def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
+ """Add documents to vectorstore."""
+ return self.vectorstore.add_documents(documents, **kwargs)
+
+ async def aadd_documents(
+ self, documents: List[Document], **kwargs: Any
+ ) -> List[str]:
+ """Add documents to vectorstore."""
+ return await self.vectorstore.aadd_documents(documents, **kwargs)
diff --git a/libs/community/langchain_community/vectorstores/redis/constants.py b/libs/community/langchain_community/vectorstores/redis/constants.py
new file mode 100644
index 00000000000..ddbfe4c5847
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/redis/constants.py
@@ -0,0 +1,20 @@
+from typing import Any, Dict, List
+
+import numpy as np
+
+# required modules
+REDIS_REQUIRED_MODULES = [
+ {"name": "search", "ver": 20600},
+ {"name": "searchlight", "ver": 20600},
+]
+
+# distance metrics
+REDIS_DISTANCE_METRICS: List[str] = ["COSINE", "IP", "L2"]
+
+# supported vector datatypes
+REDIS_VECTOR_DTYPE_MAP: Dict[str, Any] = {
+ "FLOAT32": np.float32,
+ "FLOAT64": np.float64,
+}
+
+REDIS_TAG_SEPARATOR = ","
diff --git a/libs/community/langchain_community/vectorstores/redis/filters.py b/libs/community/langchain_community/vectorstores/redis/filters.py
new file mode 100644
index 00000000000..f40c6473f14
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/redis/filters.py
@@ -0,0 +1,462 @@
+from enum import Enum
+from functools import wraps
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
+
+from langchain_community.utilities.redis import TokenEscaper
+
+# disable mypy error for dunder method overrides
+# mypy: disable-error-code="override"
+
+
+class RedisFilterOperator(Enum):
+ """RedisFilterOperator enumerator is used to create RedisFilterExpressions."""
+
+ EQ = 1
+ NE = 2
+ LT = 3
+ GT = 4
+ LE = 5
+ GE = 6
+ OR = 7
+ AND = 8
+ LIKE = 9
+ IN = 10
+
+
+class RedisFilter:
+ """Collection of RedisFilterFields."""
+
+ @staticmethod
+ def text(field: str) -> "RedisText":
+ return RedisText(field)
+
+ @staticmethod
+ def num(field: str) -> "RedisNum":
+ return RedisNum(field)
+
+ @staticmethod
+ def tag(field: str) -> "RedisTag":
+ return RedisTag(field)
+
+
+class RedisFilterField:
+ """Base class for RedisFilterFields."""
+
+ escaper: "TokenEscaper" = TokenEscaper()
+ OPERATORS: Dict[RedisFilterOperator, str] = {}
+
+ def __init__(self, field: str):
+ self._field = field
+ self._value: Any = None
+ self._operator: RedisFilterOperator = RedisFilterOperator.EQ
+
+ def equals(self, other: "RedisFilterField") -> bool:
+ if not isinstance(other, type(self)):
+ return False
+ return self._field == other._field and self._value == other._value
+
+ def _set_value(
+ self, val: Any, val_type: Tuple[Any], operator: RedisFilterOperator
+ ) -> None:
+ # check that the operator is supported by this class
+ if operator not in self.OPERATORS:
+ raise ValueError(
+ f"Operator {operator} not supported by {self.__class__.__name__}. "
+ + f"Supported operators are {self.OPERATORS.values()}."
+ )
+
+ if not isinstance(val, val_type):
+ raise TypeError(
+ f"Right side argument passed to operator {self.OPERATORS[operator]} "
+ f"with left side "
+ f"argument {self.__class__.__name__} must be of type {val_type}, "
+ f"received value {val}"
+ )
+ self._value = val
+ self._operator = operator
+
+
+def check_operator_misuse(func: Callable) -> Callable:
+ """Decorator to check for misuse of equality operators."""
+
+ @wraps(func)
+ def wrapper(instance: Any, *args: Any, **kwargs: Any) -> Any:
+ # Extracting 'other' from positional arguments or keyword arguments
+ other = kwargs.get("other") if "other" in kwargs else None
+ if not other:
+ for arg in args:
+ if isinstance(arg, type(instance)):
+ other = arg
+ break
+
+ if isinstance(other, type(instance)):
+ raise ValueError(
+ "Equality operators are overridden for FilterExpression creation. Use "
+ ".equals() for equality checks"
+ )
+ return func(instance, *args, **kwargs)
+
+ return wrapper
+
+
+class RedisTag(RedisFilterField):
+ """A RedisFilterField representing a tag in a Redis index."""
+
+ OPERATORS: Dict[RedisFilterOperator, str] = {
+ RedisFilterOperator.EQ: "==",
+ RedisFilterOperator.NE: "!=",
+ RedisFilterOperator.IN: "==",
+ }
+ OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
+ RedisFilterOperator.EQ: "@%s:{%s}",
+ RedisFilterOperator.NE: "(-@%s:{%s})",
+ RedisFilterOperator.IN: "@%s:{%s}",
+ }
+ SUPPORTED_VAL_TYPES = (list, set, tuple, str, type(None))
+
+ def __init__(self, field: str):
+ """Create a RedisTag FilterField.
+
+ Args:
+ field (str): The name of the RedisTag field in the index to be queried
+ against.
+ """
+ super().__init__(field)
+
+ def _set_tag_value(
+ self,
+ other: Union[List[str], Set[str], Tuple[str], str],
+ operator: RedisFilterOperator,
+ ) -> None:
+ if isinstance(other, (list, set, tuple)):
+ try:
+ # "if val" clause removes non-truthy values from list
+ other = [str(val) for val in other if val]
+ except ValueError:
+ raise ValueError("All tags within collection must be strings")
+ # above to catch the "" case
+ elif not other:
+ other = []
+ elif isinstance(other, str):
+ other = [other]
+
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore
+
+ @check_operator_misuse
+ def __eq__(
+ self, other: Union[List[str], Set[str], Tuple[str], str]
+ ) -> "RedisFilterExpression":
+ """Create a RedisTag equality filter expression.
+
+ Args:
+ other (Union[List[str], Set[str], Tuple[str], str]):
+ The tag(s) to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisTag
+ >>> filter = RedisTag("brand") == "nike"
+ """
+ self._set_tag_value(other, RedisFilterOperator.EQ)
+ return RedisFilterExpression(str(self))
+
+ @check_operator_misuse
+ def __ne__(
+ self, other: Union[List[str], Set[str], Tuple[str], str]
+ ) -> "RedisFilterExpression":
+ """Create a RedisTag inequality filter expression.
+
+ Args:
+ other (Union[List[str], Set[str], Tuple[str], str]):
+ The tag(s) to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisTag
+ >>> filter = RedisTag("brand") != "nike"
+ """
+ self._set_tag_value(other, RedisFilterOperator.NE)
+ return RedisFilterExpression(str(self))
+
+ @property
+ def _formatted_tag_value(self) -> str:
+ return "|".join([self.escaper.escape(tag) for tag in self._value])
+
+ def __str__(self) -> str:
+ """Return the query syntax for a RedisTag filter expression."""
+ if not self._value:
+ return "*"
+
+ return self.OPERATOR_MAP[self._operator] % (
+ self._field,
+ self._formatted_tag_value,
+ )
+
+
+class RedisNum(RedisFilterField):
+ """A RedisFilterField representing a numeric field in a Redis index."""
+
+ OPERATORS: Dict[RedisFilterOperator, str] = {
+ RedisFilterOperator.EQ: "==",
+ RedisFilterOperator.NE: "!=",
+ RedisFilterOperator.LT: "<",
+ RedisFilterOperator.GT: ">",
+ RedisFilterOperator.LE: "<=",
+ RedisFilterOperator.GE: ">=",
+ }
+ OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
+ RedisFilterOperator.EQ: "@%s:[%s %s]",
+ RedisFilterOperator.NE: "(-@%s:[%s %s])",
+ RedisFilterOperator.GT: "@%s:[(%s +inf]",
+ RedisFilterOperator.LT: "@%s:[-inf (%s]",
+ RedisFilterOperator.GE: "@%s:[%s +inf]",
+ RedisFilterOperator.LE: "@%s:[-inf %s]",
+ }
+ SUPPORTED_VAL_TYPES = (int, float, type(None))
+
+ def __str__(self) -> str:
+ """Return the query syntax for a RedisNum filter expression."""
+ if not self._value:
+ return "*"
+
+ if (
+ self._operator == RedisFilterOperator.EQ
+ or self._operator == RedisFilterOperator.NE
+ ):
+ return self.OPERATOR_MAP[self._operator] % (
+ self._field,
+ self._value,
+ self._value,
+ )
+ else:
+ return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
+
+ @check_operator_misuse
+ def __eq__(self, other: Union[int, float]) -> "RedisFilterExpression":
+ """Create a Numeric equality filter expression.
+
+ Args:
+ other (Union[int, float]): The value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisNum
+ >>> filter = RedisNum("zipcode") == 90210
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
+ return RedisFilterExpression(str(self))
+
+ @check_operator_misuse
+ def __ne__(self, other: Union[int, float]) -> "RedisFilterExpression":
+ """Create a Numeric inequality filter expression.
+
+ Args:
+ other (Union[int, float]): The value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisNum
+ >>> filter = RedisNum("zipcode") != 90210
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
+ return RedisFilterExpression(str(self))
+
+ def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
+ """Create a Numeric greater than filter expression.
+
+ Args:
+ other (Union[int, float]): The value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisNum
+ >>> filter = RedisNum("age") > 18
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore
+ return RedisFilterExpression(str(self))
+
+ def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
+ """Create a Numeric less than filter expression.
+
+ Args:
+ other (Union[int, float]): The value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisNum
+ >>> filter = RedisNum("age") < 18
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore
+ return RedisFilterExpression(str(self))
+
+ def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
+ """Create a Numeric greater than or equal to filter expression.
+
+ Args:
+ other (Union[int, float]): The value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisNum
+ >>> filter = RedisNum("age") >= 18
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore
+ return RedisFilterExpression(str(self))
+
+ def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
+ """Create a Numeric less than or equal to filter expression.
+
+ Args:
+ other (Union[int, float]): The value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisNum
+ >>> filter = RedisNum("age") <= 18
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore
+ return RedisFilterExpression(str(self))
+
+
+class RedisText(RedisFilterField):
+ """A RedisFilterField representing a text field in a Redis index."""
+
+ OPERATORS: Dict[RedisFilterOperator, str] = {
+ RedisFilterOperator.EQ: "==",
+ RedisFilterOperator.NE: "!=",
+ RedisFilterOperator.LIKE: "%",
+ }
+ OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
+ RedisFilterOperator.EQ: '@%s:("%s")',
+ RedisFilterOperator.NE: '(-@%s:"%s")',
+ RedisFilterOperator.LIKE: "@%s:(%s)",
+ }
+ SUPPORTED_VAL_TYPES = (str, type(None))
+
+ @check_operator_misuse
+ def __eq__(self, other: str) -> "RedisFilterExpression":
+ """Create a RedisText equality (exact match) filter expression.
+
+ Args:
+ other (str): The text value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisText
+ >>> filter = RedisText("job") == "engineer"
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
+ return RedisFilterExpression(str(self))
+
+ @check_operator_misuse
+ def __ne__(self, other: str) -> "RedisFilterExpression":
+ """Create a RedisText inequality filter expression.
+
+ Args:
+ other (str): The text value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisText
+ >>> filter = RedisText("job") != "engineer"
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
+ return RedisFilterExpression(str(self))
+
+ def __mod__(self, other: str) -> "RedisFilterExpression":
+ """Create a RedisText "LIKE" filter expression.
+
+ Args:
+ other (str): The text value to filter on.
+
+ Example:
+ >>> from langchain_community.vectorstores.redis import RedisText
+ >>> filter = RedisText("job") % "engine*" # suffix wild card match
+ >>> filter = RedisText("job") % "%%engine%%" # fuzzy match w/ LD
+ >>> filter = RedisText("job") % "engineer|doctor" # contains either term
+ >>> filter = RedisText("job") % "engineer doctor" # contains both terms
+ """
+ self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore
+ return RedisFilterExpression(str(self))
+
+ def __str__(self) -> str:
+ """Return the query syntax for a RedisText filter expression."""
+ if not self._value:
+ return "*"
+
+ return self.OPERATOR_MAP[self._operator] % (
+ self._field,
+ self._value,
+ )
+
+
+class RedisFilterExpression:
+ """A logical expression of RedisFilterFields.
+
+ RedisFilterExpressions can be combined using the & and | operators to create
+ complex logical expressions that evaluate to the Redis Query language.
+
+ This presents an interface by which users can create complex queries
+ without having to know the Redis Query language.
+
+ Filter expressions are not initialized directly. Instead they are built
+ by combining RedisFilterFields using the & and | operators.
+
+ Examples:
+
+ >>> from langchain_community.vectorstores.redis import RedisTag, RedisNum
+ >>> brand_is_nike = RedisTag("brand") == "nike"
+ >>> price_is_under_100 = RedisNum("price") < 100
+ >>> filter = brand_is_nike & price_is_under_100
+ >>> print(str(filter))
+ (@brand:{nike} @price:[-inf (100)])
+
+ """
+
+ def __init__(
+ self,
+ _filter: Optional[str] = None,
+ operator: Optional[RedisFilterOperator] = None,
+ left: Optional["RedisFilterExpression"] = None,
+ right: Optional["RedisFilterExpression"] = None,
+ ):
+ self._filter = _filter
+ self._operator = operator
+ self._left = left
+ self._right = right
+
+ def __and__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
+ return RedisFilterExpression(
+ operator=RedisFilterOperator.AND, left=self, right=other
+ )
+
+ def __or__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
+ return RedisFilterExpression(
+ operator=RedisFilterOperator.OR, left=self, right=other
+ )
+
+ @staticmethod
+ def format_expression(
+ left: "RedisFilterExpression", right: "RedisFilterExpression", operator_str: str
+ ) -> str:
+ _left, _right = str(left), str(right)
+ if _left == _right == "*":
+ return _left
+ if _left == "*" != _right:
+ return _right
+ if _right == "*" != _left:
+ return _left
+ return f"({_left}{operator_str}{_right})"
+
+ def __str__(self) -> str:
+ # top level check that allows recursive calls to __str__
+ if not self._filter and not self._operator:
+ raise ValueError("Improperly initialized RedisFilterExpression")
+
+ # if there's an operator, combine expressions accordingly
+ if self._operator:
+ if not isinstance(self._left, RedisFilterExpression) or not isinstance(
+ self._right, RedisFilterExpression
+ ):
+ raise TypeError(
+ "Improper combination of filters."
+ "Both left and right should be type FilterExpression"
+ )
+
+ operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
+ return self.format_expression(self._left, self._right, operator_str)
+
+ # check that base case, the filter is set
+ if not self._filter:
+ raise ValueError("Improperly initialized RedisFilterExpression")
+ return self._filter
diff --git a/libs/community/langchain_community/vectorstores/redis/schema.py b/libs/community/langchain_community/vectorstores/redis/schema.py
new file mode 100644
index 00000000000..8269b71f699
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/redis/schema.py
@@ -0,0 +1,308 @@
+from __future__ import annotations
+
+import os
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import yaml
+from langchain_core.pydantic_v1 import BaseModel, Field, validator
+from typing_extensions import TYPE_CHECKING, Literal
+
+from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
+
+if TYPE_CHECKING:
+ from redis.commands.search.field import ( # type: ignore
+ NumericField,
+ TagField,
+ TextField,
+ VectorField,
+ )
+
+
+class RedisDistanceMetric(str, Enum):
+ """Distance metrics for Redis vector fields."""
+
+ l2 = "L2"
+ cosine = "COSINE"
+ ip = "IP"
+
+
+class RedisField(BaseModel):
+ """Base class for Redis fields."""
+
+ name: str = Field(...)
+
+
+class TextFieldSchema(RedisField):
+ """Schema for text fields in Redis."""
+
+ weight: float = 1
+ no_stem: bool = False
+ phonetic_matcher: Optional[str] = None
+ withsuffixtrie: bool = False
+ no_index: bool = False
+ sortable: Optional[bool] = False
+
+ def as_field(self) -> TextField:
+ from redis.commands.search.field import TextField # type: ignore
+
+ return TextField(
+ self.name,
+ weight=self.weight,
+ no_stem=self.no_stem,
+ phonetic_matcher=self.phonetic_matcher, # type: ignore
+ sortable=self.sortable,
+ no_index=self.no_index,
+ )
+
+
+class TagFieldSchema(RedisField):
+ """Schema for tag fields in Redis."""
+
+ separator: str = ","
+ case_sensitive: bool = False
+ no_index: bool = False
+ sortable: Optional[bool] = False
+
+ def as_field(self) -> TagField:
+ from redis.commands.search.field import TagField # type: ignore
+
+ return TagField(
+ self.name,
+ separator=self.separator,
+ case_sensitive=self.case_sensitive,
+ sortable=self.sortable,
+ no_index=self.no_index,
+ )
+
+
+class NumericFieldSchema(RedisField):
+ """Schema for numeric fields in Redis."""
+
+ no_index: bool = False
+ sortable: Optional[bool] = False
+
+ def as_field(self) -> NumericField:
+ from redis.commands.search.field import NumericField # type: ignore
+
+ return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
+
+
+class RedisVectorField(RedisField):
+ """Base class for Redis vector fields."""
+
+ dims: int = Field(...)
+ algorithm: object = Field(...)
+ datatype: str = Field(default="FLOAT32")
+ distance_metric: RedisDistanceMetric = Field(default="COSINE")
+ initial_cap: Optional[int] = None
+
+ @validator("algorithm", "datatype", "distance_metric", pre=True, each_item=True)
+ def uppercase_strings(cls, v: str) -> str:
+ return v.upper()
+
+ @validator("datatype", pre=True)
+ def uppercase_and_check_dtype(cls, v: str) -> str:
+ if v.upper() not in REDIS_VECTOR_DTYPE_MAP:
+ raise ValueError(
+ f"datatype must be one of {REDIS_VECTOR_DTYPE_MAP.keys()}. Got {v}"
+ )
+ return v.upper()
+
+ def _fields(self) -> Dict[str, Any]:
+ field_data = {
+ "TYPE": self.datatype,
+ "DIM": self.dims,
+ "DISTANCE_METRIC": self.distance_metric,
+ }
+ if self.initial_cap is not None: # Only include it if it's set
+ field_data["INITIAL_CAP"] = self.initial_cap
+ return field_data
+
+
+class FlatVectorField(RedisVectorField):
+ """Schema for flat vector fields in Redis."""
+
+ algorithm: Literal["FLAT"] = "FLAT"
+ block_size: Optional[int] = None
+
+ def as_field(self) -> VectorField:
+ from redis.commands.search.field import VectorField # type: ignore
+
+ field_data = super()._fields()
+ if self.block_size is not None:
+ field_data["BLOCK_SIZE"] = self.block_size
+ return VectorField(self.name, self.algorithm, field_data)
+
+
+class HNSWVectorField(RedisVectorField):
+ """Schema for HNSW vector fields in Redis."""
+
+ algorithm: Literal["HNSW"] = "HNSW"
+ m: int = Field(default=16)
+ ef_construction: int = Field(default=200)
+ ef_runtime: int = Field(default=10)
+ epsilon: float = Field(default=0.01)
+
+ def as_field(self) -> VectorField:
+ from redis.commands.search.field import VectorField # type: ignore
+
+ field_data = super()._fields()
+ field_data.update(
+ {
+ "M": self.m,
+ "EF_CONSTRUCTION": self.ef_construction,
+ "EF_RUNTIME": self.ef_runtime,
+ "EPSILON": self.epsilon,
+ }
+ )
+ return VectorField(self.name, self.algorithm, field_data)
+
+
+class RedisModel(BaseModel):
+ """Schema for Redis index."""
+
+ # always have a content field for text
+ text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
+ tag: Optional[List[TagFieldSchema]] = None
+ numeric: Optional[List[NumericFieldSchema]] = None
+ extra: Optional[List[RedisField]] = None
+
+ # filled by default_vector_schema
+ vector: Optional[List[Union[FlatVectorField, HNSWVectorField]]] = None
+ content_key: str = "content"
+ content_vector_key: str = "content_vector"
+
+ def add_content_field(self) -> None:
+ if self.text is None:
+ self.text = []
+ for field in self.text:
+ if field.name == self.content_key:
+ return
+ self.text.append(TextFieldSchema(name=self.content_key))
+
+ def add_vector_field(self, vector_field: Dict[str, Any]) -> None:
+ # catch case where user inputted no vector field spec
+ # in the index schema
+ if self.vector is None:
+ self.vector = []
+
+ # ignore types as pydantic is handling type validation and conversion
+ if vector_field["algorithm"] == "FLAT":
+ self.vector.append(FlatVectorField(**vector_field)) # type: ignore
+ elif vector_field["algorithm"] == "HNSW":
+ self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
+ else:
+ raise ValueError(
+ f"algorithm must be either FLAT or HNSW. Got "
+ f"{vector_field['algorithm']}"
+ )
+
+ def as_dict(self) -> Dict[str, List[Any]]:
+ schemas: Dict[str, List[Any]] = {"text": [], "tag": [], "numeric": []}
+ # iter over all class attributes
+ for attr, attr_value in self.__dict__.items():
+ # only non-empty lists
+ if isinstance(attr_value, list) and len(attr_value) > 0:
+ field_values: List[Dict[str, Any]] = []
+ # iterate over all fields in each category (tag, text, etc)
+ for val in attr_value:
+ value: Dict[str, Any] = {}
+ # iterate over values within each field to extract
+ # settings for that field (i.e. name, weight, etc)
+ for field, field_value in val.__dict__.items():
+ # make enums into strings
+ if isinstance(field_value, Enum):
+ value[field] = field_value.value
+ # don't write null values
+ elif field_value is not None:
+ value[field] = field_value
+ field_values.append(value)
+
+ schemas[attr] = field_values
+
+ schema: Dict[str, List[Any]] = {}
+ # only write non-empty lists from defaults
+ for k, v in schemas.items():
+ if len(v) > 0:
+ schema[k] = v
+ return schema
+
+ @property
+ def content_vector(self) -> Union[FlatVectorField, HNSWVectorField]:
+ if not self.vector:
+ raise ValueError("No vector fields found")
+ for field in self.vector:
+ if field.name == self.content_vector_key:
+ return field
+ raise ValueError("No content_vector field found")
+
+ @property
+ def vector_dtype(self) -> np.dtype:
+ # should only ever be called after pydantic has validated the schema
+ return REDIS_VECTOR_DTYPE_MAP[self.content_vector.datatype]
+
+ @property
+ def is_empty(self) -> bool:
+ return all(
+ field is None for field in [self.tag, self.text, self.numeric, self.vector]
+ )
+
+ def get_fields(self) -> List["RedisField"]:
+ redis_fields: List["RedisField"] = []
+ if self.is_empty:
+ return redis_fields
+
+ for field_name in self.__fields__.keys():
+ if field_name not in ["content_key", "content_vector_key", "extra"]:
+ field_group = getattr(self, field_name)
+ if field_group is not None:
+ for field in field_group:
+ redis_fields.append(field.as_field())
+ return redis_fields
+
+ @property
+ def metadata_keys(self) -> List[str]:
+ keys: List[str] = []
+ if self.is_empty:
+ return keys
+
+ for field_name in self.__fields__.keys():
+ field_group = getattr(self, field_name)
+ if field_group is not None:
+ for field in field_group:
+ # check if it's a metadata field. exclude vector and content key
+ if not isinstance(field, str) and field.name not in [
+ self.content_key,
+ self.content_vector_key,
+ ]:
+ keys.append(field.name)
+ return keys
+
+
+def read_schema(
+ index_schema: Optional[Union[Dict[str, List[Any]], str, os.PathLike]],
+) -> Dict[str, Any]:
+ """Reads in the index schema from a dict or yaml file.
+
+ Check if it is a dict and return RedisModel otherwise, check if it's a path and
+ read in the file assuming it's a yaml file and return a RedisModel
+ """
+ if isinstance(index_schema, dict):
+ return index_schema
+ elif isinstance(index_schema, Path):
+ with open(index_schema, "rb") as f:
+ return yaml.safe_load(f)
+ elif isinstance(index_schema, str):
+ if Path(index_schema).resolve().is_file():
+ with open(index_schema, "rb") as f:
+ return yaml.safe_load(f)
+ else:
+ raise FileNotFoundError(f"index_schema file {index_schema} does not exist")
+ else:
+ raise TypeError(
+ f"index_schema must be a dict, or path to a yaml file "
+ f"Got {type(index_schema)}"
+ )
diff --git a/libs/community/langchain_community/vectorstores/rocksetdb.py b/libs/community/langchain_community/vectorstores/rocksetdb.py
new file mode 100644
index 00000000000..992f53db081
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/rocksetdb.py
@@ -0,0 +1,334 @@
+from __future__ import annotations
+
+import logging
+from enum import Enum
+from typing import Any, Iterable, List, Optional, Tuple
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger(__name__)
+
+
+class Rockset(VectorStore):
+ """`Rockset` vector store.
+
+ To use, you should have the `rockset` python package installed. Note that to use
+ this, the collection being used must already exist in your Rockset instance.
+ You must also ensure you use a Rockset ingest transformation to apply
+ `VECTOR_ENFORCE` on the column being used to store `embedding_key` in the
+ collection.
+ See: https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details
+
+ Everything below assumes `commons` Rockset workspace.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Rockset
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ import rockset
+
+ # Make sure you use the right host (region) for your Rockset instance
+ # and APIKEY has both read-write access to your collection.
+
+ rs = rockset.RocksetClient(host=rockset.Regions.use1a1, api_key="***")
+ collection_name = "langchain_demo"
+ embeddings = OpenAIEmbeddings()
+ vectorstore = Rockset(rs, collection_name, embeddings,
+ "description", "description_embedding")
+
+ """
+
+ def __init__(
+ self,
+ client: Any,
+ embeddings: Embeddings,
+ collection_name: str,
+ text_key: str,
+ embedding_key: str,
+ workspace: str = "commons",
+ ):
+ """Initialize with Rockset client.
+ Args:
+ client: Rockset client object
+ collection: Rockset collection to insert docs / query
+ embeddings: Langchain Embeddings object to use to generate
+ embedding for given text.
+ text_key: column in Rockset collection to use to store the text
+ embedding_key: column in Rockset collection to use to store the embedding.
+ Note: We must apply `VECTOR_ENFORCE()` on this column via
+ Rockset ingest transformation.
+
+ """
+ try:
+ from rockset import RocksetClient
+ except ImportError:
+ raise ImportError(
+ "Could not import rockset client python package. "
+ "Please install it with `pip install rockset`."
+ )
+
+ if not isinstance(client, RocksetClient):
+ raise ValueError(
+ f"client should be an instance of rockset.RocksetClient, "
+ f"got {type(client)}"
+ )
+ # TODO: check that `collection_name` exists in rockset. Create if not.
+ self._client = client
+ self._collection_name = collection_name
+ self._embeddings = embeddings
+ self._text_key = text_key
+ self._embedding_key = embedding_key
+ self._workspace = workspace
+
+ try:
+ self._client.set_application("langchain")
+ except AttributeError:
+ # ignore
+ pass
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embeddings
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ batch_size: int = 32,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids to associate with the texts.
+ batch_size: Send documents in batches to rockset.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+
+ """
+ batch: list[dict] = []
+ stored_ids = []
+
+ for i, text in enumerate(texts):
+ if len(batch) == batch_size:
+ stored_ids += self._write_documents_to_rockset(batch)
+ batch = []
+ doc = {}
+ if metadatas and len(metadatas) > i:
+ doc = metadatas[i]
+ if ids and len(ids) > i:
+ doc["_id"] = ids[i]
+ doc[self._text_key] = text
+ doc[self._embedding_key] = self._embeddings.embed_query(text)
+ batch.append(doc)
+ if len(batch) > 0:
+ stored_ids += self._write_documents_to_rockset(batch)
+ batch = []
+ return stored_ids
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ client: Any = None,
+ collection_name: str = "",
+ text_key: str = "",
+ embedding_key: str = "",
+ ids: Optional[List[str]] = None,
+ batch_size: int = 32,
+ **kwargs: Any,
+ ) -> Rockset:
+ """Create Rockset wrapper with existing texts.
+ This is intended as a quicker way to get started.
+ """
+
+ # Sanitize inputs
+ assert client is not None, "Rockset Client cannot be None"
+ assert collection_name, "Collection name cannot be empty"
+ assert text_key, "Text key name cannot be empty"
+ assert embedding_key, "Embedding key cannot be empty"
+
+ rockset = cls(client, embedding, collection_name, text_key, embedding_key)
+ rockset.add_texts(texts, metadatas, ids, batch_size)
+ return rockset
+
+ # Rockset supports these vector distance functions.
+ class DistanceFunction(Enum):
+ COSINE_SIM = "COSINE_SIM"
+ EUCLIDEAN_DIST = "EUCLIDEAN_DIST"
+ DOT_PRODUCT = "DOT_PRODUCT"
+
+ # how to sort results for "similarity"
+ def order_by(self) -> str:
+ if self.value == "EUCLIDEAN_DIST":
+ return "ASC"
+ return "DESC"
+
+ def similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Perform a similarity search with Rockset
+
+ Args:
+ query (str): Text to look up documents similar to.
+ distance_func (DistanceFunction): how to compute distance between two
+ vectors in Rockset.
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): Metadata filters supplied as a
+ SQL `where` condition string. Defaults to None.
+ eg. "price<=70.0 AND brand='Nintendo'"
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection.
+
+ Returns:
+ List[Tuple[Document, float]]: List of documents with their relevance score
+ """
+ return self.similarity_search_by_vector_with_relevance_scores(
+ self._embeddings.embed_query(query),
+ k,
+ distance_func,
+ where_str,
+ **kwargs,
+ )
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Same as `similarity_search_with_relevance_scores` but
+ doesn't return the scores.
+ """
+ return self.similarity_search_by_vector(
+ self._embeddings.embed_query(query),
+ k,
+ distance_func,
+ where_str,
+ **kwargs,
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Accepts a query_embedding (vector), and returns documents with
+ similar embeddings."""
+
+ docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
+ embedding, k, distance_func, where_str, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_by_vector_with_relevance_scores(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Accepts a query_embedding (vector), and returns documents with
+ similar embeddings along with their relevance scores."""
+
+ q_str = self._build_query_sql(embedding, distance_func, k, where_str)
+ try:
+ query_response = self._client.Queries.query(sql={"query": q_str})
+ except Exception as e:
+ logger.error("Exception when querying Rockset: %s\n", e)
+ return []
+ finalResult: list[Tuple[Document, float]] = []
+ for document in query_response.results:
+ metadata = {}
+ assert isinstance(
+ document, dict
+ ), "document should be of type `dict[str,Any]`. But found: `{}`".format(
+ type(document)
+ )
+ for k, v in document.items():
+ if k == self._text_key:
+ assert isinstance(v, str), (
+ "page content stored in column `{}` must be of type `str`. "
+ "But found: `{}`"
+ ).format(self._text_key, type(v))
+ page_content = v
+ elif k == "dist":
+ assert isinstance(v, float), (
+ "Computed distance between vectors must of type `float`. "
+ "But found {}"
+ ).format(type(v))
+ score = v
+ elif k not in ["_id", "_event_time", "_meta"]:
+ # These columns are populated by Rockset when documents are
+ # inserted. No need to return them in metadata dict.
+ metadata[k] = v
+ finalResult.append(
+ (Document(page_content=page_content, metadata=metadata), score)
+ )
+ return finalResult
+
+ # Helper functions
+
+ def _build_query_sql(
+ self,
+ query_embedding: List[float],
+ distance_func: DistanceFunction,
+ k: int = 4,
+ where_str: Optional[str] = None,
+ ) -> str:
+ """Builds Rockset SQL query to query similar vectors to query_vector"""
+
+ q_embedding_str = ",".join(map(str, query_embedding))
+ distance_str = f"""{distance_func.value}({self._embedding_key}, \
+[{q_embedding_str}]) as dist"""
+ where_str = f"WHERE {where_str}\n" if where_str else ""
+ return f"""\
+SELECT * EXCEPT({self._embedding_key}), {distance_str}
+FROM {self._workspace}.{self._collection_name}
+{where_str}\
+ORDER BY dist {distance_func.order_by()}
+LIMIT {str(k)}
+"""
+
+ def _write_documents_to_rockset(self, batch: List[dict]) -> List[str]:
+ add_doc_res = self._client.Documents.add_documents(
+ collection=self._collection_name, data=batch, workspace=self._workspace
+ )
+ return [doc_status._id for doc_status in add_doc_res.data]
+
+ def delete_texts(self, ids: List[str]) -> None:
+ """Delete a list of docs from the Rockset collection"""
+ try:
+ from rockset.models import DeleteDocumentsRequestData
+ except ImportError:
+ raise ImportError(
+ "Could not import rockset client python package. "
+ "Please install it with `pip install rockset`."
+ )
+
+ self._client.Documents.delete_documents(
+ collection=self._collection_name,
+ data=[DeleteDocumentsRequestData(id=i) for i in ids],
+ workspace=self._workspace,
+ )
diff --git a/libs/community/langchain_community/vectorstores/scann.py b/libs/community/langchain_community/vectorstores/scann.py
new file mode 100644
index 00000000000..11be67dcaac
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/scann.py
@@ -0,0 +1,544 @@
+from __future__ import annotations
+
+import operator
+import pickle
+import uuid
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.docstore.base import AddableMixin, Docstore
+from langchain_community.docstore.in_memory import InMemoryDocstore
+from langchain_community.vectorstores.utils import DistanceStrategy
+
+
+def normalize(x: np.ndarray) -> np.ndarray:
+ """Normalize vectors to unit length."""
+ x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None)
+ return x
+
+
+def dependable_scann_import() -> Any:
+ """
+ Import `scann` if available, otherwise raise error.
+ """
+ try:
+ import scann
+ except ImportError:
+ raise ImportError(
+ "Could not import scann python package. "
+ "Please install it with `pip install scann` "
+ )
+ return scann
+
+
+class ScaNN(VectorStore):
+ """`ScaNN` vector store.
+
+ To use, you should have the ``scann`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import HuggingFaceEmbeddings
+ from langchain_community.vectorstores import ScaNN
+
+ db = ScaNN.from_texts(
+ ['foo', 'bar', 'barz', 'qux'],
+ HuggingFaceEmbeddings())
+ db.similarity_search('foo?', k=1)
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ index: Any,
+ docstore: Docstore,
+ index_to_docstore_id: Dict[int, str],
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ normalize_L2: bool = False,
+ distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
+ scann_config: Optional[str] = None,
+ ):
+ """Initialize with necessary components."""
+ self.embedding = embedding
+ self.index = index
+ self.docstore = docstore
+ self.index_to_docstore_id = index_to_docstore_id
+ self.distance_strategy = distance_strategy
+ self.override_relevance_score_fn = relevance_score_fn
+ self._normalize_L2 = normalize_L2
+ self._scann_config = scann_config
+
+ def __add(
+ self,
+ texts: Iterable[str],
+ embeddings: Iterable[List[float]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ if not isinstance(self.docstore, AddableMixin):
+ raise ValueError(
+ "If trying to add texts, the underlying docstore should support "
+ f"adding items, which {self.docstore} does not"
+ )
+ raise NotImplementedError("Updates are not available in ScaNN, yet.")
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of unique IDs.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ # Embed and create the documents.
+ embeddings = self.embedding.embed_documents(list(texts))
+ return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
+
+ def add_embeddings(
+ self,
+ text_embeddings: Iterable[Tuple[str, List[float]]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ text_embeddings: Iterable pairs of string and embedding to
+ add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of unique IDs.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ if not isinstance(self.docstore, AddableMixin):
+ raise ValueError(
+ "If trying to add texts, the underlying docstore should support "
+ f"adding items, which {self.docstore} does not"
+ )
+ # Embed and create the documents.
+ texts, embeddings = zip(*text_embeddings)
+
+ return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete by vector ID or other criteria.
+
+ Args:
+ ids: List of ids to delete.
+ **kwargs: Other keyword arguments that subclasses might use.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+
+ raise NotImplementedError("Deletions are not available in ScaNN, yet.")
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ embedding: Embedding vector to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+ **kwargs: kwargs to be passed to similarity search. Can include:
+ score_threshold: Optional, a floating point value between 0 to 1 to
+ filter the resulting set of retrieved docs
+
+ Returns:
+ List of documents most similar to the query text and L2 distance
+ in float for each. Lower score represents more similarity.
+ """
+ vector = np.array([embedding], dtype=np.float32)
+ if self._normalize_L2:
+ vector = normalize(vector)
+ indices, scores = self.index.search_batched(
+ vector, k if filter is None else fetch_k
+ )
+ docs = []
+ for j, i in enumerate(indices[0]):
+ if i == -1:
+ # This happens when not enough docs are returned.
+ continue
+ _id = self.index_to_docstore_id[i]
+ doc = self.docstore.search(_id)
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {_id}, got {doc}")
+ if filter is not None:
+ filter = {
+ key: [value] if not isinstance(value, list) else value
+ for key, value in filter.items()
+ }
+ if all(doc.metadata.get(key) in value for key, value in filter.items()):
+ docs.append((doc, scores[0][j]))
+ else:
+ docs.append((doc, scores[0][j]))
+
+ score_threshold = kwargs.get("score_threshold")
+ if score_threshold is not None:
+ cmp = (
+ operator.ge
+ if self.distance_strategy
+ in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
+ else operator.le
+ )
+ docs = [
+ (doc, similarity)
+ for doc, similarity in docs
+ if cmp(similarity, score_threshold)
+ ]
+ return docs[:k]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of documents most similar to the query text with
+ L2 distance in float. Lower score represents more similarity.
+ """
+ embedding = self.embedding.embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of Documents most similar to the embedding.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding,
+ k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query, k, filter=filter, fetch_k=fetch_k, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ @classmethod
+ def __from(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ normalize_L2: bool = False,
+ **kwargs: Any,
+ ) -> ScaNN:
+ scann = dependable_scann_import()
+ distance_strategy = kwargs.get(
+ "distance_strategy", DistanceStrategy.EUCLIDEAN_DISTANCE
+ )
+ scann_config = kwargs.get("scann_config", None)
+
+ vector = np.array(embeddings, dtype=np.float32)
+ if normalize_L2:
+ vector = normalize(vector)
+ if scann_config is not None:
+ index = scann.scann_ops_pybind.create_searcher(vector, scann_config)
+ else:
+ if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ index = (
+ scann.scann_ops_pybind.builder(vector, 1, "dot_product")
+ .score_brute_force()
+ .build()
+ )
+ else:
+ # Default to L2, currently other metric types not initialized.
+ index = (
+ scann.scann_ops_pybind.builder(vector, 1, "squared_l2")
+ .score_brute_force()
+ .build()
+ )
+ documents = []
+ if ids is None:
+ ids = [str(uuid.uuid4()) for _ in texts]
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+ documents.append(Document(page_content=text, metadata=metadata))
+ index_to_id = dict(enumerate(ids))
+
+ if len(index_to_id) != len(documents):
+ raise Exception(
+ f"{len(index_to_id)} ids provided for {len(documents)} documents."
+ " Each document should have an id."
+ )
+
+ docstore = InMemoryDocstore(dict(zip(index_to_id.values(), documents)))
+ return cls(
+ embedding,
+ index,
+ docstore,
+ index_to_id,
+ normalize_L2=normalize_L2,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> ScaNN:
+ """Construct ScaNN wrapper from raw documents.
+
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Creates an in memory docstore
+ 3. Initializes the ScaNN database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ScaNN
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ scann = ScaNN.from_texts(texts, embeddings)
+ """
+ embeddings = embedding.embed_documents(texts)
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> ScaNN:
+ """Construct ScaNN wrapper from raw documents.
+
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Creates an in memory docstore
+ 3. Initializes the ScaNN database
+
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import ScaNN
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ scann = ScaNN.from_embeddings(text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ **kwargs,
+ )
+
+ def save_local(self, folder_path: str, index_name: str = "index") -> None:
+ """Save ScaNN index, docstore, and index_to_docstore_id to disk.
+
+ Args:
+ folder_path: folder path to save index, docstore,
+ and index_to_docstore_id to.
+ """
+ path = Path(folder_path)
+ scann_path = path / "{index_name}.scann".format(index_name=index_name)
+ scann_path.mkdir(exist_ok=True, parents=True)
+
+ # save index separately since it is not picklable
+ self.index.serialize(str(scann_path))
+
+ # save docstore and index_to_docstore_id
+ with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
+ pickle.dump((self.docstore, self.index_to_docstore_id), f)
+
+ @classmethod
+ def load_local(
+ cls,
+ folder_path: str,
+ embedding: Embeddings,
+ index_name: str = "index",
+ **kwargs: Any,
+ ) -> ScaNN:
+ """Load ScaNN index, docstore, and index_to_docstore_id from disk.
+
+ Args:
+ folder_path: folder path to load index, docstore,
+ and index_to_docstore_id from.
+ embeddings: Embeddings to use when generating queries
+ index_name: for saving with a specific index file name
+ """
+ path = Path(folder_path)
+ scann_path = path / "{index_name}.scann".format(index_name=index_name)
+ scann_path.mkdir(exist_ok=True, parents=True)
+ # load index separately since it is not picklable
+ scann = dependable_scann_import()
+ index = scann.scann_ops_pybind.load_searcher(str(scann_path))
+
+ # load docstore and index_to_docstore_id
+ with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
+ docstore, index_to_docstore_id = pickle.load(f)
+ return cls(embedding, index, docstore, index_to_docstore_id, **kwargs)
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+ if self.override_relevance_score_fn is not None:
+ return self.override_relevance_score_fn
+
+ # Default strategy is to rely on distance strategy provided in
+ # vectorstore constructor
+ if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ return self._max_inner_product_relevance_score_fn
+ elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
+ # Default behavior is to use euclidean distance relevancy
+ return self._euclidean_relevance_score_fn
+ else:
+ raise ValueError(
+ "Unknown distance strategy, must be cosine, max_inner_product,"
+ " or euclidean"
+ )
+
+ def _similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and their similarity scores on a scale from 0 to 1."""
+ # Pop score threshold so that only relevancy scores, not raw scores, are
+ # filtered.
+ score_threshold = kwargs.pop("score_threshold", None)
+ relevance_score_fn = self._select_relevance_score_fn()
+ if relevance_score_fn is None:
+ raise ValueError(
+ "normalize_score_fn must be provided to"
+ " ScaNN constructor to normalize scores"
+ )
+ docs_and_scores = self.similarity_search_with_score(
+ query,
+ k=k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ docs_and_rel_scores = [
+ (doc, relevance_score_fn(score)) for doc, score in docs_and_scores
+ ]
+ if score_threshold is not None:
+ docs_and_rel_scores = [
+ (doc, similarity)
+ for doc, similarity in docs_and_rel_scores
+ if similarity >= score_threshold
+ ]
+ return docs_and_rel_scores
diff --git a/libs/community/langchain_community/vectorstores/semadb.py b/libs/community/langchain_community/vectorstores/semadb.py
new file mode 100644
index 00000000000..130d957277d
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/semadb.py
@@ -0,0 +1,272 @@
+from typing import Any, Iterable, List, Optional, Tuple
+from uuid import uuid4
+
+import numpy as np
+import requests
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_env
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import DistanceStrategy
+
+
+class SemaDB(VectorStore):
+ """`SemaDB` vector store.
+
+ This vector store is a wrapper around the SemaDB database.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import SemaDB
+
+ db = SemaDB('mycollection', 768, embeddings, DistanceStrategy.COSINE)
+
+ """
+
+ HOST = "semadb.p.rapidapi.com"
+ BASE_URL = "https://" + HOST
+
+ def __init__(
+ self,
+ collection_name: str,
+ vector_size: int,
+ embedding: Embeddings,
+ distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
+ api_key: str = "",
+ ):
+ """Initialise the SemaDB vector store."""
+ self.collection_name = collection_name
+ self.vector_size = vector_size
+ self.api_key = api_key or get_from_env("api_key", "SEMADB_API_KEY")
+ self._embedding = embedding
+ self.distance_strategy = distance_strategy
+
+ @property
+ def headers(self) -> dict:
+ """Return the common headers."""
+ return {
+ "content-type": "application/json",
+ "X-RapidAPI-Key": self.api_key,
+ "X-RapidAPI-Host": SemaDB.HOST,
+ }
+
+ def _get_internal_distance_strategy(self) -> str:
+ """Return the internal distance strategy."""
+ if self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
+ return "euclidean"
+ elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ raise ValueError("Max inner product is not supported by SemaDB")
+ elif self.distance_strategy == DistanceStrategy.DOT_PRODUCT:
+ return "dot"
+ elif self.distance_strategy == DistanceStrategy.JACCARD:
+ raise ValueError("Max inner product is not supported by SemaDB")
+ elif self.distance_strategy == DistanceStrategy.COSINE:
+ return "cosine"
+ else:
+ raise ValueError(f"Unknown distance strategy {self.distance_strategy}")
+
+ def create_collection(self) -> bool:
+ """Creates the corresponding collection in SemaDB."""
+ payload = {
+ "id": self.collection_name,
+ "vectorSize": self.vector_size,
+ "distanceMetric": self._get_internal_distance_strategy(),
+ }
+ response = requests.post(
+ SemaDB.BASE_URL + "/collections",
+ json=payload,
+ headers=self.headers,
+ )
+ return response.status_code == 200
+
+ def delete_collection(self) -> bool:
+ """Deletes the corresponding collection in SemaDB."""
+ response = requests.delete(
+ SemaDB.BASE_URL + f"/collections/{self.collection_name}",
+ headers=self.headers,
+ )
+ return response.status_code == 200
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ batch_size: int = 1000,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add texts to the vector store."""
+ if not isinstance(texts, list):
+ texts = list(texts)
+ embeddings = self._embedding.embed_documents(texts)
+ # Check dimensions
+ if len(embeddings[0]) != self.vector_size:
+ raise ValueError(
+ f"Embedding size mismatch {len(embeddings[0])} != {self.vector_size}"
+ )
+ # Normalise if needed
+ if self.distance_strategy == DistanceStrategy.COSINE:
+ embed_matrix = np.array(embeddings)
+ embed_matrix = embed_matrix / np.linalg.norm(
+ embed_matrix, axis=1, keepdims=True
+ )
+ embeddings = embed_matrix.tolist()
+ # Create points
+ ids: List[str] = []
+ points = []
+ if metadatas is not None:
+ for text, embedding, metadata in zip(texts, embeddings, metadatas):
+ new_id = str(uuid4())
+ ids.append(new_id)
+ points.append(
+ {
+ "id": new_id,
+ "vector": embedding,
+ "metadata": {**metadata, **{"text": text}},
+ }
+ )
+ else:
+ for text, embedding in zip(texts, embeddings):
+ new_id = str(uuid4())
+ ids.append(new_id)
+ points.append(
+ {
+ "id": new_id,
+ "vector": embedding,
+ "metadata": {"text": text},
+ }
+ )
+ # Insert points in batches
+ for i in range(0, len(points), batch_size):
+ batch = points[i : i + batch_size]
+ response = requests.post(
+ SemaDB.BASE_URL + f"/collections/{self.collection_name}/points",
+ json={"points": batch},
+ headers=self.headers,
+ )
+ if response.status_code != 200:
+ print("HERE--", batch)
+ raise ValueError(f"Error adding points: {response.text}")
+ failed_ranges = response.json()["failedRanges"]
+ if len(failed_ranges) > 0:
+ raise ValueError(f"Error adding points: {failed_ranges}")
+ # Return ids
+ return ids
+
+ @property
+ def embeddings(self) -> Embeddings:
+ """Return the embeddings."""
+ return self._embedding
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete by vector ID or other criteria.
+
+ Args:
+ ids: List of ids to delete.
+ **kwargs: Other keyword arguments that subclasses might use.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+ payload = {
+ "ids": ids,
+ }
+ response = requests.delete(
+ SemaDB.BASE_URL + f"/collections/{self.collection_name}/points",
+ json=payload,
+ headers=self.headers,
+ )
+ return response.status_code == 200 and len(response.json()["failedPoints"]) == 0
+
+ def _search_points(self, embedding: List[float], k: int = 4) -> List[dict]:
+ """Search points."""
+ # Normalise if needed
+ if self.distance_strategy == DistanceStrategy.COSINE:
+ vec = np.array(embedding)
+ vec = vec / np.linalg.norm(vec)
+ embedding = vec.tolist()
+ # Perform search request
+ payload = {
+ "vector": embedding,
+ "limit": k,
+ }
+ response = requests.post(
+ SemaDB.BASE_URL + f"/collections/{self.collection_name}/points/search",
+ json=payload,
+ headers=self.headers,
+ )
+ if response.status_code != 200:
+ raise ValueError(f"Error searching: {response.text}")
+ return response.json()["points"]
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query."""
+ query_embedding = self._embedding.embed_query(query)
+ return self.similarity_search_by_vector(query_embedding, k=k)
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Run similarity search with distance."""
+ query_embedding = self._embedding.embed_query(query)
+ points = self._search_points(query_embedding, k=k)
+ return [
+ (
+ Document(page_content=p["metadata"]["text"], metadata=p["metadata"]),
+ p["distance"],
+ )
+ for p in points
+ ]
+
+ def similarity_search_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ points = self._search_points(embedding, k=k)
+ return [
+ Document(page_content=p["metadata"]["text"], metadata=p["metadata"])
+ for p in points
+ ]
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = "",
+ vector_size: int = 0,
+ api_key: str = "",
+ distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
+ **kwargs: Any,
+ ) -> "SemaDB":
+ """Return VectorStore initialized from texts and embeddings."""
+ if not collection_name:
+ raise ValueError("Collection name must be provided")
+ if not vector_size:
+ raise ValueError("Vector size must be provided")
+ if not api_key:
+ raise ValueError("API key must be provided")
+ semadb = cls(
+ collection_name,
+ vector_size,
+ embedding,
+ distance_strategy=distance_strategy,
+ api_key=api_key,
+ )
+ if not semadb.create_collection():
+ raise ValueError("Error creating collection")
+ semadb.add_texts(texts, metadatas=metadatas)
+ return semadb
diff --git a/libs/community/langchain_community/vectorstores/singlestoredb.py b/libs/community/langchain_community/vectorstores/singlestoredb.py
new file mode 100644
index 00000000000..c33858e9657
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/singlestoredb.py
@@ -0,0 +1,448 @@
+from __future__ import annotations
+
+import json
+import re
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
+from sqlalchemy.pool import QueuePool
+
+from langchain_community.vectorstores.utils import DistanceStrategy
+
+DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.DOT_PRODUCT
+
+ORDERING_DIRECTIVE: dict = {
+ DistanceStrategy.EUCLIDEAN_DISTANCE: "",
+ DistanceStrategy.DOT_PRODUCT: "DESC",
+}
+
+
+class SingleStoreDB(VectorStore):
+ """`SingleStore DB` vector store.
+
+ The prerequisite for using this class is the installation of the ``singlestoredb``
+ Python package.
+
+ The SingleStoreDB vectorstore can be created by providing an embedding function and
+ the relevant parameters for the database connection, connection pool, and
+ optionally, the names of the table and the fields to use.
+ """
+
+ def _get_connection(self: SingleStoreDB) -> Any:
+ try:
+ import singlestoredb as s2
+ except ImportError:
+ raise ImportError(
+ "Could not import singlestoredb python package. "
+ "Please install it with `pip install singlestoredb`."
+ )
+ return s2.connect(**self.connection_kwargs)
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ *,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ table_name: str = "embeddings",
+ content_field: str = "content",
+ metadata_field: str = "metadata",
+ vector_field: str = "vector",
+ pool_size: int = 5,
+ max_overflow: int = 10,
+ timeout: float = 30,
+ **kwargs: Any,
+ ):
+ """Initialize with necessary components.
+
+ Args:
+ embedding (Embeddings): A text embedding model.
+
+ distance_strategy (DistanceStrategy, optional):
+ Determines the strategy employed for calculating
+ the distance between vectors in the embedding space.
+ Defaults to DOT_PRODUCT.
+ Available options are:
+ - DOT_PRODUCT: Computes the scalar product of two vectors.
+ This is the default behavior
+ - EUCLIDEAN_DISTANCE: Computes the Euclidean distance between
+ two vectors. This metric considers the geometric distance in
+ the vector space, and might be more suitable for embeddings
+ that rely on spatial relationships.
+
+ table_name (str, optional): Specifies the name of the table in use.
+ Defaults to "embeddings".
+ content_field (str, optional): Specifies the field to store the content.
+ Defaults to "content".
+ metadata_field (str, optional): Specifies the field to store metadata.
+ Defaults to "metadata".
+ vector_field (str, optional): Specifies the field to store the vector.
+ Defaults to "vector".
+
+ Following arguments pertain to the connection pool:
+
+ pool_size (int, optional): Determines the number of active connections in
+ the pool. Defaults to 5.
+ max_overflow (int, optional): Determines the maximum number of connections
+ allowed beyond the pool_size. Defaults to 10.
+ timeout (float, optional): Specifies the maximum wait time in seconds for
+ establishing a connection. Defaults to 30.
+
+ Following arguments pertain to the database connection:
+
+ host (str, optional): Specifies the hostname, IP address, or URL for the
+ database connection. The default scheme is "mysql".
+ user (str, optional): Database username.
+ password (str, optional): Database password.
+ port (int, optional): Database port. Defaults to 3306 for non-HTTP
+ connections, 80 for HTTP connections, and 443 for HTTPS connections.
+ database (str, optional): Database name.
+
+ Additional optional arguments provide further customization over the
+ database connection:
+
+ pure_python (bool, optional): Toggles the connector mode. If True,
+ operates in pure Python mode.
+ local_infile (bool, optional): Allows local file uploads.
+ charset (str, optional): Specifies the character set for string values.
+ ssl_key (str, optional): Specifies the path of the file containing the SSL
+ key.
+ ssl_cert (str, optional): Specifies the path of the file containing the SSL
+ certificate.
+ ssl_ca (str, optional): Specifies the path of the file containing the SSL
+ certificate authority.
+ ssl_cipher (str, optional): Sets the SSL cipher list.
+ ssl_disabled (bool, optional): Disables SSL usage.
+ ssl_verify_cert (bool, optional): Verifies the server's certificate.
+ Automatically enabled if ``ssl_ca`` is specified.
+ ssl_verify_identity (bool, optional): Verifies the server's identity.
+ conv (dict[int, Callable], optional): A dictionary of data conversion
+ functions.
+ credential_type (str, optional): Specifies the type of authentication to
+ use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO.
+ autocommit (bool, optional): Enables autocommits.
+ results_type (str, optional): Determines the structure of the query results:
+ tuples, namedtuples, dicts.
+ results_format (str, optional): Deprecated. This option has been renamed to
+ results_type.
+
+ Examples:
+ Basic Usage:
+
+ .. code-block:: python
+
+ from langchain_community.embeddings import OpenAIEmbeddings
+ from langchain_community.vectorstores import SingleStoreDB
+
+ vectorstore = SingleStoreDB(
+ OpenAIEmbeddings(),
+ host="https://user:password@127.0.0.1:3306/database"
+ )
+
+ Advanced Usage:
+
+ .. code-block:: python
+
+ from langchain_community.embeddings import OpenAIEmbeddings
+ from langchain_community.vectorstores import SingleStoreDB
+
+ vectorstore = SingleStoreDB(
+ OpenAIEmbeddings(),
+ distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE,
+ host="127.0.0.1",
+ port=3306,
+ user="user",
+ password="password",
+ database="db",
+ table_name="my_custom_table",
+ pool_size=10,
+ timeout=60,
+ )
+
+ Using environment variables:
+
+ .. code-block:: python
+
+ from langchain_community.embeddings import OpenAIEmbeddings
+ from langchain_community.vectorstores import SingleStoreDB
+
+ os.environ['SINGLESTOREDB_URL'] = 'me:p455w0rd@s2-host.com/my_db'
+ vectorstore = SingleStoreDB(OpenAIEmbeddings())
+ """
+
+ self.embedding = embedding
+ self.distance_strategy = distance_strategy
+ self.table_name = self._sanitize_input(table_name)
+ self.content_field = self._sanitize_input(content_field)
+ self.metadata_field = self._sanitize_input(metadata_field)
+ self.vector_field = self._sanitize_input(vector_field)
+
+ # Pass the rest of the kwargs to the connection.
+ self.connection_kwargs = kwargs
+
+ # Add program name and version to connection attributes.
+ if "conn_attrs" not in self.connection_kwargs:
+ self.connection_kwargs["conn_attrs"] = dict()
+
+ self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
+ self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
+
+ # Create connection pool.
+ self.connection_pool = QueuePool(
+ self._get_connection,
+ max_overflow=max_overflow,
+ pool_size=pool_size,
+ timeout=timeout,
+ )
+ self._create_table()
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding
+
+ def _sanitize_input(self, input_str: str) -> str:
+ # Remove characters that are not alphanumeric or underscores
+ return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ return self._max_inner_product_relevance_score_fn
+
+ def _create_table(self: SingleStoreDB) -> None:
+ """Create table if it doesn't exist."""
+ conn = self.connection_pool.connect()
+ try:
+ cur = conn.cursor()
+ try:
+ cur.execute(
+ """CREATE TABLE IF NOT EXISTS {}
+ ({} TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci,
+ {} BLOB, {} JSON);""".format(
+ self.table_name,
+ self.content_field,
+ self.vector_field,
+ self.metadata_field,
+ ),
+ )
+ finally:
+ cur.close()
+ finally:
+ conn.close()
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ embeddings: Optional[List[List[float]]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add more texts to the vectorstore.
+
+ Args:
+ texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
+ metadatas (Optional[List[dict]], optional): Optional list of metadatas.
+ Defaults to None.
+ embeddings (Optional[List[List[float]]], optional): Optional pre-generated
+ embeddings. Defaults to None.
+
+ Returns:
+ List[str]: empty list
+ """
+ conn = self.connection_pool.connect()
+ try:
+ cur = conn.cursor()
+ try:
+ # Write data to singlestore db
+ for i, text in enumerate(texts):
+ # Use provided values by default or fallback
+ metadata = metadatas[i] if metadatas else {}
+ embedding = (
+ embeddings[i]
+ if embeddings
+ else self.embedding.embed_documents([text])[0]
+ )
+ cur.execute(
+ "INSERT INTO {} VALUES (%s, JSON_ARRAY_PACK(%s), %s)".format(
+ self.table_name
+ ),
+ (
+ text,
+ "[{}]".format(",".join(map(str, embedding))),
+ json.dumps(metadata),
+ ),
+ )
+ finally:
+ cur.close()
+ finally:
+ conn.close()
+ return []
+
+ def similarity_search(
+ self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
+ ) -> List[Document]:
+ """Returns the most similar indexed documents to the query text.
+
+ Uses cosine similarity.
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+ filter (dict): A dictionary of metadata fields and values to filter by.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query text.
+
+ Examples:
+ .. code-block:: python
+ from langchain_community.vectorstores import SingleStoreDB
+ from langchain_community.embeddings import OpenAIEmbeddings
+ s2 = SingleStoreDB.from_documents(
+ docs,
+ OpenAIEmbeddings(),
+ host="username:password@localhost:3306/database"
+ )
+ s2.similarity_search("query text", 1,
+ {"metadata_field": "metadata_value"})
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query=query, k=k, filter=filter
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, filter: Optional[dict] = None
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query. Uses cosine similarity.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: A dictionary of metadata fields and values to filter by.
+ Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ # Creates embedding vector from user query
+ embedding = self.embedding.embed_query(query)
+ conn = self.connection_pool.connect()
+ result = []
+ where_clause: str = ""
+ where_clause_values: List[Any] = []
+ if filter:
+ where_clause = "WHERE "
+ arguments = []
+
+ def build_where_clause(
+ where_clause_values: List[Any],
+ sub_filter: dict,
+ prefix_args: Optional[List[str]] = None,
+ ) -> None:
+ prefix_args = prefix_args or []
+ for key in sub_filter.keys():
+ if isinstance(sub_filter[key], dict):
+ build_where_clause(
+ where_clause_values, sub_filter[key], prefix_args + [key]
+ )
+ else:
+ arguments.append(
+ "JSON_EXTRACT_JSON({}, {}) = %s".format(
+ self.metadata_field,
+ ", ".join(["%s"] * (len(prefix_args) + 1)),
+ )
+ )
+ where_clause_values += prefix_args + [key]
+ where_clause_values.append(json.dumps(sub_filter[key]))
+
+ build_where_clause(where_clause_values, filter)
+ where_clause += " AND ".join(arguments)
+
+ try:
+ cur = conn.cursor()
+ try:
+ cur.execute(
+ """SELECT {}, {}, {}({}, JSON_ARRAY_PACK(%s)) as __score
+ FROM {} {} ORDER BY __score {} LIMIT %s""".format(
+ self.content_field,
+ self.metadata_field,
+ self.distance_strategy.name
+ if isinstance(self.distance_strategy, DistanceStrategy)
+ else self.distance_strategy,
+ self.vector_field,
+ self.table_name,
+ where_clause,
+ ORDERING_DIRECTIVE[self.distance_strategy],
+ ),
+ ("[{}]".format(",".join(map(str, embedding))),)
+ + tuple(where_clause_values)
+ + (k,),
+ )
+
+ for row in cur.fetchall():
+ doc = Document(page_content=row[0], metadata=row[1])
+ result.append((doc, float(row[2])))
+ finally:
+ cur.close()
+ finally:
+ conn.close()
+ return result
+
+ @classmethod
+ def from_texts(
+ cls: Type[SingleStoreDB],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ table_name: str = "embeddings",
+ content_field: str = "content",
+ metadata_field: str = "metadata",
+ vector_field: str = "vector",
+ pool_size: int = 5,
+ max_overflow: int = 10,
+ timeout: float = 30,
+ **kwargs: Any,
+ ) -> SingleStoreDB:
+ """Create a SingleStoreDB vectorstore from raw documents.
+ This is a user-friendly interface that:
+ 1. Embeds documents.
+ 2. Creates a new table for the embeddings in SingleStoreDB.
+ 3. Adds the documents to the newly created table.
+ This is intended to be a quick way to get started.
+ Example:
+ .. code-block:: python
+ from langchain_community.vectorstores import SingleStoreDB
+ from langchain_community.embeddings import OpenAIEmbeddings
+ s2 = SingleStoreDB.from_texts(
+ texts,
+ OpenAIEmbeddings(),
+ host="username:password@localhost:3306/database"
+ )
+ """
+
+ instance = cls(
+ embedding,
+ distance_strategy=distance_strategy,
+ table_name=table_name,
+ content_field=content_field,
+ metadata_field=metadata_field,
+ vector_field=vector_field,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ timeout=timeout,
+ **kwargs,
+ )
+ instance.add_texts(texts, metadatas, embedding.embed_documents(texts), **kwargs)
+ return instance
+
+
+# SingleStoreDBRetriever is not needed, but we keep it for backwards compatibility
+SingleStoreDBRetriever = VectorStoreRetriever
diff --git a/libs/community/langchain_community/vectorstores/sklearn.py b/libs/community/langchain_community/vectorstores/sklearn.py
new file mode 100644
index 00000000000..8f9d415e3d6
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/sklearn.py
@@ -0,0 +1,355 @@
+""" Wrapper around scikit-learn NearestNeighbors implementation.
+
+The vector store can be persisted in json, bson or parquet format.
+"""
+
+import json
+import math
+import os
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
+from uuid import uuid4
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import guard_import
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+DEFAULT_K = 4 # Number of Documents to return.
+DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
+
+
+class BaseSerializer(ABC):
+ """Base class for serializing data."""
+
+ def __init__(self, persist_path: str) -> None:
+ self.persist_path = persist_path
+
+ @classmethod
+ @abstractmethod
+ def extension(cls) -> str:
+ """The file extension suggested by this serializer (without dot)."""
+
+ @abstractmethod
+ def save(self, data: Any) -> None:
+ """Saves the data to the persist_path"""
+
+ @abstractmethod
+ def load(self) -> Any:
+ """Loads the data from the persist_path"""
+
+
+class JsonSerializer(BaseSerializer):
+ """Serializes data in json using the json package from python standard library."""
+
+ @classmethod
+ def extension(cls) -> str:
+ return "json"
+
+ def save(self, data: Any) -> None:
+ with open(self.persist_path, "w") as fp:
+ json.dump(data, fp)
+
+ def load(self) -> Any:
+ with open(self.persist_path, "r") as fp:
+ return json.load(fp)
+
+
+class BsonSerializer(BaseSerializer):
+ """Serializes data in binary json using the `bson` python package."""
+
+ def __init__(self, persist_path: str) -> None:
+ super().__init__(persist_path)
+ self.bson = guard_import("bson")
+
+ @classmethod
+ def extension(cls) -> str:
+ return "bson"
+
+ def save(self, data: Any) -> None:
+ with open(self.persist_path, "wb") as fp:
+ fp.write(self.bson.dumps(data))
+
+ def load(self) -> Any:
+ with open(self.persist_path, "rb") as fp:
+ return self.bson.loads(fp.read())
+
+
+class ParquetSerializer(BaseSerializer):
+ """Serializes data in `Apache Parquet` format using the `pyarrow` package."""
+
+ def __init__(self, persist_path: str) -> None:
+ super().__init__(persist_path)
+ self.pd = guard_import("pandas")
+ self.pa = guard_import("pyarrow")
+ self.pq = guard_import("pyarrow.parquet")
+
+ @classmethod
+ def extension(cls) -> str:
+ return "parquet"
+
+ def save(self, data: Any) -> None:
+ df = self.pd.DataFrame(data)
+ table = self.pa.Table.from_pandas(df)
+ if os.path.exists(self.persist_path):
+ backup_path = str(self.persist_path) + "-backup"
+ os.rename(self.persist_path, backup_path)
+ try:
+ self.pq.write_table(table, self.persist_path)
+ except Exception as exc:
+ os.rename(backup_path, self.persist_path)
+ raise exc
+ else:
+ os.remove(backup_path)
+ else:
+ self.pq.write_table(table, self.persist_path)
+
+ def load(self) -> Any:
+ table = self.pq.read_table(self.persist_path)
+ df = table.to_pandas()
+ return {col: series.tolist() for col, series in df.items()}
+
+
+SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = {
+ "json": JsonSerializer,
+ "bson": BsonSerializer,
+ "parquet": ParquetSerializer,
+}
+
+
+class SKLearnVectorStoreException(RuntimeError):
+ """Exception raised by SKLearnVectorStore."""
+
+ pass
+
+
+class SKLearnVectorStore(VectorStore):
+ """Simple in-memory vector store based on the `scikit-learn` library
+ `NearestNeighbors` implementation."""
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ *,
+ persist_path: Optional[str] = None,
+ serializer: Literal["json", "bson", "parquet"] = "json",
+ metric: str = "cosine",
+ **kwargs: Any,
+ ) -> None:
+ np = guard_import("numpy")
+ sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn")
+
+ # non-persistent properties
+ self._np = np
+ self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs)
+ self._neighbors_fitted = False
+ self._embedding_function = embedding
+ self._persist_path = persist_path
+ self._serializer: Optional[BaseSerializer] = None
+ if self._persist_path is not None:
+ serializer_cls = SERIALIZER_MAP[serializer]
+ self._serializer = serializer_cls(persist_path=self._persist_path)
+
+ # data properties
+ self._embeddings: List[List[float]] = []
+ self._texts: List[str] = []
+ self._metadatas: List[dict] = []
+ self._ids: List[str] = []
+
+ # cache properties
+ self._embeddings_np: Any = np.asarray([])
+
+ if self._persist_path is not None and os.path.isfile(self._persist_path):
+ self._load()
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embedding_function
+
+ def persist(self) -> None:
+ if self._serializer is None:
+ raise SKLearnVectorStoreException(
+ "You must specify a persist_path on creation to persist the "
+ "collection."
+ )
+ data = {
+ "ids": self._ids,
+ "texts": self._texts,
+ "metadatas": self._metadatas,
+ "embeddings": self._embeddings,
+ }
+ self._serializer.save(data)
+
+ def _load(self) -> None:
+ if self._serializer is None:
+ raise SKLearnVectorStoreException(
+ "You must specify a persist_path on creation to load the " "collection."
+ )
+ data = self._serializer.load()
+ self._embeddings = data["embeddings"]
+ self._texts = data["texts"]
+ self._metadatas = data["metadatas"]
+ self._ids = data["ids"]
+ self._update_neighbors()
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ _texts = list(texts)
+ _ids = ids or [str(uuid4()) for _ in _texts]
+ self._texts.extend(_texts)
+ self._embeddings.extend(self._embedding_function.embed_documents(_texts))
+ self._metadatas.extend(metadatas or ([{}] * len(_texts)))
+ self._ids.extend(_ids)
+ self._update_neighbors()
+ return _ids
+
+ def _update_neighbors(self) -> None:
+ if len(self._embeddings) == 0:
+ raise SKLearnVectorStoreException(
+ "No data was added to SKLearnVectorStore."
+ )
+ self._embeddings_np = self._np.asarray(self._embeddings)
+ self._neighbors.fit(self._embeddings_np)
+ self._neighbors_fitted = True
+
+ def _similarity_index_search_with_score(
+ self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any
+ ) -> List[Tuple[int, float]]:
+ """Search k embeddings similar to the query embedding. Returns a list of
+ (index, distance) tuples."""
+ if not self._neighbors_fitted:
+ raise SKLearnVectorStoreException(
+ "No data was added to SKLearnVectorStore."
+ )
+ neigh_dists, neigh_idxs = self._neighbors.kneighbors(
+ [query_embedding], n_neighbors=k
+ )
+ return list(zip(neigh_idxs[0], neigh_dists[0]))
+
+ def similarity_search_with_score(
+ self, query: str, *, k: int = DEFAULT_K, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ query_embedding = self._embedding_function.embed_query(query)
+ indices_dists = self._similarity_index_search_with_score(
+ query_embedding, k=k, **kwargs
+ )
+ return [
+ (
+ Document(
+ page_content=self._texts[idx],
+ metadata={"id": self._ids[idx], **self._metadatas[idx]},
+ ),
+ dist,
+ )
+ for idx, dist in indices_dists
+ ]
+
+ def similarity_search(
+ self, query: str, k: int = DEFAULT_K, **kwargs: Any
+ ) -> List[Document]:
+ docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
+ return [doc for doc, _ in docs_scores]
+
+ def _similarity_search_with_relevance_scores(
+ self, query: str, k: int = DEFAULT_K, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ docs_dists = self.similarity_search_with_score(query, k=k, **kwargs)
+ docs, dists = zip(*docs_dists)
+ scores = [1 / math.exp(dist) for dist in dists]
+ return list(zip(list(docs), scores))
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = DEFAULT_K,
+ fetch_k: int = DEFAULT_FETCH_K,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ indices_dists = self._similarity_index_search_with_score(
+ embedding, k=fetch_k, **kwargs
+ )
+ indices, _ = zip(*indices_dists)
+ result_embeddings = self._embeddings_np[indices,]
+ mmr_selected = maximal_marginal_relevance(
+ self._np.array(embedding, dtype=self._np.float32),
+ result_embeddings,
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ mmr_indices = [indices[i] for i in mmr_selected]
+ return [
+ Document(
+ page_content=self._texts[idx],
+ metadata={"id": self._ids[idx], **self._metadatas[idx]},
+ )
+ for idx in mmr_indices
+ ]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = DEFAULT_K,
+ fetch_k: int = DEFAULT_FETCH_K,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ if self._embedding_function is None:
+ raise ValueError(
+ "For MMR search, you must specify an embedding function on creation."
+ )
+
+ embedding = self._embedding_function.embed_query(query)
+ docs = self.max_marginal_relevance_search_by_vector(
+ embedding, k, fetch_k, lambda_mul=lambda_mult
+ )
+ return docs
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ persist_path: Optional[str] = None,
+ **kwargs: Any,
+ ) -> "SKLearnVectorStore":
+ vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs)
+ vs.add_texts(texts, metadatas=metadatas, ids=ids)
+ return vs
diff --git a/libs/community/langchain_community/vectorstores/sqlitevss.py b/libs/community/langchain_community/vectorstores/sqlitevss.py
new file mode 100644
index 00000000000..60551d02066
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/sqlitevss.py
@@ -0,0 +1,227 @@
+from __future__ import annotations
+
+import json
+import logging
+import warnings
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ import sqlite3
+
+logger = logging.getLogger(__name__)
+
+
+class SQLiteVSS(VectorStore):
+ """Wrapper around SQLite with vss extension as a vector database.
+ To use, you should have the ``sqlite-vss`` python package installed.
+ Example:
+ .. code-block:: python
+ from langchain_community.vectorstores import SQLiteVSS
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ ...
+ """
+
+ def __init__(
+ self,
+ table: str,
+ connection: Optional[sqlite3.Connection],
+ embedding: Embeddings,
+ db_file: str = "vss.db",
+ ):
+ """Initialize with sqlite client with vss extension."""
+ try:
+ import sqlite_vss # noqa # pylint: disable=unused-import
+ except ImportError:
+ raise ImportError(
+ "Could not import sqlite-vss python package. "
+ "Please install it with `pip install sqlite-vss`."
+ )
+
+ if not connection:
+ connection = self.create_connection(db_file)
+
+ if not isinstance(embedding, Embeddings):
+ warnings.warn("embeddings input must be Embeddings object.")
+
+ self._connection = connection
+ self._table = table
+ self._embedding = embedding
+
+ self.create_table_if_not_exists()
+
+ def create_table_if_not_exists(self) -> None:
+ self._connection.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {self._table}
+ (
+ rowid INTEGER PRIMARY KEY AUTOINCREMENT,
+ text TEXT,
+ metadata BLOB,
+ text_embedding BLOB
+ )
+ ;
+ """
+ )
+ self._connection.execute(
+ f"""
+ CREATE VIRTUAL TABLE IF NOT EXISTS vss_{self._table} USING vss0(
+ text_embedding({self.get_dimensionality()})
+ );
+ """
+ )
+ self._connection.execute(
+ f"""
+ CREATE TRIGGER IF NOT EXISTS embed_text
+ AFTER INSERT ON {self._table}
+ BEGIN
+ INSERT INTO vss_{self._table}(rowid, text_embedding)
+ VALUES (new.rowid, new.text_embedding)
+ ;
+ END;
+ """
+ )
+ self._connection.commit()
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add more texts to the vectorstore index.
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+ """
+ max_id = self._connection.execute(
+ f"SELECT max(rowid) as rowid FROM {self._table}"
+ ).fetchone()["rowid"]
+ if max_id is None: # no text added yet
+ max_id = 0
+
+ embeds = self._embedding.embed_documents(list(texts))
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+ data_input = [
+ (text, json.dumps(metadata), json.dumps(embed))
+ for text, metadata, embed in zip(texts, metadatas, embeds)
+ ]
+ self._connection.executemany(
+ f"INSERT INTO {self._table}(text, metadata, text_embedding) "
+ f"VALUES (?,?,?)",
+ data_input,
+ )
+ self._connection.commit()
+ # pulling every ids we just inserted
+ results = self._connection.execute(
+ f"SELECT rowid FROM {self._table} WHERE rowid > {max_id}"
+ )
+ return [row["rowid"] for row in results]
+
+ def similarity_search_with_score_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ sql_query = f"""
+ SELECT
+ text,
+ metadata,
+ distance
+ FROM {self._table} e
+ INNER JOIN vss_{self._table} v on v.rowid = e.rowid
+ WHERE vss_search(
+ v.text_embedding,
+ vss_search_params('{json.dumps(embedding)}', {k})
+ )
+ """
+ cursor = self._connection.cursor()
+ cursor.execute(sql_query)
+ results = cursor.fetchall()
+
+ documents = []
+ for row in results:
+ metadata = json.loads(row["metadata"]) or {}
+ doc = Document(page_content=row["text"], metadata=metadata)
+ documents.append((doc, row["distance"]))
+
+ return documents
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query."""
+ embedding = self._embedding.embed_query(query)
+ documents = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k
+ )
+ return [doc for doc, _ in documents]
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query."""
+ embedding = self._embedding.embed_query(query)
+ documents = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k
+ )
+ return documents
+
+ def similarity_search_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ documents = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k
+ )
+ return [doc for doc, _ in documents]
+
+ @classmethod
+ def from_texts(
+ cls: Type[SQLiteVSS],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ table: str = "langchain",
+ db_file: str = "vss.db",
+ **kwargs: Any,
+ ) -> SQLiteVSS:
+ """Return VectorStore initialized from texts and embeddings."""
+ connection = cls.create_connection(db_file)
+ vss = cls(
+ table=table, connection=connection, db_file=db_file, embedding=embedding
+ )
+ vss.add_texts(texts=texts, metadatas=metadatas)
+ return vss
+
+ @staticmethod
+ def create_connection(db_file: str) -> sqlite3.Connection:
+ import sqlite3
+
+ import sqlite_vss
+
+ connection = sqlite3.connect(db_file)
+ connection.row_factory = sqlite3.Row
+ connection.enable_load_extension(True)
+ sqlite_vss.load(connection)
+ connection.enable_load_extension(False)
+ return connection
+
+ def get_dimensionality(self) -> int:
+ """
+ Function that does a dummy embedding to figure out how many dimensions
+ this embedding function returns. Needed for the virtual table DDL.
+ """
+ dummy_text = "This is a dummy text"
+ dummy_embedding = self._embedding.embed_query(dummy_text)
+ return len(dummy_embedding)
diff --git a/libs/community/langchain_community/vectorstores/starrocks.py b/libs/community/langchain_community/vectorstores/starrocks.py
new file mode 100644
index 00000000000..568daf0d4a4
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/starrocks.py
@@ -0,0 +1,482 @@
+from __future__ import annotations
+
+import json
+import logging
+from hashlib import sha1
+from threading import Thread
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import BaseSettings
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger()
+DEBUG = False
+
+
+def has_mul_sub_str(s: str, *args: Any) -> bool:
+ """
+ Check if a string has multiple substrings.
+ Args:
+ s: The string to check
+ *args: The substrings to check for in the string
+
+ Returns:
+ bool: True if all substrings are present in the string, False otherwise
+ """
+ for a in args:
+ if a not in s:
+ return False
+ return True
+
+
+def debug_output(s: Any) -> None:
+ """
+ Print a debug message if DEBUG is True.
+ Args:
+ s: The message to print
+ """
+ if DEBUG:
+ print(s)
+
+
+def get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
+ """
+ Get a named result from a query.
+ Args:
+ connection: The connection to the database
+ query: The query to execute
+
+ Returns:
+ List[dict[str, Any]]: The result of the query
+ """
+ cursor = connection.cursor()
+ cursor.execute(query)
+ columns = cursor.description
+ result = []
+ for value in cursor.fetchall():
+ r = {}
+ for idx, datum in enumerate(value):
+ k = columns[idx][0]
+ r[k] = datum
+ result.append(r)
+ debug_output(result)
+ cursor.close()
+ return result
+
+
+class StarRocksSettings(BaseSettings):
+ """StarRocks client configuration.
+
+ Attribute:
+ StarRocks_host (str) : An URL to connect to MyScale backend.
+ Defaults to 'localhost'.
+ StarRocks_port (int) : URL port to connect with HTTP. Defaults to 8443.
+ username (str) : Username to login. Defaults to None.
+ password (str) : Password to login. Defaults to None.
+ database (str) : Database name to find the table. Defaults to 'default'.
+ table (str) : Table name to operate on.
+ Defaults to 'vector_table'.
+
+ column_map (Dict) : Column type map to project column name onto langchain
+ semantics. Must have keys: `text`, `id`, `vector`,
+ must be same size to number of columns. For example:
+ .. code-block:: python
+
+ {
+ 'id': 'text_id',
+ 'embedding': 'text_embedding',
+ 'document': 'text_plain',
+ 'metadata': 'metadata_dictionary_in_json',
+ }
+
+ Defaults to identity map.
+ """
+
+ host: str = "localhost"
+ port: int = 9030
+ username: str = "root"
+ password: str = ""
+
+ column_map: Dict[str, str] = {
+ "id": "id",
+ "document": "document",
+ "embedding": "embedding",
+ "metadata": "metadata",
+ }
+
+ database: str = "default"
+ table: str = "langchain"
+
+ def __getitem__(self, item: str) -> Any:
+ return getattr(self, item)
+
+ class Config:
+ env_file = ".env"
+ env_prefix = "starrocks_"
+ env_file_encoding = "utf-8"
+
+
+class StarRocks(VectorStore):
+ """`StarRocks` vector store.
+
+ You need a `pymysql` python package, and a valid account
+ to connect to StarRocks.
+
+ Right now StarRocks has only implemented `cosine_similarity` function to
+ compute distance between two vectors. And there is no vector inside right now,
+ so we have to iterate all vectors and compute spatial distance.
+
+ For more information, please visit
+ [StarRocks official site](https://www.starrocks.io/)
+ [StarRocks github](https://github.com/StarRocks/starrocks)
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ config: Optional[StarRocksSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ """StarRocks Wrapper to LangChain
+
+ embedding_function (Embeddings):
+ config (StarRocksSettings): Configuration to StarRocks Client
+ """
+ try:
+ import pymysql # type: ignore[import]
+ except ImportError:
+ raise ImportError(
+ "Could not import pymysql python package. "
+ "Please install it with `pip install pymysql`."
+ )
+ try:
+ from tqdm import tqdm
+
+ self.pgbar = tqdm
+ except ImportError:
+ # Just in case if tqdm is not installed
+ self.pgbar = lambda x, **kwargs: x
+ super().__init__()
+ if config is not None:
+ self.config = config
+ else:
+ self.config = StarRocksSettings()
+ assert self.config
+ assert self.config.host and self.config.port
+ assert self.config.column_map and self.config.database and self.config.table
+ for k in ["id", "embedding", "document", "metadata"]:
+ assert k in self.config.column_map
+
+ # initialize the schema
+ dim = len(embedding.embed_query("test"))
+
+ self.schema = f"""\
+CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
+ {self.config.column_map['id']} string,
+ {self.config.column_map['document']} string,
+ {self.config.column_map['embedding']} array,
+ {self.config.column_map['metadata']} string
+) ENGINE = OLAP PRIMARY KEY(id) DISTRIBUTED BY HASH(id) \
+ PROPERTIES ("replication_num" = "1")\
+"""
+ self.dim = dim
+ self.BS = "\\"
+ self.must_escape = ("\\", "'")
+ self.embedding_function = embedding
+ self.dist_order = "DESC"
+ debug_output(self.config)
+
+ # Create a connection to StarRocks
+ self.connection = pymysql.connect(
+ host=self.config.host,
+ port=self.config.port,
+ user=self.config.username,
+ password=self.config.password,
+ database=self.config.database,
+ **kwargs,
+ )
+
+ debug_output(self.schema)
+ get_named_result(self.connection, self.schema)
+
+ def escape_str(self, value: str) -> str:
+ return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_function
+
+ def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str:
+ ks = ",".join(column_names)
+ embed_tuple_index = tuple(column_names).index(
+ self.config.column_map["embedding"]
+ )
+ _data = []
+ for n in transac:
+ n = ",".join(
+ [
+ f"'{self.escape_str(str(_n))}'"
+ if idx != embed_tuple_index
+ else f"array{str(_n)}"
+ for (idx, _n) in enumerate(n)
+ ]
+ )
+ _data.append(f"({n})")
+ i_str = f"""
+ INSERT INTO
+ {self.config.database}.{self.config.table}({ks})
+ VALUES
+ {','.join(_data)}
+ """
+ return i_str
+
+ def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
+ _insert_query = self._build_insert_sql(transac, column_names)
+ debug_output(_insert_query)
+ get_named_result(self.connection, _insert_query)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ batch_size: int = 32,
+ ids: Optional[Iterable[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Insert more texts through the embeddings and add to the VectorStore.
+
+ Args:
+ texts: Iterable of strings to add to the VectorStore.
+ ids: Optional list of ids to associate with the texts.
+ batch_size: Batch size of insertion
+ metadata: Optional column data to be inserted
+
+ Returns:
+ List of ids from adding the texts into the VectorStore.
+
+ """
+ # Embed and create the documents
+ ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts]
+ colmap_ = self.config.column_map
+ transac = []
+ column_names = {
+ colmap_["id"]: ids,
+ colmap_["document"]: texts,
+ colmap_["embedding"]: self.embedding_function.embed_documents(list(texts)),
+ }
+ metadatas = metadatas or [{} for _ in texts]
+ column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
+ assert len(set(colmap_) - set(column_names)) >= 0
+ keys, values = zip(*column_names.items())
+ try:
+ t = None
+ for v in self.pgbar(
+ zip(*values), desc="Inserting data...", total=len(metadatas)
+ ):
+ assert (
+ len(v[keys.index(self.config.column_map["embedding"])]) == self.dim
+ )
+ transac.append(v)
+ if len(transac) == batch_size:
+ if t:
+ t.join()
+ t = Thread(target=self._insert, args=[transac, keys])
+ t.start()
+ transac = []
+ if len(transac) > 0:
+ if t:
+ t.join()
+ self._insert(transac, keys)
+ return [i for i in ids]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ config: Optional[StarRocksSettings] = None,
+ text_ids: Optional[Iterable[str]] = None,
+ batch_size: int = 32,
+ **kwargs: Any,
+ ) -> StarRocks:
+ """Create StarRocks wrapper with existing texts
+
+ Args:
+ embedding_function (Embeddings): Function to extract text embedding
+ texts (Iterable[str]): List or tuple of strings to be added
+ config (StarRocksSettings, Optional): StarRocks configuration
+ text_ids (Optional[Iterable], optional): IDs for the texts.
+ Defaults to None.
+ batch_size (int, optional): Batchsize when transmitting data to StarRocks.
+ Defaults to 32.
+ metadata (List[dict], optional): metadata to texts. Defaults to None.
+ Returns:
+ StarRocks Index
+ """
+ ctx = cls(embedding, config, **kwargs)
+ ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas)
+ return ctx
+
+ def __repr__(self) -> str:
+ """Text representation for StarRocks Vector Store, prints backends, username
+ and schemas. Easy to use with `str(StarRocks())`
+
+ Returns:
+ repr: string to show connection info and data schema
+ """
+ _repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ "
+ _repr += f"{self.config.host}:{self.config.port}\033[0m\n\n"
+ _repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n"
+ width = 25
+ fields = 3
+ _repr += "-" * (width * fields + 1) + "\n"
+ columns = ["name", "type", "key"]
+ _repr += f"|\033[94m{columns[0]:24s}\033[0m|\033[96m{columns[1]:24s}"
+ _repr += f"\033[0m|\033[96m{columns[2]:24s}\033[0m|\n"
+ _repr += "-" * (width * fields + 1) + "\n"
+ q_str = f"DESC {self.config.database}.{self.config.table}"
+ debug_output(q_str)
+ rs = get_named_result(self.connection, q_str)
+ for r in rs:
+ _repr += f"|\033[94m{r['Field']:24s}\033[0m|\033[96m{r['Type']:24s}"
+ _repr += f"\033[0m|\033[96m{r['Key']:24s}\033[0m|\n"
+ _repr += "-" * (width * fields + 1) + "\n"
+ return _repr
+
+ def _build_query_sql(
+ self, q_emb: List[float], topk: int, where_str: Optional[str] = None
+ ) -> str:
+ q_emb_str = ",".join(map(str, q_emb))
+ if where_str:
+ where_str = f"WHERE {where_str}"
+ else:
+ where_str = ""
+
+ q_str = f"""
+ SELECT {self.config.column_map['document']},
+ {self.config.column_map['metadata']},
+ cosine_similarity_norm(array[{q_emb_str}],
+ {self.config.column_map['embedding']}) as dist
+ FROM {self.config.database}.{self.config.table}
+ {where_str}
+ ORDER BY dist {self.dist_order}
+ LIMIT {topk}
+ """
+
+ debug_output(q_str)
+ return q_str
+
+ def similarity_search(
+ self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
+ ) -> List[Document]:
+ """Perform a similarity search with StarRocks
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of Documents
+ """
+ return self.similarity_search_by_vector(
+ self.embedding_function.embed_query(query), k, where_str, **kwargs
+ )
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ where_str: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a similarity search with StarRocks by vectors
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of (Document, similarity)
+ """
+ q_str = self._build_query_sql(embedding, k, where_str)
+ try:
+ return [
+ Document(
+ page_content=r[self.config.column_map["document"]],
+ metadata=json.loads(r[self.config.column_map["metadata"]]),
+ )
+ for r in get_named_result(self.connection, q_str)
+ ]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ def similarity_search_with_relevance_scores(
+ self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Perform a similarity search with StarRocks
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+ where_str (Optional[str], optional): where condition string.
+ Defaults to None.
+
+ NOTE: Please do not let end-user to fill this and always be aware
+ of SQL injection. When dealing with metadatas, remember to
+ use `{self.metadata_column}.attribute` instead of `attribute`
+ alone. The default name for it is `metadata`.
+
+ Returns:
+ List[Document]: List of documents
+ """
+ q_str = self._build_query_sql(
+ self.embedding_function.embed_query(query), k, where_str
+ )
+ try:
+ return [
+ (
+ Document(
+ page_content=r[self.config.column_map["document"]],
+ metadata=json.loads(r[self.config.column_map["metadata"]]),
+ ),
+ r["dist"],
+ )
+ for r in get_named_result(self.connection, q_str)
+ ]
+ except Exception as e:
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
+
+ def drop(self) -> None:
+ """
+ Helper function: Drop data
+ """
+ get_named_result(
+ self.connection,
+ f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}",
+ )
+
+ @property
+ def metadata_column(self) -> str:
+ return self.config.column_map["metadata"]
diff --git a/libs/community/langchain_community/vectorstores/supabase.py b/libs/community/langchain_community/vectorstores/supabase.py
new file mode 100644
index 00000000000..b3f84c65eb9
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/supabase.py
@@ -0,0 +1,466 @@
+from __future__ import annotations
+
+import uuid
+from itertools import repeat
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+if TYPE_CHECKING:
+ import supabase
+
+
+class SupabaseVectorStore(VectorStore):
+ """`Supabase Postgres` vector store.
+
+ It assumes you have the `pgvector`
+ extension installed and a `match_documents` (or similar) function. For more details:
+ https://integrations.langchain.com/vectorstores?integration_name=SupabaseVectorStore
+
+ You can implement your own `match_documents` function in order to limit the search
+ space to a subset of documents based on your own authorization or business logic.
+
+ Note that the Supabase Python client does not yet support async operations.
+
+ If you'd like to use `max_marginal_relevance_search`, please review the instructions
+ below on modifying the `match_documents` function to return matched embeddings.
+
+
+ Examples:
+
+ .. code-block:: python
+
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ from langchain_core.documents import Document
+ from langchain_community.vectorstores import SupabaseVectorStore
+ from supabase.client import create_client
+
+ docs = [
+ Document(page_content="foo", metadata={"id": 1}),
+ ]
+ embeddings = OpenAIEmbeddings()
+ supabase_client = create_client("my_supabase_url", "my_supabase_key")
+ vector_store = SupabaseVectorStore.from_documents(
+ docs,
+ embeddings,
+ client=supabase_client,
+ table_name="documents",
+ query_name="match_documents",
+ chunk_size=500,
+ )
+
+ To load from an existing table:
+
+ .. code-block:: python
+
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ from langchain_community.vectorstores import SupabaseVectorStore
+ from supabase.client import create_client
+
+
+ embeddings = OpenAIEmbeddings()
+ supabase_client = create_client("my_supabase_url", "my_supabase_key")
+ vector_store = SupabaseVectorStore(
+ client=supabase_client,
+ embedding=embeddings,
+ table_name="documents",
+ query_name="match_documents",
+ )
+
+ """
+
+ def __init__(
+ self,
+ client: supabase.client.Client,
+ embedding: Embeddings,
+ table_name: str,
+ chunk_size: int = 500,
+ query_name: Union[str, None] = None,
+ ) -> None:
+ """Initialize with supabase client."""
+ try:
+ import supabase # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import supabase python package. "
+ "Please install it with `pip install supabase`."
+ )
+
+ self._client = client
+ self._embedding: Embeddings = embedding
+ self.table_name = table_name or "documents"
+ self.query_name = query_name or "match_documents"
+ self.chunk_size = chunk_size or 500
+ # According to the SupabaseVectorStore JS implementation, the best chunk size
+ # is 500. Though for large datasets it can be too large so it is configurable.
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embedding
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
+ docs = self._texts_to_documents(texts, metadatas)
+
+ vectors = self._embedding.embed_documents(list(texts))
+ return self.add_vectors(vectors, docs, ids)
+
+ @classmethod
+ def from_texts(
+ cls: Type["SupabaseVectorStore"],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ client: Optional[supabase.client.Client] = None,
+ table_name: Optional[str] = "documents",
+ query_name: Union[str, None] = "match_documents",
+ chunk_size: int = 500,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> "SupabaseVectorStore":
+ """Return VectorStore initialized from texts and embeddings."""
+
+ if not client:
+ raise ValueError("Supabase client is required.")
+
+ if not table_name:
+ raise ValueError("Supabase document table_name is required.")
+
+ embeddings = embedding.embed_documents(texts)
+ ids = [str(uuid.uuid4()) for _ in texts]
+ docs = cls._texts_to_documents(texts, metadatas)
+ cls._add_vectors(client, table_name, embeddings, docs, ids, chunk_size)
+
+ return cls(
+ client=client,
+ embedding=embedding,
+ table_name=table_name,
+ query_name=query_name,
+ chunk_size=chunk_size,
+ )
+
+ def add_vectors(
+ self,
+ vectors: List[List[float]],
+ documents: List[Document],
+ ids: List[str],
+ ) -> List[str]:
+ return self._add_vectors(
+ self._client, self.table_name, vectors, documents, ids, self.chunk_size
+ )
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ vector = self._embedding.embed_query(query)
+ return self.similarity_search_by_vector(vector, k=k, filter=filter, **kwargs)
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ result = self.similarity_search_by_vector_with_relevance_scores(
+ embedding, k=k, filter=filter, **kwargs
+ )
+
+ documents = [doc for doc, _ in result]
+
+ return documents
+
+ def similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ vector = self._embedding.embed_query(query)
+ return self.similarity_search_by_vector_with_relevance_scores(
+ vector, k=k, filter=filter
+ )
+
+ def match_args(
+ self, query: List[float], filter: Optional[Dict[str, Any]]
+ ) -> Dict[str, Any]:
+ ret: Dict[str, Any] = dict(query_embedding=query)
+ if filter:
+ ret["filter"] = filter
+ return ret
+
+ def similarity_search_by_vector_with_relevance_scores(
+ self,
+ query: List[float],
+ k: int,
+ filter: Optional[Dict[str, Any]] = None,
+ postgrest_filter: Optional[str] = None,
+ ) -> List[Tuple[Document, float]]:
+ match_documents_params = self.match_args(query, filter)
+ query_builder = self._client.rpc(self.query_name, match_documents_params)
+
+ if postgrest_filter:
+ query_builder.params = query_builder.params.set(
+ "and", f"({postgrest_filter})"
+ )
+
+ query_builder.params = query_builder.params.set("limit", k)
+
+ res = query_builder.execute()
+
+ match_result = [
+ (
+ Document(
+ metadata=search.get("metadata", {}), # type: ignore
+ page_content=search.get("content", ""),
+ ),
+ search.get("similarity", 0.0),
+ )
+ for search in res.data
+ if search.get("content")
+ ]
+
+ return match_result
+
+ def similarity_search_by_vector_returning_embeddings(
+ self,
+ query: List[float],
+ k: int,
+ filter: Optional[Dict[str, Any]] = None,
+ postgrest_filter: Optional[str] = None,
+ ) -> List[Tuple[Document, float, np.ndarray[np.float32, Any]]]:
+ match_documents_params = self.match_args(query, filter)
+ query_builder = self._client.rpc(self.query_name, match_documents_params)
+
+ if postgrest_filter:
+ query_builder.params = query_builder.params.set(
+ "and", f"({postgrest_filter})"
+ )
+
+ query_builder.params = query_builder.params.set("limit", k)
+
+ res = query_builder.execute()
+
+ match_result = [
+ (
+ Document(
+ metadata=search.get("metadata", {}), # type: ignore
+ page_content=search.get("content", ""),
+ ),
+ search.get("similarity", 0.0),
+ # Supabase returns a vector type as its string represation (!).
+ # This is a hack to convert the string to numpy array.
+ np.fromstring(
+ search.get("embedding", "").strip("[]"), np.float32, sep=","
+ ),
+ )
+ for search in res.data
+ if search.get("content")
+ ]
+
+ return match_result
+
+ @staticmethod
+ def _texts_to_documents(
+ texts: Iterable[str],
+ metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
+ ) -> List[Document]:
+ """Return list of Documents from list of texts and metadatas."""
+ if metadatas is None:
+ metadatas = repeat({})
+
+ docs = [
+ Document(page_content=text, metadata=metadata)
+ for text, metadata in zip(texts, metadatas)
+ ]
+
+ return docs
+
+ @staticmethod
+ def _add_vectors(
+ client: supabase.client.Client,
+ table_name: str,
+ vectors: List[List[float]],
+ documents: List[Document],
+ ids: List[str],
+ chunk_size: int,
+ ) -> List[str]:
+ """Add vectors to Supabase table."""
+
+ rows: List[Dict[str, Any]] = [
+ {
+ "id": ids[idx],
+ "content": documents[idx].page_content,
+ "embedding": embedding,
+ "metadata": documents[idx].metadata, # type: ignore
+ }
+ for idx, embedding in enumerate(vectors)
+ ]
+
+ id_list: List[str] = []
+ for i in range(0, len(rows), chunk_size):
+ chunk = rows[i : i + chunk_size]
+
+ result = client.from_(table_name).upsert(chunk).execute() # type: ignore
+
+ if len(result.data) == 0:
+ raise Exception("Error inserting: No rows added")
+
+ # VectorStore.add_vectors returns ids as strings
+ ids = [str(i.get("id")) for i in result.data if i.get("id")]
+
+ id_list.extend(ids)
+
+ return id_list
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ result = self.similarity_search_by_vector_returning_embeddings(
+ embedding, fetch_k
+ )
+
+ matched_documents = [doc_tuple[0] for doc_tuple in result]
+ matched_embeddings = [doc_tuple[2] for doc_tuple in result]
+
+ mmr_selected = maximal_marginal_relevance(
+ np.array([embedding], dtype=np.float32),
+ matched_embeddings,
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+
+ filtered_documents = [matched_documents[i] for i in mmr_selected]
+
+ return filtered_documents
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+
+ `max_marginal_relevance_search` requires that `query_name` returns matched
+ embeddings alongside the match documents. The following function
+ demonstrates how to do this:
+
+ ```sql
+ CREATE FUNCTION match_documents_embeddings(query_embedding vector(1536),
+ match_count int)
+ RETURNS TABLE(
+ id uuid,
+ content text,
+ metadata jsonb,
+ embedding vector(1536),
+ similarity float)
+ LANGUAGE plpgsql
+ AS $$
+ # variable_conflict use_column
+ BEGIN
+ RETURN query
+ SELECT
+ id,
+ content,
+ metadata,
+ embedding,
+ 1 -(docstore.embedding <=> query_embedding) AS similarity
+ FROM
+ docstore
+ ORDER BY
+ docstore.embedding <=> query_embedding
+ LIMIT match_count;
+ END;
+ $$;
+ ```
+ """
+ embedding = self._embedding.embed_query(query)
+ docs = self.max_marginal_relevance_search_by_vector(
+ embedding, k, fetch_k, lambda_mult=lambda_mult
+ )
+ return docs
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
+ """Delete by vector IDs.
+
+ Args:
+ ids: List of ids to delete.
+ """
+
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ rows: List[Dict[str, Any]] = [
+ {
+ "id": id,
+ }
+ for id in ids
+ ]
+
+ # TODO: Check if this can be done in bulk
+ for row in rows:
+ self._client.from_(self.table_name).delete().eq("id", row["id"]).execute()
diff --git a/libs/community/langchain_community/vectorstores/tair.py b/libs/community/langchain_community/vectorstores/tair.py
new file mode 100644
index 00000000000..4ff8fb8a35f
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/tair.py
@@ -0,0 +1,309 @@
+from __future__ import annotations
+
+import json
+import logging
+import uuid
+from typing import Any, Iterable, List, Optional, Type
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore
+
+logger = logging.getLogger(__name__)
+
+
+def _uuid_key() -> str:
+ return uuid.uuid4().hex
+
+
+class Tair(VectorStore):
+ """`Tair` vector store."""
+
+ def __init__(
+ self,
+ embedding_function: Embeddings,
+ url: str,
+ index_name: str,
+ content_key: str = "content",
+ metadata_key: str = "metadata",
+ search_params: Optional[dict] = None,
+ **kwargs: Any,
+ ):
+ self.embedding_function = embedding_function
+ self.index_name = index_name
+ try:
+ from tair import Tair as TairClient
+ except ImportError:
+ raise ImportError(
+ "Could not import tair python package. "
+ "Please install it with `pip install tair`."
+ )
+ try:
+ # connect to tair from url
+ client = TairClient.from_url(url, **kwargs)
+ except ValueError as e:
+ raise ValueError(f"Tair failed to connect: {e}")
+
+ self.client = client
+ self.content_key = content_key
+ self.metadata_key = metadata_key
+ self.search_params = search_params
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_function
+
+ def create_index_if_not_exist(
+ self,
+ dim: int,
+ distance_type: str,
+ index_type: str,
+ data_type: str,
+ **kwargs: Any,
+ ) -> bool:
+ index = self.client.tvs_get_index(self.index_name)
+ if index is not None:
+ logger.info("Index already exists")
+ return False
+ self.client.tvs_create_index(
+ self.index_name,
+ dim,
+ distance_type,
+ index_type,
+ data_type,
+ **kwargs,
+ )
+ return True
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add texts data to an existing index."""
+ ids = []
+ keys = kwargs.get("keys", None)
+ use_hybrid_search = False
+ index = self.client.tvs_get_index(self.index_name)
+ if index is not None and index.get("lexical_algorithm") == "bm25":
+ use_hybrid_search = True
+ # Write data to tair
+ pipeline = self.client.pipeline(transaction=False)
+ embeddings = self.embedding_function.embed_documents(list(texts))
+ for i, text in enumerate(texts):
+ # Use provided key otherwise use default key
+ key = keys[i] if keys else _uuid_key()
+ metadata = metadatas[i] if metadatas else {}
+ if use_hybrid_search:
+ # tair use TEXT attr hybrid search
+ pipeline.tvs_hset(
+ self.index_name,
+ key,
+ embeddings[i],
+ False,
+ **{
+ "TEXT": text,
+ self.content_key: text,
+ self.metadata_key: json.dumps(metadata),
+ },
+ )
+ else:
+ pipeline.tvs_hset(
+ self.index_name,
+ key,
+ embeddings[i],
+ False,
+ **{
+ self.content_key: text,
+ self.metadata_key: json.dumps(metadata),
+ },
+ )
+ ids.append(key)
+ pipeline.execute()
+ return ids
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """
+ Returns the most similar indexed documents to the query text.
+
+ Args:
+ query (str): The query text for which to find similar documents.
+ k (int): The number of documents to return. Default is 4.
+
+ Returns:
+ List[Document]: A list of documents that are most similar to the query text.
+ """
+ # Creates embedding vector from user query
+ embedding = self.embedding_function.embed_query(query)
+
+ keys_and_scores = self.client.tvs_knnsearch(
+ self.index_name, k, embedding, False, None, **kwargs
+ )
+
+ pipeline = self.client.pipeline(transaction=False)
+ for key, _ in keys_and_scores:
+ pipeline.tvs_hmget(
+ self.index_name, key, self.metadata_key, self.content_key
+ )
+ docs = pipeline.execute()
+
+ return [
+ Document(
+ page_content=d[1],
+ metadata=json.loads(d[0]),
+ )
+ for d in docs
+ ]
+
+ @classmethod
+ def from_texts(
+ cls: Type[Tair],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ index_name: str = "langchain",
+ content_key: str = "content",
+ metadata_key: str = "metadata",
+ **kwargs: Any,
+ ) -> Tair:
+ try:
+ from tair import tairvector
+ except ImportError:
+ raise ValueError(
+ "Could not import tair python package. "
+ "Please install it with `pip install tair`."
+ )
+ url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
+ if "tair_url" in kwargs:
+ kwargs.pop("tair_url")
+
+ distance_type = tairvector.DistanceMetric.InnerProduct
+ if "distance_type" in kwargs:
+ distance_type = kwargs.pop("distance_type")
+ index_type = tairvector.IndexType.HNSW
+ if "index_type" in kwargs:
+ index_type = kwargs.pop("index_type")
+ data_type = tairvector.DataType.Float32
+ if "data_type" in kwargs:
+ data_type = kwargs.pop("data_type")
+ index_params = {}
+ if "index_params" in kwargs:
+ index_params = kwargs.pop("index_params")
+ search_params = {}
+ if "search_params" in kwargs:
+ search_params = kwargs.pop("search_params")
+
+ keys = None
+ if "keys" in kwargs:
+ keys = kwargs.pop("keys")
+ try:
+ tair_vector_store = cls(
+ embedding,
+ url,
+ index_name,
+ content_key=content_key,
+ metadata_key=metadata_key,
+ search_params=search_params,
+ **kwargs,
+ )
+ except ValueError as e:
+ raise ValueError(f"tair failed to connect: {e}")
+
+ # Create embeddings for documents
+ embeddings = embedding.embed_documents(texts)
+
+ tair_vector_store.create_index_if_not_exist(
+ len(embeddings[0]),
+ distance_type,
+ index_type,
+ data_type,
+ **index_params,
+ )
+
+ tair_vector_store.add_texts(texts, metadatas, keys=keys)
+ return tair_vector_store
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: List[Document],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ index_name: str = "langchain",
+ content_key: str = "content",
+ metadata_key: str = "metadata",
+ **kwargs: Any,
+ ) -> Tair:
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+
+ return cls.from_texts(
+ texts, embedding, metadatas, index_name, content_key, metadata_key, **kwargs
+ )
+
+ @staticmethod
+ def drop_index(
+ index_name: str = "langchain",
+ **kwargs: Any,
+ ) -> bool:
+ """
+ Drop an existing index.
+
+ Args:
+ index_name (str): Name of the index to drop.
+
+ Returns:
+ bool: True if the index is dropped successfully.
+ """
+ try:
+ from tair import Tair as TairClient
+ except ImportError:
+ raise ValueError(
+ "Could not import tair python package. "
+ "Please install it with `pip install tair`."
+ )
+ url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
+
+ try:
+ if "tair_url" in kwargs:
+ kwargs.pop("tair_url")
+ client = TairClient.from_url(url=url, **kwargs)
+ except ValueError as e:
+ raise ValueError(f"Tair connection error: {e}")
+ # delete index
+ ret = client.tvs_del_index(index_name)
+ if ret == 0:
+ # index not exist
+ logger.info("Index does not exist")
+ return False
+ return True
+
+ @classmethod
+ def from_existing_index(
+ cls,
+ embedding: Embeddings,
+ index_name: str = "langchain",
+ content_key: str = "content",
+ metadata_key: str = "metadata",
+ **kwargs: Any,
+ ) -> Tair:
+ """Connect to an existing Tair index."""
+ url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
+
+ search_params = {}
+ if "search_params" in kwargs:
+ search_params = kwargs.pop("search_params")
+
+ return cls(
+ embedding,
+ url,
+ index_name,
+ content_key=content_key,
+ metadata_key=metadata_key,
+ search_params=search_params,
+ **kwargs,
+ )
diff --git a/libs/community/langchain_community/vectorstores/tencentvectordb.py b/libs/community/langchain_community/vectorstores/tencentvectordb.py
new file mode 100644
index 00000000000..adf67936398
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/tencentvectordb.py
@@ -0,0 +1,392 @@
+"""Wrapper around the Tencent vector database."""
+from __future__ import annotations
+
+import json
+import logging
+import time
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import guard_import
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+logger = logging.getLogger(__name__)
+
+
+class ConnectionParams:
+ """Tencent vector DB Connection params.
+
+ See the following documentation for details:
+ https://cloud.tencent.com/document/product/1709/95820
+
+ Attribute:
+ url (str) : The access address of the vector database server
+ that the client needs to connect to.
+ key (str): API key for client to access the vector database server,
+ which is used for authentication.
+ username (str) : Account for client to access the vector database server.
+ timeout (int) : Request Timeout.
+ """
+
+ def __init__(self, url: str, key: str, username: str = "root", timeout: int = 10):
+ self.url = url
+ self.key = key
+ self.username = username
+ self.timeout = timeout
+
+
+class IndexParams:
+ """Tencent vector DB Index params.
+
+ See the following documentation for details:
+ https://cloud.tencent.com/document/product/1709/95826
+ """
+
+ def __init__(
+ self,
+ dimension: int,
+ shard: int = 1,
+ replicas: int = 2,
+ index_type: str = "HNSW",
+ metric_type: str = "L2",
+ params: Optional[Dict] = None,
+ ):
+ self.dimension = dimension
+ self.shard = shard
+ self.replicas = replicas
+ self.index_type = index_type
+ self.metric_type = metric_type
+ self.params = params
+
+
+class TencentVectorDB(VectorStore):
+ """Initialize wrapper around the tencent vector database.
+
+ In order to use this you need to have a database instance.
+ See the following documentation for details:
+ https://cloud.tencent.com/document/product/1709/94951
+ """
+
+ field_id: str = "id"
+ field_vector: str = "vector"
+ field_text: str = "text"
+ field_metadata: str = "metadata"
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ connection_params: ConnectionParams,
+ index_params: IndexParams = IndexParams(128),
+ database_name: str = "LangChainDatabase",
+ collection_name: str = "LangChainCollection",
+ drop_old: Optional[bool] = False,
+ ):
+ self.document = guard_import("tcvectordb.model.document")
+ tcvectordb = guard_import("tcvectordb")
+ self.embedding_func = embedding
+ self.index_params = index_params
+ self.vdb_client = tcvectordb.VectorDBClient(
+ url=connection_params.url,
+ username=connection_params.username,
+ key=connection_params.key,
+ timeout=connection_params.timeout,
+ )
+ db_list = self.vdb_client.list_databases()
+ db_exist: bool = False
+ for db in db_list:
+ if database_name == db.database_name:
+ db_exist = True
+ break
+ if db_exist:
+ self.database = self.vdb_client.database(database_name)
+ else:
+ self.database = self.vdb_client.create_database(database_name)
+ try:
+ self.collection = self.database.describe_collection(collection_name)
+ if drop_old:
+ self.database.drop_collection(collection_name)
+ self._create_collection(collection_name)
+ except tcvectordb.exceptions.VectorDBException:
+ self._create_collection(collection_name)
+
+ def _create_collection(self, collection_name: str) -> None:
+ enum = guard_import("tcvectordb.model.enum")
+ vdb_index = guard_import("tcvectordb.model.index")
+ index_type = None
+ for k, v in enum.IndexType.__members__.items():
+ if k == self.index_params.index_type:
+ index_type = v
+ if index_type is None:
+ raise ValueError("unsupported index_type")
+ metric_type = None
+ for k, v in enum.MetricType.__members__.items():
+ if k == self.index_params.metric_type:
+ metric_type = v
+ if metric_type is None:
+ raise ValueError("unsupported metric_type")
+ if self.index_params.params is None:
+ params = vdb_index.HNSWParams(m=16, efconstruction=200)
+ else:
+ params = vdb_index.HNSWParams(
+ m=self.index_params.params.get("M", 16),
+ efconstruction=self.index_params.params.get("efConstruction", 200),
+ )
+ index = vdb_index.Index(
+ vdb_index.FilterIndex(
+ self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
+ ),
+ vdb_index.VectorIndex(
+ self.field_vector,
+ self.index_params.dimension,
+ index_type,
+ metric_type,
+ params,
+ ),
+ vdb_index.FilterIndex(
+ self.field_text, enum.FieldType.String, enum.IndexType.FILTER
+ ),
+ vdb_index.FilterIndex(
+ self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
+ ),
+ )
+ self.collection = self.database.create_collection(
+ name=collection_name,
+ shard=self.index_params.shard,
+ replicas=self.index_params.replicas,
+ description="Collection for LangChain",
+ index=index,
+ )
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding_func
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ connection_params: Optional[ConnectionParams] = None,
+ index_params: Optional[IndexParams] = None,
+ database_name: str = "LangChainDatabase",
+ collection_name: str = "LangChainCollection",
+ drop_old: Optional[bool] = False,
+ **kwargs: Any,
+ ) -> TencentVectorDB:
+ """Create a collection, indexes it with HNSW, and insert data."""
+ if len(texts) == 0:
+ raise ValueError("texts is empty")
+ if connection_params is None:
+ raise ValueError("connection_params is empty")
+ try:
+ embeddings = embedding.embed_documents(texts[0:1])
+ except NotImplementedError:
+ embeddings = [embedding.embed_query(texts[0])]
+ dimension = len(embeddings[0])
+ if index_params is None:
+ index_params = IndexParams(dimension=dimension)
+ else:
+ index_params.dimension = dimension
+ vector_db = cls(
+ embedding=embedding,
+ connection_params=connection_params,
+ index_params=index_params,
+ database_name=database_name,
+ collection_name=collection_name,
+ drop_old=drop_old,
+ )
+ vector_db.add_texts(texts=texts, metadatas=metadatas)
+ return vector_db
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ timeout: Optional[int] = None,
+ batch_size: int = 1000,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Insert text data into TencentVectorDB."""
+ texts = list(texts)
+ try:
+ embeddings = self.embedding_func.embed_documents(texts)
+ except NotImplementedError:
+ embeddings = [self.embedding_func.embed_query(x) for x in texts]
+ if len(embeddings) == 0:
+ logger.debug("Nothing to insert, skipping.")
+ return []
+ pks: list[str] = []
+ total_count = len(embeddings)
+ for start in range(0, total_count, batch_size):
+ # Grab end index
+ docs = []
+ end = min(start + batch_size, total_count)
+ for id in range(start, end, 1):
+ metadata = "{}"
+ if metadatas is not None:
+ metadata = json.dumps(metadatas[id])
+ doc = self.document.Document(
+ id="{}-{}-{}".format(time.time_ns(), hash(texts[id]), id),
+ vector=embeddings[id],
+ text=texts[id],
+ metadata=metadata,
+ )
+ docs.append(doc)
+ pks.append(str(id))
+ self.collection.upsert(docs, timeout)
+ return pks
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a similarity search against the query string."""
+ res = self.similarity_search_with_score(
+ query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+ )
+ return [doc for doc, _ in res]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Perform a search on a query string and return results with score."""
+ # Embed the query text.
+ embedding = self.embedding_func.embed_query(query)
+ res = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+ )
+ return res
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a similarity search against the query string."""
+ res = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+ )
+ return [doc for doc, _ in res]
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Perform a search on a query string and return results with score."""
+ filter = None if expr is None else self.document.Filter(expr)
+ ef = 10 if param is None else param.get("ef", 10)
+ res: List[List[Dict]] = self.collection.search(
+ vectors=[embedding],
+ filter=filter,
+ params=self.document.HNSWSearchParams(ef=ef),
+ retrieve_vector=False,
+ limit=k,
+ timeout=timeout,
+ )
+ # Organize results.
+ ret: List[Tuple[Document, float]] = []
+ if res is None or len(res) == 0:
+ return ret
+ for result in res[0]:
+ meta = result.get(self.field_metadata)
+ if meta is not None:
+ meta = json.loads(meta)
+ doc = Document(page_content=result.get(self.field_text), metadata=meta)
+ pair = (doc, result.get("score", 0.0))
+ ret.append(pair)
+ return ret
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a search and return results that are reordered by MMR."""
+ embedding = self.embedding_func.embed_query(query)
+ return self.max_marginal_relevance_search_by_vector(
+ embedding=embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ param=param,
+ expr=expr,
+ timeout=timeout,
+ **kwargs,
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: list[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ param: Optional[dict] = None,
+ expr: Optional[str] = None,
+ timeout: Optional[int] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Perform a search and return results that are reordered by MMR."""
+ filter = None if expr is None else self.document.Filter(expr)
+ ef = 10 if param is None else param.get("ef", 10)
+ res: List[List[Dict]] = self.collection.search(
+ vectors=[embedding],
+ filter=filter,
+ params=self.document.HNSWSearchParams(ef=ef),
+ retrieve_vector=True,
+ limit=fetch_k,
+ timeout=timeout,
+ )
+ # Organize results.
+ documents = []
+ ordered_result_embeddings = []
+ for result in res[0]:
+ meta = result.get(self.field_metadata)
+ if meta is not None:
+ meta = json.loads(meta)
+ doc = Document(page_content=result.get(self.field_text), metadata=meta)
+ documents.append(doc)
+ ordered_result_embeddings.append(result.get(self.field_vector))
+ # Get the new order of results.
+ new_ordering = maximal_marginal_relevance(
+ np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
+ )
+ # Reorder the values and return.
+ ret = []
+ for x in new_ordering:
+ # Function can return -1 index
+ if x == -1:
+ break
+ else:
+ ret.append(documents[x])
+ return ret
diff --git a/libs/community/langchain_community/vectorstores/tigris.py b/libs/community/langchain_community/vectorstores/tigris.py
new file mode 100644
index 00000000000..96038b7c749
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/tigris.py
@@ -0,0 +1,148 @@
+from __future__ import annotations
+
+import itertools
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ from tigrisdb import TigrisClient
+ from tigrisdb import VectorStore as TigrisVectorStore
+ from tigrisdb.types.filters import Filter as TigrisFilter
+ from tigrisdb.types.vector import Document as TigrisDocument
+
+
+class Tigris(VectorStore):
+ """`Tigris` vector store."""
+
+ def __init__(self, client: TigrisClient, embeddings: Embeddings, index_name: str):
+ """Initialize Tigris vector store."""
+ try:
+ import tigrisdb # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import tigrisdb python package. "
+ "Please install it with `pip install tigrisdb`"
+ )
+
+ self._embed_fn = embeddings
+ self._vector_store = TigrisVectorStore(client.get_search(), index_name)
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embed_fn
+
+ @property
+ def search_index(self) -> TigrisVectorStore:
+ return self._vector_store
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids for documents.
+ Ids will be autogenerated if not provided.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ docs = self._prep_docs(texts, metadatas, ids)
+ result = self.search_index.add_documents(docs)
+ return [r.id for r in result]
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[TigrisFilter] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query."""
+ docs_with_scores = self.similarity_search_with_score(query, k, filter)
+ return [doc for doc, _ in docs_with_scores]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[TigrisFilter] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Run similarity search with Chroma with distance.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[TigrisFilter]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Tuple[Document, float]]: List of documents most similar to the query
+ text with distance in float.
+ """
+ vector = self._embed_fn.embed_query(query)
+ result = self.search_index.similarity_search(
+ vector=vector, k=k, filter_by=filter
+ )
+ docs: List[Tuple[Document, float]] = []
+ for r in result:
+ docs.append(
+ (
+ Document(
+ page_content=r.doc["text"], metadata=r.doc.get("metadata")
+ ),
+ r.score,
+ )
+ )
+ return docs
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ client: Optional[TigrisClient] = None,
+ index_name: Optional[str] = None,
+ **kwargs: Any,
+ ) -> Tigris:
+ """Return VectorStore initialized from texts and embeddings."""
+ if not index_name:
+ raise ValueError("`index_name` is required")
+
+ if not client:
+ client = TigrisClient()
+ store = cls(client, embedding, index_name)
+ store.add_texts(texts=texts, metadatas=metadatas, ids=ids)
+ return store
+
+ def _prep_docs(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]],
+ ids: Optional[List[str]],
+ ) -> List[TigrisDocument]:
+ embeddings: List[List[float]] = self._embed_fn.embed_documents(list(texts))
+ docs: List[TigrisDocument] = []
+ for t, m, e, _id in itertools.zip_longest(
+ texts, metadatas or [], embeddings or [], ids or []
+ ):
+ doc: TigrisDocument = {
+ "text": t,
+ "embeddings": e or [],
+ "metadata": m or {},
+ }
+ if _id:
+ doc["id"] = _id
+ docs.append(doc)
+ return docs
diff --git a/libs/community/langchain_community/vectorstores/tiledb.py b/libs/community/langchain_community/vectorstores/tiledb.py
new file mode 100644
index 00000000000..85d4052a94b
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/tiledb.py
@@ -0,0 +1,789 @@
+"""Wrapper around TileDB vector database."""
+from __future__ import annotations
+
+import pickle
+import random
+import sys
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+INDEX_METRICS = frozenset(["euclidean"])
+DEFAULT_METRIC = "euclidean"
+DOCUMENTS_ARRAY_NAME = "documents"
+VECTOR_INDEX_NAME = "vectors"
+MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
+MAX_FLOAT_32 = np.finfo(np.dtype("float32")).max
+MAX_FLOAT = sys.float_info.max
+
+
+def dependable_tiledb_import() -> Any:
+ """Import tiledb-vector-search if available, otherwise raise error."""
+ try:
+ import tiledb as tiledb
+ import tiledb.vector_search as tiledb_vs
+ except ImportError:
+ raise ValueError(
+ "Could not import tiledb-vector-search python package. "
+ "Please install it with `conda install -c tiledb tiledb-vector-search` "
+ "or `pip install tiledb-vector-search`"
+ )
+ return tiledb_vs, tiledb
+
+
+def get_vector_index_uri_from_group(group: Any) -> str:
+ return group[VECTOR_INDEX_NAME].uri
+
+
+def get_documents_array_uri_from_group(group: Any) -> str:
+ return group[DOCUMENTS_ARRAY_NAME].uri
+
+
+def get_vector_index_uri(uri: str) -> str:
+ return f"{uri}/{VECTOR_INDEX_NAME}"
+
+
+def get_documents_array_uri(uri: str) -> str:
+ return f"{uri}/{DOCUMENTS_ARRAY_NAME}"
+
+
+class TileDB(VectorStore):
+ """Wrapper around TileDB vector database.
+
+ To use, you should have the ``tiledb-vector-search`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community import TileDB
+ embeddings = OpenAIEmbeddings()
+ db = TileDB(embeddings, index_uri, metric)
+
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ index_uri: str,
+ metric: str,
+ *,
+ vector_index_uri: str = "",
+ docs_array_uri: str = "",
+ config: Optional[Mapping[str, Any]] = None,
+ timestamp: Any = None,
+ **kwargs: Any,
+ ):
+ """Initialize with necessary components."""
+ self.embedding = embedding
+ self.embedding_function = embedding.embed_query
+ self.index_uri = index_uri
+ self.metric = metric
+ self.config = config
+
+ tiledb_vs, tiledb = dependable_tiledb_import()
+ with tiledb.scope_ctx(ctx_or_config=config):
+ index_group = tiledb.Group(self.index_uri, "r")
+ self.vector_index_uri = (
+ vector_index_uri
+ if vector_index_uri != ""
+ else get_vector_index_uri_from_group(index_group)
+ )
+ self.docs_array_uri = (
+ docs_array_uri
+ if docs_array_uri != ""
+ else get_documents_array_uri_from_group(index_group)
+ )
+ index_group.close()
+ group = tiledb.Group(self.vector_index_uri, "r")
+ self.index_type = group.meta.get("index_type")
+ group.close()
+ self.timestamp = timestamp
+ if self.index_type == "FLAT":
+ self.vector_index = tiledb_vs.flat_index.FlatIndex(
+ uri=self.vector_index_uri,
+ config=self.config,
+ timestamp=self.timestamp,
+ **kwargs,
+ )
+ elif self.index_type == "IVF_FLAT":
+ self.vector_index = tiledb_vs.ivf_flat_index.IVFFlatIndex(
+ uri=self.vector_index_uri,
+ config=self.config,
+ timestamp=self.timestamp,
+ **kwargs,
+ )
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self.embedding
+
+ def process_index_results(
+ self,
+ ids: List[int],
+ scores: List[float],
+ *,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ score_threshold: float = MAX_FLOAT,
+ ) -> List[Tuple[Document, float]]:
+ """Turns TileDB results into a list of documents and scores.
+
+ Args:
+ ids: List of indices of the documents in the index.
+ scores: List of distances of the documents in the index.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
+ score_threshold: Optional, a floating point value to filter the
+ resulting set of retrieved docs
+ Returns:
+ List of Documents and scores.
+ """
+ tiledb_vs, tiledb = dependable_tiledb_import()
+ docs = []
+ docs_array = tiledb.open(
+ self.docs_array_uri, "r", timestamp=self.timestamp, config=self.config
+ )
+ for idx, score in zip(ids, scores):
+ if idx == 0 and score == 0:
+ continue
+ if idx == MAX_UINT64 and score == MAX_FLOAT_32:
+ continue
+ doc = docs_array[idx]
+ if doc is None or len(doc["text"]) == 0:
+ raise ValueError(f"Could not find document for id {idx}, got {doc}")
+ pickled_metadata = doc.get("metadata")
+ result_doc = Document(page_content=str(doc["text"][0]))
+ if pickled_metadata is not None:
+ metadata = pickle.loads(
+ np.array(pickled_metadata.tolist()).astype(np.uint8).tobytes()
+ )
+ result_doc.metadata = metadata
+ if filter is not None:
+ filter = {
+ key: [value] if not isinstance(value, list) else value
+ for key, value in filter.items()
+ }
+ if all(
+ result_doc.metadata.get(key) in value
+ for key, value in filter.items()
+ ):
+ docs.append((result_doc, score))
+ else:
+ docs.append((result_doc, score))
+ docs_array.close()
+ docs = [(doc, score) for doc, score in docs if score <= score_threshold]
+ return docs[:k]
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ *,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ embedding: Embedding vector to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+ **kwargs: kwargs to be passed to similarity search. Can include:
+ nprobe: Optional, number of partitions to check if using IVF_FLAT index
+ score_threshold: Optional, a floating point value to filter the
+ resulting set of retrieved docs
+
+ Returns:
+ List of documents most similar to the query text and distance
+ in float for each. Lower score represents more similarity.
+ """
+ if "score_threshold" in kwargs:
+ score_threshold = kwargs.pop("score_threshold")
+ else:
+ score_threshold = MAX_FLOAT
+ d, i = self.vector_index.query(
+ np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
+ k=k if filter is None else fetch_k,
+ **kwargs,
+ )
+ return self.process_index_results(
+ ids=i[0], scores=d[0], filter=filter, k=k, score_threshold=score_threshold
+ )
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ *,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of documents most similar to the query text with
+ Distance as float. Lower score represents more similarity.
+ """
+ embedding = self.embedding_function(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding,
+ k=k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of Documents most similar to the embedding.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding,
+ k=k,
+ filter=filter,
+ fetch_k=fetch_k,
+ **kwargs,
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Dict[str, Any]] = None,
+ fetch_k: int = 20,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+ fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
+ Defaults to 20.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query, k=k, filter=filter, fetch_k=fetch_k, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def max_marginal_relevance_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ *,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs and their similarity scores selected using the maximal marginal
+ relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents and similarity scores selected by maximal marginal
+ relevance and score for each.
+ """
+ if "score_threshold" in kwargs:
+ score_threshold = kwargs.pop("score_threshold")
+ else:
+ score_threshold = MAX_FLOAT
+ scores, indices = self.vector_index.query(
+ np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
+ k=fetch_k if filter is None else fetch_k * 2,
+ **kwargs,
+ )
+ results = self.process_index_results(
+ ids=indices[0],
+ scores=scores[0],
+ filter=filter,
+ k=fetch_k if filter is None else fetch_k * 2,
+ score_threshold=score_threshold,
+ )
+ embeddings = [
+ self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results
+ ]
+ mmr_selected = maximal_marginal_relevance(
+ np.array([embedding], dtype=np.float32),
+ embeddings,
+ k=k,
+ lambda_mult=lambda_mult,
+ )
+ docs_and_scores = []
+ for i in mmr_selected:
+ docs_and_scores.append(results[i])
+ return docs_and_scores
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ **kwargs,
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ filter: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch before filtering (if needed) to
+ pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ embedding = self.embedding_function(query)
+ docs = self.max_marginal_relevance_search_by_vector(
+ embedding,
+ k=k,
+ fetch_k=fetch_k,
+ lambda_mult=lambda_mult,
+ filter=filter,
+ **kwargs,
+ )
+ return docs
+
+ @classmethod
+ def create(
+ cls,
+ index_uri: str,
+ index_type: str,
+ dimensions: int,
+ vector_type: np.dtype,
+ *,
+ metadatas: bool = True,
+ config: Optional[Mapping[str, Any]] = None,
+ ) -> None:
+ tiledb_vs, tiledb = dependable_tiledb_import()
+ with tiledb.scope_ctx(ctx_or_config=config):
+ try:
+ tiledb.group_create(index_uri)
+ except tiledb.TileDBError as err:
+ raise err
+ group = tiledb.Group(index_uri, "w")
+ vector_index_uri = get_vector_index_uri(group.uri)
+ docs_uri = get_documents_array_uri(group.uri)
+ if index_type == "FLAT":
+ tiledb_vs.flat_index.create(
+ uri=vector_index_uri,
+ dimensions=dimensions,
+ vector_type=vector_type,
+ config=config,
+ )
+ elif index_type == "IVF_FLAT":
+ tiledb_vs.ivf_flat_index.create(
+ uri=vector_index_uri,
+ dimensions=dimensions,
+ vector_type=vector_type,
+ config=config,
+ )
+ group.add(vector_index_uri, name=VECTOR_INDEX_NAME)
+
+ # Create TileDB array to store Documents
+ # TODO add a Document store API to tiledb-vector-search to allow storing
+ # different types of objects and metadata in a more generic way.
+ dim = tiledb.Dim(
+ name="id",
+ domain=(0, MAX_UINT64 - 1),
+ dtype=np.dtype(np.uint64),
+ )
+ dom = tiledb.Domain(dim)
+
+ text_attr = tiledb.Attr(name="text", dtype=np.dtype("U1"), var=True)
+ attrs = [text_attr]
+ if metadatas:
+ metadata_attr = tiledb.Attr(name="metadata", dtype=np.uint8, var=True)
+ attrs.append(metadata_attr)
+ schema = tiledb.ArraySchema(
+ domain=dom,
+ sparse=True,
+ allows_duplicates=False,
+ attrs=attrs,
+ )
+ tiledb.Array.create(docs_uri, schema)
+ group.add(docs_uri, name=DOCUMENTS_ARRAY_NAME)
+ group.close()
+
+ @classmethod
+ def __from(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ index_uri: str,
+ *,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ metric: str = DEFAULT_METRIC,
+ index_type: str = "FLAT",
+ config: Optional[Mapping[str, Any]] = None,
+ index_timestamp: int = 0,
+ **kwargs: Any,
+ ) -> TileDB:
+ if metric not in INDEX_METRICS:
+ raise ValueError(
+ (
+ f"Unsupported distance metric: {metric}. "
+ f"Expected one of {list(INDEX_METRICS)}"
+ )
+ )
+ tiledb_vs, tiledb = dependable_tiledb_import()
+ input_vectors = np.array(embeddings).astype(np.float32)
+ cls.create(
+ index_uri=index_uri,
+ index_type=index_type,
+ dimensions=input_vectors.shape[1],
+ vector_type=input_vectors.dtype,
+ metadatas=metadatas is not None,
+ config=config,
+ )
+ with tiledb.scope_ctx(ctx_or_config=config):
+ if not embeddings:
+ raise ValueError("embeddings must be provided to build a TileDB index")
+
+ vector_index_uri = get_vector_index_uri(index_uri)
+ docs_uri = get_documents_array_uri(index_uri)
+ if ids is None:
+ ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
+ external_ids = np.array(ids).astype(np.uint64)
+
+ tiledb_vs.ingestion.ingest(
+ index_type=index_type,
+ index_uri=vector_index_uri,
+ input_vectors=input_vectors,
+ external_ids=external_ids,
+ index_timestamp=index_timestamp if index_timestamp != 0 else None,
+ config=config,
+ **kwargs,
+ )
+ with tiledb.open(docs_uri, "w") as A:
+ if external_ids is None:
+ external_ids = np.zeros(len(texts), dtype=np.uint64)
+ for i in range(len(texts)):
+ external_ids[i] = i
+ data = {}
+ data["text"] = np.array(texts)
+ if metadatas is not None:
+ metadata_attr = np.empty([len(metadatas)], dtype=object)
+ i = 0
+ for metadata in metadatas:
+ metadata_attr[i] = np.frombuffer(
+ pickle.dumps(metadata), dtype=np.uint8
+ )
+ i += 1
+ data["metadata"] = metadata_attr
+
+ A[external_ids] = data
+ return cls(
+ embedding=embedding,
+ index_uri=index_uri,
+ metric=metric,
+ config=config,
+ **kwargs,
+ )
+
+ def delete(
+ self, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any
+ ) -> Optional[bool]:
+ """Delete by vector ID or other criteria.
+
+ Args:
+ ids: List of ids to delete.
+ timestamp: Optional timestamp to delete with.
+ **kwargs: Other keyword arguments that subclasses might use.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+
+ external_ids = np.array(ids).astype(np.uint64)
+ self.vector_index.delete_batch(
+ external_ids=external_ids, timestamp=timestamp if timestamp != 0 else None
+ )
+ return True
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ timestamp: int = 0,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional ids of each text object.
+ timestamp: Optional timestamp to write new texts with.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ tiledb_vs, tiledb = dependable_tiledb_import()
+ embeddings = self.embedding.embed_documents(list(texts))
+ if ids is None:
+ ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
+
+ external_ids = np.array(ids).astype(np.uint64)
+ vectors = np.empty((len(embeddings)), dtype="O")
+ for i in range(len(embeddings)):
+ vectors[i] = np.array(embeddings[i], dtype=np.float32)
+ self.vector_index.update_batch(
+ vectors=vectors,
+ external_ids=external_ids,
+ timestamp=timestamp if timestamp != 0 else None,
+ )
+
+ docs = {}
+ docs["text"] = np.array(texts)
+ if metadatas is not None:
+ metadata_attr = np.empty([len(metadatas)], dtype=object)
+ i = 0
+ for metadata in metadatas:
+ metadata_attr[i] = np.frombuffer(pickle.dumps(metadata), dtype=np.uint8)
+ i += 1
+ docs["metadata"] = metadata_attr
+
+ docs_array = tiledb.open(
+ self.docs_array_uri,
+ "w",
+ timestamp=timestamp if timestamp != 0 else None,
+ config=self.config,
+ )
+ docs_array[external_ids] = docs
+ docs_array.close()
+ return ids
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ metric: str = DEFAULT_METRIC,
+ index_uri: str = "/tmp/tiledb_array",
+ index_type: str = "FLAT",
+ config: Optional[Mapping[str, Any]] = None,
+ index_timestamp: int = 0,
+ **kwargs: Any,
+ ) -> TileDB:
+ """Construct a TileDB index from raw documents.
+
+ Args:
+ texts: List of documents to index.
+ embedding: Embedding function to use.
+ metadatas: List of metadata dictionaries to associate with documents.
+ ids: Optional ids of each text object.
+ metric: Metric to use for indexing. Defaults to "euclidean".
+ index_uri: The URI to write the TileDB arrays
+ index_type: Optional, Vector index type ("FLAT", IVF_FLAT")
+ config: Optional, TileDB config
+ index_timestamp: Optional, timestamp to write new texts with.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community import TileDB
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ index = TileDB.from_texts(texts, embeddings)
+ """
+ embeddings = []
+ embeddings = embedding.embed_documents(texts)
+ return cls.__from(
+ texts=texts,
+ embeddings=embeddings,
+ embedding=embedding,
+ metadatas=metadatas,
+ ids=ids,
+ metric=metric,
+ index_uri=index_uri,
+ index_type=index_type,
+ config=config,
+ index_timestamp=index_timestamp,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ index_uri: str,
+ *,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ metric: str = DEFAULT_METRIC,
+ index_type: str = "FLAT",
+ config: Optional[Mapping[str, Any]] = None,
+ index_timestamp: int = 0,
+ **kwargs: Any,
+ ) -> TileDB:
+ """Construct TileDB index from embeddings.
+
+ Args:
+ text_embeddings: List of tuples of (text, embedding)
+ embedding: Embedding function to use.
+ index_uri: The URI to write the TileDB arrays
+ metadatas: List of metadata dictionaries to associate with documents.
+ metric: Optional, Metric to use for indexing. Defaults to "euclidean".
+ index_type: Optional, Vector index type ("FLAT", IVF_FLAT")
+ config: Optional, TileDB config
+ index_timestamp: Optional, timestamp to write new texts with.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community import TileDB
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ db = TileDB.from_embeddings(text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+
+ return cls.__from(
+ texts=texts,
+ embeddings=embeddings,
+ embedding=embedding,
+ metadatas=metadatas,
+ ids=ids,
+ metric=metric,
+ index_uri=index_uri,
+ index_type=index_type,
+ config=config,
+ index_timestamp=index_timestamp,
+ **kwargs,
+ )
+
+ @classmethod
+ def load(
+ cls,
+ index_uri: str,
+ embedding: Embeddings,
+ *,
+ metric: str = DEFAULT_METRIC,
+ config: Optional[Mapping[str, Any]] = None,
+ timestamp: Any = None,
+ **kwargs: Any,
+ ) -> TileDB:
+ """Load a TileDB index from a URI.
+
+ Args:
+ index_uri: The URI of the TileDB vector index.
+ embedding: Embeddings to use when generating queries.
+ metric: Optional, Metric to use for indexing. Defaults to "euclidean".
+ config: Optional, TileDB config
+ timestamp: Optional, timestamp to use for opening the arrays.
+ """
+ return cls(
+ embedding=embedding,
+ index_uri=index_uri,
+ metric=metric,
+ config=config,
+ timestamp=timestamp,
+ **kwargs,
+ )
+
+ def consolidate_updates(self, **kwargs: Any) -> None:
+ self.vector_index = self.vector_index.consolidate_updates(**kwargs)
diff --git a/libs/community/langchain_community/vectorstores/timescalevector.py b/libs/community/langchain_community/vectorstores/timescalevector.py
new file mode 100644
index 00000000000..7d3201510ad
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/timescalevector.py
@@ -0,0 +1,883 @@
+"""VectorStore wrapper around a Postgres-TimescaleVector database."""
+from __future__ import annotations
+
+import enum
+import logging
+import uuid
+from datetime import timedelta
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_dict_or_env
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import DistanceStrategy
+
+if TYPE_CHECKING:
+ from timescale_vector import Predicates
+
+
+DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
+
+ADA_TOKEN_COUNT = 1536
+
+_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_store"
+
+
+class TimescaleVector(VectorStore):
+ """VectorStore implementation using the timescale vector client to store vectors
+ in Postgres.
+
+ To use, you should have the ``timescale_vector`` python package installed.
+
+ Args:
+ service_url: Service url on timescale cloud.
+ embedding: Any embedding function implementing
+ `langchain.embeddings.base.Embeddings` interface.
+ collection_name: The name of the collection to use. (default: langchain_store)
+ This will become the table name used for the collection.
+ distance_strategy: The distance strategy to use. (default: COSINE)
+ pre_delete_collection: If True, will delete the collection if it exists.
+ (default: False). Useful for testing.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import TimescaleVector
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+
+ SERVICE_URL = "postgres://tsdbadmin:@.tsdb.cloud.timescale.com:/tsdb?sslmode=require"
+ COLLECTION_NAME = "state_of_the_union_test"
+ embeddings = OpenAIEmbeddings()
+ vectorestore = TimescaleVector.from_documents(
+ embedding=embeddings,
+ documents=docs,
+ collection_name=COLLECTION_NAME,
+ service_url=SERVICE_URL,
+ )
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ service_url: str,
+ embedding: Embeddings,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ num_dimensions: int = ADA_TOKEN_COUNT,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ pre_delete_collection: bool = False,
+ logger: Optional[logging.Logger] = None,
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ time_partition_interval: Optional[timedelta] = None,
+ **kwargs: Any,
+ ) -> None:
+ try:
+ from timescale_vector import client
+ except ImportError:
+ raise ImportError(
+ "Could not import timescale_vector python package. "
+ "Please install it with `pip install timescale-vector`."
+ )
+
+ self.service_url = service_url
+ self.embedding = embedding
+ self.collection_name = collection_name
+ self.num_dimensions = num_dimensions
+ self._distance_strategy = distance_strategy
+ self.pre_delete_collection = pre_delete_collection
+ self.logger = logger or logging.getLogger(__name__)
+ self.override_relevance_score_fn = relevance_score_fn
+ self._time_partition_interval = time_partition_interval
+ self.sync_client = client.Sync(
+ self.service_url,
+ self.collection_name,
+ self.num_dimensions,
+ self._distance_strategy.value.lower(),
+ time_partition_interval=self._time_partition_interval,
+ **kwargs,
+ )
+ self.async_client = client.Async(
+ self.service_url,
+ self.collection_name,
+ self.num_dimensions,
+ self._distance_strategy.value.lower(),
+ time_partition_interval=self._time_partition_interval,
+ **kwargs,
+ )
+ self.__post_init__()
+
+ def __post_init__(
+ self,
+ ) -> None:
+ """
+ Initialize the store.
+ """
+ self.sync_client.create_tables()
+ if self.pre_delete_collection:
+ self.sync_client.delete_all()
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self.embedding
+
+ def drop_tables(self) -> None:
+ self.sync_client.drop_table()
+
+ @classmethod
+ def __from(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ service_url: Optional[str] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> TimescaleVector:
+ num_dimensions = len(embeddings[0])
+
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ if service_url is None:
+ service_url = cls.get_service_url(kwargs)
+
+ store = cls(
+ service_url=service_url,
+ num_dimensions=num_dimensions,
+ collection_name=collection_name,
+ embedding=embedding,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ store.add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ return store
+
+ @classmethod
+ async def __afrom(
+ cls,
+ texts: List[str],
+ embeddings: List[List[float]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ service_url: Optional[str] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> TimescaleVector:
+ num_dimensions = len(embeddings[0])
+
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ if service_url is None:
+ service_url = cls.get_service_url(kwargs)
+
+ store = cls(
+ service_url=service_url,
+ num_dimensions=num_dimensions,
+ collection_name=collection_name,
+ embedding=embedding,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ await store.aadd_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ return store
+
+ def add_embeddings(
+ self,
+ texts: Iterable[str],
+ embeddings: List[List[float]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add embeddings to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ embeddings: List of list of embedding vectors.
+ metadatas: List of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+ """
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ records = list(zip(ids, metadatas, texts, embeddings))
+ self.sync_client.upsert(records)
+
+ return ids
+
+ async def aadd_embeddings(
+ self,
+ texts: Iterable[str],
+ embeddings: List[List[float]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add embeddings to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ embeddings: List of list of embedding vectors.
+ metadatas: List of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+ """
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ records = list(zip(ids, metadatas, texts, embeddings))
+ await self.async_client.upsert(records)
+
+ return ids
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ embeddings = self.embedding.embed_documents(list(texts))
+ return self.add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ async def aadd_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ embeddings = self.embedding.embed_documents(list(texts))
+ return await self.aadd_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+
+ def _embed_query(self, query: str) -> Optional[List[float]]:
+ # an empty query should not be embedded
+ if query is None or query == "" or query.isspace():
+ return None
+ else:
+ return self.embedding.embed_query(query)
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Union[dict, list]] = None,
+ predicates: Optional[Predicates] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search with TimescaleVector with distance.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ embedding = self._embed_query(query)
+ return self.similarity_search_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ predicates=predicates,
+ **kwargs,
+ )
+
+ async def asimilarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Union[dict, list]] = None,
+ predicates: Optional[Predicates] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Run similarity search with TimescaleVector with distance.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ embedding = self._embed_query(query)
+ return await self.asimilarity_search_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ predicates=predicates,
+ **kwargs,
+ )
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Union[dict, list]] = None,
+ predicates: Optional[Predicates] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ embedding = self._embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ predicates=predicates,
+ **kwargs,
+ )
+ return docs
+
+ async def asimilarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[Union[dict, list]] = None,
+ predicates: Optional[Predicates] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+
+ embedding = self._embed_query(query)
+ return await self.asimilarity_search_with_score_by_vector(
+ embedding=embedding,
+ k=k,
+ filter=filter,
+ predicates=predicates,
+ **kwargs,
+ )
+
+ def date_to_range_filter(self, **kwargs: Any) -> Any:
+ constructor_args = {
+ key: kwargs[key]
+ for key in [
+ "start_date",
+ "end_date",
+ "time_delta",
+ "start_inclusive",
+ "end_inclusive",
+ ]
+ if key in kwargs
+ }
+ if not constructor_args or len(constructor_args) == 0:
+ return None
+
+ try:
+ from timescale_vector import client
+ except ImportError:
+ raise ImportError(
+ "Could not import timescale_vector python package. "
+ "Please install it with `pip install timescale-vector`."
+ )
+ return client.UUIDTimeRange(**constructor_args)
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: Optional[List[float]],
+ k: int = 4,
+ filter: Optional[Union[dict, list]] = None,
+ predicates: Optional[Predicates] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ try:
+ from timescale_vector import client
+ except ImportError:
+ raise ImportError(
+ "Could not import timescale_vector python package. "
+ "Please install it with `pip install timescale-vector`."
+ )
+
+ results = self.sync_client.search(
+ embedding,
+ limit=k,
+ filter=filter,
+ predicates=predicates,
+ uuid_time_filter=self.date_to_range_filter(**kwargs),
+ )
+
+ docs = [
+ (
+ Document(
+ page_content=result[client.SEARCH_RESULT_CONTENTS_IDX],
+ metadata=result[client.SEARCH_RESULT_METADATA_IDX],
+ ),
+ result[client.SEARCH_RESULT_DISTANCE_IDX],
+ )
+ for result in results
+ ]
+ return docs
+
+ async def asimilarity_search_with_score_by_vector(
+ self,
+ embedding: Optional[List[float]],
+ k: int = 4,
+ filter: Optional[Union[dict, list]] = None,
+ predicates: Optional[Predicates] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ try:
+ from timescale_vector import client
+ except ImportError:
+ raise ImportError(
+ "Could not import timescale_vector python package. "
+ "Please install it with `pip install timescale-vector`."
+ )
+
+ results = await self.async_client.search(
+ embedding,
+ limit=k,
+ filter=filter,
+ predicates=predicates,
+ uuid_time_filter=self.date_to_range_filter(**kwargs),
+ )
+
+ docs = [
+ (
+ Document(
+ page_content=result[client.SEARCH_RESULT_CONTENTS_IDX],
+ metadata=result[client.SEARCH_RESULT_METADATA_IDX],
+ ),
+ result[client.SEARCH_RESULT_DISTANCE_IDX],
+ )
+ for result in results
+ ]
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: Optional[List[float]],
+ k: int = 4,
+ filter: Optional[Union[dict, list]] = None,
+ predicates: Optional[Predicates] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter, predicates=predicates, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ async def asimilarity_search_by_vector(
+ self,
+ embedding: Optional[List[float]],
+ k: int = 4,
+ filter: Optional[Union[dict, list]] = None,
+ predicates: Optional[Predicates] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ docs_and_scores = await self.asimilarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter, predicates=predicates, **kwargs
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ @classmethod
+ def from_texts(
+ cls: Type[TimescaleVector],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> TimescaleVector:
+ """
+ Return VectorStore initialized from texts and embeddings.
+ Postgres connection string is required
+ "Either pass it as a parameter
+ or set the TIMESCALE_SERVICE_URL environment variable.
+ """
+ embeddings = embedding.embed_documents(list(texts))
+
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ async def afrom_texts(
+ cls: Type[TimescaleVector],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> TimescaleVector:
+ """
+ Return VectorStore initialized from texts and embeddings.
+ Postgres connection string is required
+ "Either pass it as a parameter
+ or set the TIMESCALE_SERVICE_URL environment variable.
+ """
+ embeddings = embedding.embed_documents(list(texts))
+
+ return await cls.__afrom(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> TimescaleVector:
+ """Construct TimescaleVector wrapper from raw documents and pre-
+ generated embeddings.
+
+ Return VectorStore initialized from documents and embeddings.
+ Postgres connection string is required
+ "Either pass it as a parameter
+ or set the TIMESCALE_SERVICE_URL environment variable.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import TimescaleVector
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ tvs = TimescaleVector.from_embeddings(text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+
+ return cls.__from(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ async def afrom_embeddings(
+ cls,
+ text_embeddings: List[Tuple[str, List[float]]],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> TimescaleVector:
+ """Construct TimescaleVector wrapper from raw documents and pre-
+ generated embeddings.
+
+ Return VectorStore initialized from documents and embeddings.
+ Postgres connection string is required
+ "Either pass it as a parameter
+ or set the TIMESCALE_SERVICE_URL environment variable.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import TimescaleVector
+ from langchain_community.embeddings import OpenAIEmbeddings
+ embeddings = OpenAIEmbeddings()
+ text_embeddings = embeddings.embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ tvs = TimescaleVector.from_embeddings(text_embedding_pairs, embeddings)
+ """
+ texts = [t[0] for t in text_embeddings]
+ embeddings = [t[1] for t in text_embeddings]
+
+ return await cls.__afrom(
+ texts,
+ embeddings,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_existing_index(
+ cls: Type[TimescaleVector],
+ embedding: Embeddings,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> TimescaleVector:
+ """
+ Get instance of an existing TimescaleVector store.This method will
+ return the instance of the store without inserting any new
+ embeddings
+ """
+
+ service_url = cls.get_service_url(kwargs)
+
+ store = cls(
+ service_url=service_url,
+ collection_name=collection_name,
+ embedding=embedding,
+ distance_strategy=distance_strategy,
+ pre_delete_collection=pre_delete_collection,
+ )
+
+ return store
+
+ @classmethod
+ def get_service_url(cls, kwargs: Dict[str, Any]) -> str:
+ service_url: str = get_from_dict_or_env(
+ data=kwargs,
+ key="service_url",
+ env_key="TIMESCALE_SERVICE_URL",
+ )
+
+ if not service_url:
+ raise ValueError(
+ "Postgres connection string is required"
+ "Either pass it as a parameter"
+ "or set the TIMESCALE_SERVICE_URL environment variable."
+ )
+
+ return service_url
+
+ @classmethod
+ def service_url_from_db_params(
+ cls,
+ host: str,
+ port: int,
+ database: str,
+ user: str,
+ password: str,
+ ) -> str:
+ """Return connection string from database parameters."""
+ return f"postgresql://{user}:{password}@{host}:{port}/{database}"
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+ if self.override_relevance_score_fn is not None:
+ return self.override_relevance_score_fn
+
+ # Default strategy is to rely on distance strategy provided
+ # in vectorstore constructor
+ if self._distance_strategy == DistanceStrategy.COSINE:
+ return self._cosine_relevance_score_fn
+ elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
+ return self._euclidean_relevance_score_fn
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ return self._max_inner_product_relevance_score_fn
+ else:
+ raise ValueError(
+ "No supported normalization function"
+ f" for distance_strategy of {self._distance_strategy}."
+ "Consider providing relevance_score_fn to TimescaleVector constructor."
+ )
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete by vector ID or other criteria.
+
+ Args:
+ ids: List of ids to delete.
+ **kwargs: Other keyword arguments that subclasses might use.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ self.sync_client.delete_by_ids(ids)
+ return True
+
+ # todo should this be part of delete|()?
+ def delete_by_metadata(
+ self, filter: Union[Dict[str, str], List[Dict[str, str]]], **kwargs: Any
+ ) -> Optional[bool]:
+ """Delete by vector ID or other criteria.
+
+ Args:
+ ids: List of ids to delete.
+ **kwargs: Other keyword arguments that subclasses might use.
+
+ Returns:
+ Optional[bool]: True if deletion is successful,
+ False otherwise, None if not implemented.
+ """
+
+ self.sync_client.delete_by_metadata(filter)
+ return True
+
+ class IndexType(str, enum.Enum):
+ """Enumerator for the supported Index types"""
+
+ TIMESCALE_VECTOR = "tsv"
+ PGVECTOR_IVFFLAT = "ivfflat"
+ PGVECTOR_HNSW = "hnsw"
+
+ DEFAULT_INDEX_TYPE = IndexType.TIMESCALE_VECTOR
+
+ def create_index(
+ self, index_type: Union[IndexType, str] = DEFAULT_INDEX_TYPE, **kwargs: Any
+ ) -> None:
+ try:
+ from timescale_vector import client
+ except ImportError:
+ raise ImportError(
+ "Could not import timescale_vector python package. "
+ "Please install it with `pip install timescale-vector`."
+ )
+
+ index_type = (
+ index_type.value if isinstance(index_type, self.IndexType) else index_type
+ )
+ if index_type == self.IndexType.PGVECTOR_IVFFLAT.value:
+ self.sync_client.create_embedding_index(client.IvfflatIndex(**kwargs))
+
+ if index_type == self.IndexType.PGVECTOR_HNSW.value:
+ self.sync_client.create_embedding_index(client.HNSWIndex(**kwargs))
+
+ if index_type == self.IndexType.TIMESCALE_VECTOR.value:
+ self.sync_client.create_embedding_index(
+ client.TimescaleVectorIndex(**kwargs)
+ )
+
+ def drop_index(self) -> None:
+ self.sync_client.drop_embedding_index()
diff --git a/libs/community/langchain_community/vectorstores/typesense.py b/libs/community/langchain_community/vectorstores/typesense.py
new file mode 100644
index 00000000000..662a9185148
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/typesense.py
@@ -0,0 +1,275 @@
+from __future__ import annotations
+
+import uuid
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.utils import get_from_env
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ from typesense.client import Client
+ from typesense.collection import Collection
+
+
+class Typesense(VectorStore):
+ """`Typesense` vector store.
+
+ To use, you should have the ``typesense`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embedding.openai import OpenAIEmbeddings
+ from langchain_community.vectorstores import Typesense
+ import typesense
+
+ node = {
+ "host": "localhost", # For Typesense Cloud use xxx.a1.typesense.net
+ "port": "8108", # For Typesense Cloud use 443
+ "protocol": "http" # For Typesense Cloud use https
+ }
+ typesense_client = typesense.Client(
+ {
+ "nodes": [node],
+ "api_key": "",
+ "connection_timeout_seconds": 2
+ }
+ )
+ typesense_collection_name = "langchain-memory"
+
+ embedding = OpenAIEmbeddings()
+ vectorstore = Typesense(
+ typesense_client=typesense_client,
+ embedding=embedding,
+ typesense_collection_name=typesense_collection_name,
+ text_key="text",
+ )
+ """
+
+ def __init__(
+ self,
+ typesense_client: Client,
+ embedding: Embeddings,
+ *,
+ typesense_collection_name: Optional[str] = None,
+ text_key: str = "text",
+ ):
+ """Initialize with Typesense client."""
+ try:
+ from typesense import Client
+ except ImportError:
+ raise ImportError(
+ "Could not import typesense python package. "
+ "Please install it with `pip install typesense`."
+ )
+ if not isinstance(typesense_client, Client):
+ raise ValueError(
+ f"typesense_client should be an instance of typesense.Client, "
+ f"got {type(typesense_client)}"
+ )
+ self._typesense_client = typesense_client
+ self._embedding = embedding
+ self._typesense_collection_name = (
+ typesense_collection_name or f"langchain-{str(uuid.uuid4())}"
+ )
+ self._text_key = text_key
+
+ @property
+ def _collection(self) -> Collection:
+ return self._typesense_client.collections[self._typesense_collection_name]
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embedding
+
+ def _prep_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]],
+ ids: Optional[List[str]],
+ ) -> List[dict]:
+ """Embed and create the documents"""
+ _ids = ids or (str(uuid.uuid4()) for _ in texts)
+ _metadatas: Iterable[dict] = metadatas or ({} for _ in texts)
+ embedded_texts = self._embedding.embed_documents(list(texts))
+ return [
+ {"id": _id, "vec": vec, f"{self._text_key}": text, "metadata": metadata}
+ for _id, vec, text, metadata in zip(_ids, embedded_texts, texts, _metadatas)
+ ]
+
+ def _create_collection(self, num_dim: int) -> None:
+ fields = [
+ {"name": "vec", "type": "float[]", "num_dim": num_dim},
+ {"name": f"{self._text_key}", "type": "string"},
+ {"name": ".*", "type": "auto"},
+ ]
+ self._typesense_client.collections.create(
+ {"name": self._typesense_collection_name, "fields": fields}
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embedding and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids to associate with the texts.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+
+ """
+ from typesense.exceptions import ObjectNotFound
+
+ docs = self._prep_texts(texts, metadatas, ids)
+ try:
+ self._collection.documents.import_(docs, {"action": "upsert"})
+ except ObjectNotFound:
+ # Create the collection if it doesn't already exist
+ self._create_collection(len(docs[0]["vec"]))
+ self._collection.documents.import_(docs, {"action": "upsert"})
+ return [doc["id"] for doc in docs]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 10,
+ filter: Optional[str] = "",
+ ) -> List[Tuple[Document, float]]:
+ """Return typesense documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 10.
+ Minimum 10 results would be returned.
+ filter: typesense filter_by expression to filter documents on
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ embedded_query = [str(x) for x in self._embedding.embed_query(query)]
+ query_obj = {
+ "q": "*",
+ "vector_query": f'vec:([{",".join(embedded_query)}], k:{k})',
+ "filter_by": filter,
+ "collection": self._typesense_collection_name,
+ }
+ docs = []
+ response = self._typesense_client.multi_search.perform(
+ {"searches": [query_obj]}, {}
+ )
+ for hit in response["results"][0]["hits"]:
+ document = hit["document"]
+ metadata = document["metadata"]
+ text = document[self._text_key]
+ score = hit["vector_distance"]
+ docs.append((Document(page_content=text, metadata=metadata), score))
+ return docs
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 10,
+ filter: Optional[str] = "",
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return typesense documents most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 10.
+ Minimum 10 results would be returned.
+ filter: typesense filter_by expression to filter documents on
+
+ Returns:
+ List of Documents most similar to the query and score for each
+ """
+ docs_and_score = self.similarity_search_with_score(query, k=k, filter=filter)
+ return [doc for doc, _ in docs_and_score]
+
+ @classmethod
+ def from_client_params(
+ cls,
+ embedding: Embeddings,
+ *,
+ host: str = "localhost",
+ port: Union[str, int] = "8108",
+ protocol: str = "http",
+ typesense_api_key: Optional[str] = None,
+ connection_timeout_seconds: int = 2,
+ **kwargs: Any,
+ ) -> Typesense:
+ """Initialize Typesense directly from client parameters.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embedding.openai import OpenAIEmbeddings
+ from langchain_community.vectorstores import Typesense
+
+ # Pass in typesense_api_key as kwarg or set env var "TYPESENSE_API_KEY".
+ vectorstore = Typesense(
+ OpenAIEmbeddings(),
+ host="localhost",
+ port="8108",
+ protocol="http",
+ typesense_collection_name="langchain-memory",
+ )
+ """
+ try:
+ from typesense import Client
+ except ImportError:
+ raise ValueError(
+ "Could not import typesense python package. "
+ "Please install it with `pip install typesense`."
+ )
+
+ node = {
+ "host": host,
+ "port": str(port),
+ "protocol": protocol,
+ }
+ typesense_api_key = typesense_api_key or get_from_env(
+ "typesense_api_key", "TYPESENSE_API_KEY"
+ )
+ client_config = {
+ "nodes": [node],
+ "api_key": typesense_api_key,
+ "connection_timeout_seconds": connection_timeout_seconds,
+ }
+ return cls(Client(client_config), embedding, **kwargs)
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ typesense_client: Optional[Client] = None,
+ typesense_client_params: Optional[dict] = None,
+ typesense_collection_name: Optional[str] = None,
+ text_key: str = "text",
+ **kwargs: Any,
+ ) -> Typesense:
+ """Construct Typesense wrapper from raw text."""
+ if typesense_client:
+ vectorstore = cls(typesense_client, embedding, **kwargs)
+ elif typesense_client_params:
+ vectorstore = cls.from_client_params(
+ embedding, **typesense_client_params, **kwargs
+ )
+ else:
+ raise ValueError(
+ "Must specify one of typesense_client or typesense_client_params."
+ )
+ vectorstore.add_texts(texts, metadatas=metadatas, ids=ids)
+ return vectorstore
diff --git a/libs/community/langchain_community/vectorstores/usearch.py b/libs/community/langchain_community/vectorstores/usearch.py
new file mode 100644
index 00000000000..6ca66e630c8
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/usearch.py
@@ -0,0 +1,176 @@
+from __future__ import annotations
+
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.docstore.base import AddableMixin, Docstore
+from langchain_community.docstore.in_memory import InMemoryDocstore
+
+
+def dependable_usearch_import() -> Any:
+ """
+ Import usearch if available, otherwise raise error.
+ """
+ try:
+ import usearch.index
+ except ImportError:
+ raise ImportError(
+ "Could not import usearch python package. "
+ "Please install it with `pip install usearch` "
+ )
+ return usearch.index
+
+
+class USearch(VectorStore):
+ """`USearch` vector store.
+
+ To use, you should have the ``usearch`` python package installed.
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ index: Any,
+ docstore: Docstore,
+ ids: List[str],
+ ):
+ """Initialize with necessary components."""
+ self.embedding = embedding
+ self.index = index
+ self.docstore = docstore
+ self.ids = ids
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict]] = None,
+ ids: Optional[np.ndarray] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of unique IDs.
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ if not isinstance(self.docstore, AddableMixin):
+ raise ValueError(
+ "If trying to add texts, the underlying docstore should support "
+ f"adding items, which {self.docstore} does not"
+ )
+
+ embeddings = self.embedding.embed_documents(list(texts))
+ documents = []
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+ documents.append(Document(page_content=text, metadata=metadata))
+ last_id = int(self.ids[-1]) + 1
+ if ids is None:
+ ids = np.array([str(last_id + id) for id, _ in enumerate(texts)])
+
+ self.index.add(np.array(ids), np.array(embeddings))
+ self.docstore.add(dict(zip(ids, documents)))
+ self.ids.extend(ids)
+ return ids.tolist()
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of documents most similar to the query with distance.
+ """
+ query_embedding = self.embedding.embed_query(query)
+ matches = self.index.search(np.array(query_embedding), k)
+
+ docs_with_scores: List[Tuple[Document, float]] = []
+ for id, score in zip(matches.keys, matches.distances):
+ doc = self.docstore.search(str(id))
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {id}, got {doc}")
+ docs_with_scores.append((doc, score))
+
+ return docs_with_scores
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ query_embedding = self.embedding.embed_query(query)
+ matches = self.index.search(np.array(query_embedding), k)
+
+ docs: List[Document] = []
+ for id in matches.keys:
+ doc = self.docstore.search(str(id))
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {id}, got {doc}")
+ docs.append(doc)
+
+ return docs
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[Dict]] = None,
+ ids: Optional[np.ndarray] = None,
+ metric: str = "cos",
+ **kwargs: Any,
+ ) -> USearch:
+ """Construct USearch wrapper from raw documents.
+ This is a user friendly interface that:
+ 1. Embeds documents.
+ 2. Creates an in memory docstore
+ 3. Initializes the USearch database
+ This is intended to be a quick way to get started.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import USearch
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embeddings = OpenAIEmbeddings()
+ usearch = USearch.from_texts(texts, embeddings)
+ """
+ embeddings = embedding.embed_documents(texts)
+
+ documents: List[Document] = []
+ if ids is None:
+ ids = np.array([str(id) for id, _ in enumerate(texts)])
+ for i, text in enumerate(texts):
+ metadata = metadatas[i] if metadatas else {}
+ documents.append(Document(page_content=text, metadata=metadata))
+
+ docstore = InMemoryDocstore(dict(zip(ids, documents)))
+ usearch = dependable_usearch_import()
+ index = usearch.Index(ndim=len(embeddings[0]), metric=metric)
+ index.add(np.array(ids), np.array(embeddings))
+ return cls(embedding, index, docstore, ids.tolist())
diff --git a/libs/community/langchain_community/vectorstores/utils.py b/libs/community/langchain_community/vectorstores/utils.py
new file mode 100644
index 00000000000..a21fcec2163
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/utils.py
@@ -0,0 +1,74 @@
+"""Utility functions for working with vectors and vectorstores."""
+
+from enum import Enum
+from typing import List, Tuple, Type
+
+import numpy as np
+from langchain_core.documents import Document
+
+from langchain_community.utils.math import cosine_similarity
+
+
+class DistanceStrategy(str, Enum):
+ """Enumerator of the Distance strategies for calculating distances
+ between vectors."""
+
+ EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
+ MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
+ DOT_PRODUCT = "DOT_PRODUCT"
+ JACCARD = "JACCARD"
+ COSINE = "COSINE"
+
+
+def maximal_marginal_relevance(
+ query_embedding: np.ndarray,
+ embedding_list: list,
+ lambda_mult: float = 0.5,
+ k: int = 4,
+) -> List[int]:
+ """Calculate maximal marginal relevance."""
+ if min(k, len(embedding_list)) <= 0:
+ return []
+ if query_embedding.ndim == 1:
+ query_embedding = np.expand_dims(query_embedding, axis=0)
+ similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
+ most_similar = int(np.argmax(similarity_to_query))
+ idxs = [most_similar]
+ selected = np.array([embedding_list[most_similar]])
+ while len(idxs) < min(k, len(embedding_list)):
+ best_score = -np.inf
+ idx_to_add = -1
+ similarity_to_selected = cosine_similarity(embedding_list, selected)
+ for i, query_score in enumerate(similarity_to_query):
+ if i in idxs:
+ continue
+ redundant_score = max(similarity_to_selected[i])
+ equation_score = (
+ lambda_mult * query_score - (1 - lambda_mult) * redundant_score
+ )
+ if equation_score > best_score:
+ best_score = equation_score
+ idx_to_add = i
+ idxs.append(idx_to_add)
+ selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
+ return idxs
+
+
+def filter_complex_metadata(
+ documents: List[Document],
+ *,
+ allowed_types: Tuple[Type, ...] = (str, bool, int, float),
+) -> List[Document]:
+ """Filter out metadata types that are not supported for a vector store."""
+ updated_documents = []
+ for document in documents:
+ filtered_metadata = {}
+ for key, value in document.metadata.items():
+ if not isinstance(value, allowed_types):
+ continue
+ filtered_metadata[key] = value
+
+ document.metadata = filtered_metadata
+ updated_documents.append(document)
+
+ return updated_documents
diff --git a/libs/community/langchain_community/vectorstores/vald.py b/libs/community/langchain_community/vectorstores/vald.py
new file mode 100644
index 00000000000..6b6abae4e92
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/vald.py
@@ -0,0 +1,419 @@
+"""Wrapper around Vald vector database."""
+from __future__ import annotations
+
+from typing import Any, Iterable, List, Optional, Tuple, Type
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+
+class Vald(VectorStore):
+ """Wrapper around Vald vector database.
+
+ To use, you should have the ``vald-client-python`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import HuggingFaceEmbeddings
+ from langchain_community.vectorstores import Vald
+
+ texts = ['foo', 'bar', 'baz']
+ vald = Vald.from_texts(
+ texts=texts,
+ embedding=HuggingFaceEmbeddings(),
+ host="localhost",
+ port=8080,
+ skip_strict_exist_check=False,
+ )
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ host: str = "localhost",
+ port: int = 8080,
+ grpc_options: Tuple = (
+ ("grpc.keepalive_time_ms", 1000 * 10),
+ ("grpc.keepalive_timeout_ms", 1000 * 10),
+ ),
+ grpc_use_secure: bool = False,
+ grpc_credentials: Optional[Any] = None,
+ ):
+ self._embedding = embedding
+ self.target = host + ":" + str(port)
+ self.grpc_options = grpc_options
+ self.grpc_use_secure = grpc_use_secure
+ self.grpc_credentials = grpc_credentials
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embedding
+
+ def _get_channel(self) -> Any:
+ try:
+ import grpc
+ except ImportError:
+ raise ValueError(
+ "Could not import grpcio python package. "
+ "Please install it with `pip install grpcio`."
+ )
+ return (
+ grpc.secure_channel(
+ self.target, self.grpc_credentials, options=self.grpc_options
+ )
+ if self.grpc_use_secure
+ else grpc.insecure_channel(self.target, options=self.grpc_options)
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ grpc_metadata: Optional[Any] = None,
+ skip_strict_exist_check: bool = False,
+ **kwargs: Any,
+ ) -> List[str]:
+ """
+ Args:
+ skip_strict_exist_check: Deprecated. This is not used basically.
+ """
+ try:
+ from vald.v1.payload import payload_pb2
+ from vald.v1.vald import upsert_pb2_grpc
+ except ImportError:
+ raise ValueError(
+ "Could not import vald-client-python python package. "
+ "Please install it with `pip install vald-client-python`."
+ )
+
+ channel = self._get_channel()
+ # Depending on the network quality,
+ # it is necessary to wait for ChannelConnectivity.READY.
+ # _ = grpc.channel_ready_future(channel).result(timeout=10)
+ stub = upsert_pb2_grpc.UpsertStub(channel)
+ cfg = payload_pb2.Upsert.Config(skip_strict_exist_check=skip_strict_exist_check)
+
+ ids = []
+ embs = self._embedding.embed_documents(list(texts))
+ for text, emb in zip(texts, embs):
+ vec = payload_pb2.Object.Vector(id=text, vector=emb)
+ res = stub.Upsert(
+ payload_pb2.Upsert.Request(vector=vec, config=cfg),
+ metadata=grpc_metadata,
+ )
+ ids.append(res.uuid)
+
+ channel.close()
+ return ids
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ skip_strict_exist_check: bool = False,
+ grpc_metadata: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> Optional[bool]:
+ """
+ Args:
+ skip_strict_exist_check: Deprecated. This is not used basically.
+ """
+ try:
+ from vald.v1.payload import payload_pb2
+ from vald.v1.vald import remove_pb2_grpc
+ except ImportError:
+ raise ValueError(
+ "Could not import vald-client-python python package. "
+ "Please install it with `pip install vald-client-python`."
+ )
+
+ if ids is None:
+ raise ValueError("No ids provided to delete")
+
+ channel = self._get_channel()
+ # Depending on the network quality,
+ # it is necessary to wait for ChannelConnectivity.READY.
+ # _ = grpc.channel_ready_future(channel).result(timeout=10)
+ stub = remove_pb2_grpc.RemoveStub(channel)
+ cfg = payload_pb2.Remove.Config(skip_strict_exist_check=skip_strict_exist_check)
+
+ for _id in ids:
+ oid = payload_pb2.Object.ID(id=_id)
+ _ = stub.Remove(
+ payload_pb2.Remove.Request(id=oid, config=cfg), metadata=grpc_metadata
+ )
+
+ channel.close()
+ return True
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ radius: float = -1.0,
+ epsilon: float = 0.01,
+ timeout: int = 3000000000,
+ grpc_metadata: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ docs_and_scores = self.similarity_search_with_score(
+ query, k, radius, epsilon, timeout, grpc_metadata
+ )
+
+ docs = []
+ for doc, _ in docs_and_scores:
+ docs.append(doc)
+
+ return docs
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ radius: float = -1.0,
+ epsilon: float = 0.01,
+ timeout: int = 3000000000,
+ grpc_metadata: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ emb = self._embedding.embed_query(query)
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ emb, k, radius, epsilon, timeout, grpc_metadata
+ )
+
+ return docs_and_scores
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ radius: float = -1.0,
+ epsilon: float = 0.01,
+ timeout: int = 3000000000,
+ grpc_metadata: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding, k, radius, epsilon, timeout, grpc_metadata
+ )
+
+ docs = []
+ for doc, _ in docs_and_scores:
+ docs.append(doc)
+
+ return docs
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ radius: float = -1.0,
+ epsilon: float = 0.01,
+ timeout: int = 3000000000,
+ grpc_metadata: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ try:
+ from vald.v1.payload import payload_pb2
+ from vald.v1.vald import search_pb2_grpc
+ except ImportError:
+ raise ValueError(
+ "Could not import vald-client-python python package. "
+ "Please install it with `pip install vald-client-python`."
+ )
+
+ channel = self._get_channel()
+ # Depending on the network quality,
+ # it is necessary to wait for ChannelConnectivity.READY.
+ # _ = grpc.channel_ready_future(channel).result(timeout=10)
+ stub = search_pb2_grpc.SearchStub(channel)
+ cfg = payload_pb2.Search.Config(
+ num=k, radius=radius, epsilon=epsilon, timeout=timeout
+ )
+
+ res = stub.Search(
+ payload_pb2.Search.Request(vector=embedding, config=cfg),
+ metadata=grpc_metadata,
+ )
+
+ docs_and_scores = []
+ for result in res.results:
+ docs_and_scores.append((Document(page_content=result.id), result.distance))
+
+ channel.close()
+ return docs_and_scores
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ radius: float = -1.0,
+ epsilon: float = 0.01,
+ timeout: int = 3000000000,
+ grpc_metadata: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ emb = self._embedding.embed_query(query)
+ docs = self.max_marginal_relevance_search_by_vector(
+ emb,
+ k=k,
+ fetch_k=fetch_k,
+ radius=radius,
+ epsilon=epsilon,
+ timeout=timeout,
+ lambda_mult=lambda_mult,
+ grpc_metadata=grpc_metadata,
+ )
+
+ return docs
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ radius: float = -1.0,
+ epsilon: float = 0.01,
+ timeout: int = 3000000000,
+ grpc_metadata: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ try:
+ from vald.v1.payload import payload_pb2
+ from vald.v1.vald import object_pb2_grpc
+ except ImportError:
+ raise ValueError(
+ "Could not import vald-client-python python package. "
+ "Please install it with `pip install vald-client-python`."
+ )
+ channel = self._get_channel()
+ # Depending on the network quality,
+ # it is necessary to wait for ChannelConnectivity.READY.
+ # _ = grpc.channel_ready_future(channel).result(timeout=10)
+ stub = object_pb2_grpc.ObjectStub(channel)
+
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding,
+ fetch_k=fetch_k,
+ radius=radius,
+ epsilon=epsilon,
+ timeout=timeout,
+ grpc_metadata=grpc_metadata,
+ )
+
+ docs = []
+ embs = []
+ for doc, _ in docs_and_scores:
+ vec = stub.GetObject(
+ payload_pb2.Object.VectorRequest(
+ id=payload_pb2.Object.ID(id=doc.page_content)
+ ),
+ metadata=grpc_metadata,
+ )
+ embs.append(vec.vector)
+ docs.append(doc)
+
+ mmr = maximal_marginal_relevance(
+ np.array(embedding),
+ embs,
+ lambda_mult=lambda_mult,
+ k=k,
+ )
+
+ channel.close()
+ return [docs[i] for i in mmr]
+
+ @classmethod
+ def from_texts(
+ cls: Type[Vald],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ host: str = "localhost",
+ port: int = 8080,
+ grpc_options: Tuple = (
+ ("grpc.keepalive_time_ms", 1000 * 10),
+ ("grpc.keepalive_timeout_ms", 1000 * 10),
+ ),
+ grpc_use_secure: bool = False,
+ grpc_credentials: Optional[Any] = None,
+ grpc_metadata: Optional[Any] = None,
+ skip_strict_exist_check: bool = False,
+ **kwargs: Any,
+ ) -> Vald:
+ """
+ Args:
+ skip_strict_exist_check: Deprecated. This is not used basically.
+ """
+ vald = cls(
+ embedding=embedding,
+ host=host,
+ port=port,
+ grpc_options=grpc_options,
+ grpc_use_secure=grpc_use_secure,
+ grpc_credentials=grpc_credentials,
+ **kwargs,
+ )
+ vald.add_texts(
+ texts=texts,
+ metadatas=metadatas,
+ grpc_metadata=grpc_metadata,
+ skip_strict_exist_check=skip_strict_exist_check,
+ )
+ return vald
+
+
+"""We will support if there are any requests."""
+# async def aadd_texts(
+# self,
+# texts: Iterable[str],
+# metadatas: Optional[List[dict]] = None,
+# **kwargs: Any,
+# ) -> List[str]:
+# pass
+#
+# def _select_relevance_score_fn(self) -> Callable[[float], float]:
+# pass
+#
+# def _similarity_search_with_relevance_scores(
+# self,
+# query: str,
+# k: int = 4,
+# **kwargs: Any,
+# ) -> List[Tuple[Document, float]]:
+# pass
+#
+# def similarity_search_with_relevance_scores(
+# self,
+# query: str,
+# k: int = 4,
+# **kwargs: Any,
+# ) -> List[Tuple[Document, float]]:
+# pass
+#
+# async def amax_marginal_relevance_search_by_vector(
+# self,
+# embedding: List[float],
+# k: int = 4,
+# fetch_k: int = 20,
+# lambda_mult: float = 0.5,
+# **kwargs: Any,
+# ) -> List[Document]:
+# pass
+#
+# @classmethod
+# async def afrom_texts(
+# cls: Type[VST],
+# texts: List[str],
+# embedding: Embeddings,
+# metadatas: Optional[List[dict]] = None,
+# **kwargs: Any,
+# ) -> VST:
+# pass
diff --git a/libs/community/langchain_community/vectorstores/vearch.py b/libs/community/langchain_community/vectorstores/vearch.py
new file mode 100644
index 00000000000..5fac20f7b8a
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/vearch.py
@@ -0,0 +1,577 @@
+from __future__ import annotations
+
+import os
+import time
+import uuid
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ import vearch
+
+DEFAULT_TOPN = 4
+
+
+class Vearch(VectorStore):
+ _DEFAULT_TABLE_NAME = "langchain_vearch"
+ _DEFAULT_CLUSTER_DB_NAME = "cluster_client_db"
+ _DEFAULT_VERSION = 1
+
+ def __init__(
+ self,
+ embedding_function: Embeddings,
+ path_or_url: Optional[str] = None,
+ table_name: str = _DEFAULT_TABLE_NAME,
+ db_name: str = _DEFAULT_CLUSTER_DB_NAME,
+ flag: int = _DEFAULT_VERSION,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize vearch vector store
+ flag 1 for cluster,0 for standalone
+ """
+ try:
+ if flag:
+ import vearch_cluster
+ else:
+ import vearch
+ except ImportError:
+ raise ValueError(
+ "Could not import suitable python package. "
+ "Please install it with `pip install vearch or vearch_cluster`."
+ )
+
+ if flag:
+ if path_or_url is None:
+ raise ValueError("Please input url of cluster")
+ if not db_name:
+ db_name = self._DEFAULT_CLUSTER_DB_NAME
+ db_name += "_"
+ db_name += str(uuid.uuid4()).split("-")[-1]
+ self.using_db_name = db_name
+ self.url = path_or_url
+ self.vearch = vearch_cluster.VearchCluster(path_or_url)
+
+ else:
+ if path_or_url is None:
+ metadata_path = os.getcwd().replace("\\", "/")
+ else:
+ metadata_path = path_or_url
+ if not os.path.isdir(metadata_path):
+ os.makedirs(metadata_path)
+ log_path = os.path.join(metadata_path, "log")
+ if not os.path.isdir(log_path):
+ os.makedirs(log_path)
+ self.vearch = vearch.Engine(metadata_path, log_path)
+ self.using_metapath = metadata_path
+ if not table_name:
+ table_name = self._DEFAULT_TABLE_NAME
+ table_name += "_"
+ table_name += str(uuid.uuid4()).split("-")[-1]
+ self.using_table_name = table_name
+ self.embedding_func = embedding_function
+ self.flag = flag
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self.embedding_func
+
+ @classmethod
+ def from_documents(
+ cls: Type[Vearch],
+ documents: List[Document],
+ embedding: Embeddings,
+ path_or_url: Optional[str] = None,
+ table_name: str = _DEFAULT_TABLE_NAME,
+ db_name: str = _DEFAULT_CLUSTER_DB_NAME,
+ flag: int = _DEFAULT_VERSION,
+ **kwargs: Any,
+ ) -> Vearch:
+ """Return Vearch VectorStore"""
+
+ texts = [d.page_content for d in documents]
+ metadatas = [d.metadata for d in documents]
+
+ return cls.from_texts(
+ texts=texts,
+ embedding=embedding,
+ metadatas=metadatas,
+ path_or_url=path_or_url,
+ table_name=table_name,
+ db_name=db_name,
+ flag=flag,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_texts(
+ cls: Type[Vearch],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ path_or_url: Optional[str] = None,
+ table_name: str = _DEFAULT_TABLE_NAME,
+ db_name: str = _DEFAULT_CLUSTER_DB_NAME,
+ flag: int = _DEFAULT_VERSION,
+ **kwargs: Any,
+ ) -> Vearch:
+ """Return Vearch VectorStore"""
+
+ vearch_db = cls(
+ embedding_function=embedding,
+ embedding=embedding,
+ path_or_url=path_or_url,
+ db_name=db_name,
+ table_name=table_name,
+ flag=flag,
+ )
+ vearch_db.add_texts(texts=texts, metadatas=metadatas)
+ return vearch_db
+
+ def _create_table(
+ self,
+ dim: int = 1024,
+ field_list: List[dict] = [
+ {"field": "text", "type": "str"},
+ {"field": "metadata", "type": "str"},
+ ],
+ ) -> int:
+ """
+ Create VectorStore Table
+ Args:
+ dim:dimension of vector
+ fields_list: the field you want to store
+ Return:
+ code,0 for success,1 for failed
+ """
+
+ type_dict = {"int": vearch.dataType.INT, "str": vearch.dataType.STRING}
+ engine_info = {
+ "index_size": 10000,
+ "retrieval_type": "IVFPQ",
+ "retrieval_param": {"ncentroids": 2048, "nsubvector": 32},
+ }
+ fields = [
+ vearch.GammaFieldInfo(fi["field"], type_dict[fi["type"]])
+ for fi in field_list
+ ]
+ vector_field = vearch.GammaVectorInfo(
+ name="text_embedding",
+ type=vearch.dataType.VECTOR,
+ is_index=True,
+ dimension=dim,
+ model_id="",
+ store_type="MemoryOnly",
+ store_param={"cache_size": 10000},
+ has_source=False,
+ )
+ response_code = self.vearch.create_table(
+ engine_info,
+ name=self.using_table_name,
+ fields=fields,
+ vector_field=vector_field,
+ )
+ return response_code
+
+ def _create_space(
+ self,
+ dim: int = 1024,
+ ) -> int:
+ """
+ Create VectorStore space
+ Args:
+ dim:dimension of vector
+ Return:
+ code,0 failed for ,1 for success
+ """
+ space_config = {
+ "name": self.using_table_name,
+ "partition_num": 1,
+ "replica_num": 1,
+ "engine": {
+ "name": "gamma",
+ "index_size": 1,
+ "retrieval_type": "FLAT",
+ "retrieval_param": {
+ "metric_type": "L2",
+ },
+ },
+ "properties": {
+ "text": {
+ "type": "string",
+ },
+ "metadata": {
+ "type": "string",
+ },
+ "text_embedding": {
+ "type": "vector",
+ "index": True,
+ "dimension": dim,
+ "store_type": "MemoryOnly",
+ },
+ },
+ }
+ response_code = self.vearch.create_space(self.using_db_name, space_config)
+
+ return response_code
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ embeddings = None
+ if self.embedding_func is not None:
+ embeddings = self.embedding_func.embed_documents(list(texts))
+ if embeddings is None:
+ raise ValueError("embeddings is None")
+ if self.flag:
+ dbs_list = self.vearch.list_dbs()
+ if self.using_db_name not in dbs_list:
+ create_db_code = self.vearch.create_db(self.using_db_name)
+ if not create_db_code:
+ raise ValueError("create db failed!!!")
+ space_list = self.vearch.list_spaces(self.using_db_name)
+ if self.using_table_name not in space_list:
+ create_space_code = self._create_space(len(embeddings[0]))
+ if not create_space_code:
+ raise ValueError("create space failed!!!")
+ docid = []
+ if embeddings is not None and metadatas is not None:
+ for text, metadata, embed in zip(texts, metadatas, embeddings):
+ profiles: dict[str, Any] = {}
+ profiles["text"] = text
+ profiles["metadata"] = metadata["source"]
+ embed_np = np.array(embed)
+ profiles["text_embedding"] = {
+ "feature": (embed_np / np.linalg.norm(embed_np)).tolist()
+ }
+ insert_res = self.vearch.insert_one(
+ self.using_db_name, self.using_table_name, profiles
+ )
+ if insert_res["status"] == 200:
+ docid.append(insert_res["_id"])
+ continue
+ else:
+ retry_insert = self.vearch.insert_one(
+ self.using_db_name, self.using_table_name, profiles
+ )
+ docid.append(retry_insert["_id"])
+ continue
+ else:
+ table_path = os.path.join(
+ self.using_metapath, self.using_table_name + ".schema"
+ )
+ if not os.path.exists(table_path):
+ dim = len(embeddings[0])
+ response_code = self._create_table(dim)
+ if response_code:
+ raise ValueError("create table failed!!!")
+ if embeddings is not None and metadatas is not None:
+ doc_items = []
+ for text, metadata, embed in zip(texts, metadatas, embeddings):
+ profiles_v: dict[str, Any] = {}
+ profiles_v["text"] = text
+ profiles_v["metadata"] = metadata["source"]
+ embed_np = np.array(embed)
+ profiles_v["text_embedding"] = embed_np / np.linalg.norm(embed_np)
+ doc_items.append(profiles_v)
+
+ docid = self.vearch.add(doc_items)
+ t_time = 0
+ while len(docid) != len(embeddings):
+ time.sleep(0.5)
+ if t_time > 6:
+ break
+ t_time += 1
+ self.vearch.dump()
+ return docid
+
+ def _load(self) -> None:
+ """
+ load vearch engine for standalone vearch
+ """
+ self.vearch.load()
+
+ @classmethod
+ def load_local(
+ cls,
+ embedding: Embeddings,
+ path_or_url: Optional[str] = None,
+ table_name: str = _DEFAULT_TABLE_NAME,
+ db_name: str = _DEFAULT_CLUSTER_DB_NAME,
+ flag: int = _DEFAULT_VERSION,
+ **kwargs: Any,
+ ) -> Vearch:
+ """Load the local specified table of standalone vearch.
+ Returns:
+ Success or failure of loading the local specified table
+ """
+ if not path_or_url:
+ raise ValueError("No metadata path!!!")
+ if not table_name:
+ raise ValueError("No table name!!!")
+ table_path = os.path.join(path_or_url, table_name + ".schema")
+ if not os.path.exists(table_path):
+ raise ValueError("vearch vectorbase table not exist!!!")
+
+ vearch_db = cls(
+ embedding_function=embedding,
+ path_or_url=path_or_url,
+ table_name=table_name,
+ db_name=db_name,
+ flag=flag,
+ )
+ vearch_db._load()
+ return vearch_db
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = DEFAULT_TOPN,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """
+ Return docs most similar to query.
+
+ """
+ if self.embedding_func is None:
+ raise ValueError("embedding_func is None!!!")
+ embeddings = self.embedding_func.embed_query(query)
+ docs = self.similarity_search_by_vector(embeddings, k)
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = DEFAULT_TOPN,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """The most k similar documents and scores of the specified query.
+ Args:
+ embeddings: embedding vector of the query.
+ k: The k most similar documents to the text query.
+ min_score: the score of similar documents to the text query
+ Returns:
+ The k most similar documents to the specified text query.
+ 0 is dissimilar, 1 is the most similar.
+ """
+ embed = np.array(embedding)
+ if self.flag:
+ query_data = {
+ "query": {
+ "sum": [
+ {
+ "field": "text_embedding",
+ "feature": (embed / np.linalg.norm(embed)).tolist(),
+ }
+ ],
+ },
+ "size": k,
+ "fields": ["text", "metadata"],
+ }
+ query_result = self.vearch.search(
+ self.using_db_name, self.using_table_name, query_data
+ )
+ res = query_result["hits"]["hits"]
+ else:
+ query_data = {
+ "vector": [
+ {
+ "field": "text_embedding",
+ "feature": embed / np.linalg.norm(embed),
+ }
+ ],
+ "fields": [],
+ "is_brute_search": 1,
+ "retrieval_param": {"metric_type": "InnerProduct", "nprobe": 20},
+ "topn": k,
+ }
+ query_result = self.vearch.search(query_data)
+ res = query_result[0]["result_items"]
+ docs = []
+ for item in res:
+ content = ""
+ meta_data = {}
+ if self.flag:
+ item = item["_source"]
+ for item_key in item:
+ if item_key == "text":
+ content = item[item_key]
+ continue
+ if item_key == "metadata":
+ meta_data["source"] = item[item_key]
+ continue
+ docs.append(Document(page_content=content, metadata=meta_data))
+ return docs
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = DEFAULT_TOPN,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """The most k similar documents and scores of the specified query.
+ Args:
+ embeddings: embedding vector of the query.
+ k: The k most similar documents to the text query.
+ min_score: the score of similar documents to the text query
+ Returns:
+ The k most similar documents to the specified text query.
+ 0 is dissimilar, 1 is the most similar.
+ """
+ if self.embedding_func is None:
+ raise ValueError("embedding_func is None!!!")
+ embeddings = self.embedding_func.embed_query(query)
+ embed = np.array(embeddings)
+ if self.flag:
+ query_data = {
+ "query": {
+ "sum": [
+ {
+ "field": "text_embedding",
+ "feature": (embed / np.linalg.norm(embed)).tolist(),
+ }
+ ],
+ },
+ "size": k,
+ "fields": ["text_embedding", "text", "metadata"],
+ }
+ query_result = self.vearch.search(
+ self.using_db_name, self.using_table_name, query_data
+ )
+ res = query_result["hits"]["hits"]
+ else:
+ query_data = {
+ "vector": [
+ {
+ "field": "text_embedding",
+ "feature": embed / np.linalg.norm(embed),
+ }
+ ],
+ "fields": [],
+ "is_brute_search": 1,
+ "retrieval_param": {"metric_type": "InnerProduct", "nprobe": 20},
+ "topn": k,
+ }
+ query_result = self.vearch.search(query_data)
+ res = query_result[0]["result_items"]
+ results: List[Tuple[Document, float]] = []
+ for item in res:
+ content = ""
+ meta_data = {}
+ if self.flag:
+ score = item["_score"]
+ item = item["_source"]
+ for item_key in item:
+ if item_key == "text":
+ content = item[item_key]
+ continue
+ if item_key == "metadata":
+ meta_data["source"] = item[item_key]
+ continue
+ if self.flag != 1 and item_key == "score":
+ score = item[item_key]
+ continue
+ tmp_res = (Document(page_content=content, metadata=meta_data), score)
+ results.append(tmp_res)
+ return results
+
+ def _similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ return self.similarity_search_with_score(query, k, **kwargs)
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Optional[bool]:
+ """Delete the documents which have the specified ids.
+
+ Args:
+ ids: The ids of the embedding vectors.
+ **kwargs: Other keyword arguments that subclasses might use.
+ Returns:
+ Optional[bool]: True if deletion is successful.
+ False otherwise, None if not implemented.
+ """
+
+ ret: Optional[bool] = None
+ tmp_res = []
+ if ids is None or ids.__len__() == 0:
+ return ret
+ for _id in ids:
+ if self.flag:
+ ret = self.vearch.delete(self.using_db_name, self.using_table_name, _id)
+ else:
+ ret = self.vearch.del_doc(_id)
+ tmp_res.append(ret)
+ ret = all(i == 0 for i in tmp_res)
+ return ret
+
+ def get(
+ self,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Document]:
+ """Return docs according ids.
+
+ Args:
+ ids: The ids of the embedding vectors.
+ Returns:
+ Documents which satisfy the input conditions.
+ """
+
+ results: Dict[str, Document] = {}
+ if ids is None or ids.__len__() == 0:
+ return results
+ if self.flag:
+ query_data = {"query": {"ids": ids}}
+ docs_detail = self.vearch.mget_by_ids(
+ self.using_db_name, self.using_table_name, query_data
+ )
+ for record in docs_detail:
+ if record["found"] is False:
+ continue
+ content = ""
+ meta_info = {}
+ for field in record["_source"]:
+ if field == "text":
+ content = record["_source"][field]
+ continue
+ elif field == "metadata":
+ meta_info["source"] = record["_source"][field]
+ continue
+ results[record["_id"]] = Document(
+ page_content=content, metadata=meta_info
+ )
+ else:
+ for id in ids:
+ docs_detail = self.vearch.get_doc_by_id(id)
+ if docs_detail == {}:
+ continue
+ content = ""
+ meta_info = {}
+ for field in docs_detail:
+ if field == "text":
+ content = docs_detail[field]
+ continue
+ elif field == "metadata":
+ meta_info["source"] = docs_detail[field]
+ continue
+ results[docs_detail["_id"]] = Document(
+ page_content=content, metadata=meta_info
+ )
+ return results
diff --git a/libs/community/langchain_community/vectorstores/vectara.py b/libs/community/langchain_community/vectorstores/vectara.py
new file mode 100644
index 00000000000..5c5ae249d32
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/vectara.py
@@ -0,0 +1,481 @@
+from __future__ import annotations
+
+import json
+import logging
+import os
+from hashlib import md5
+from typing import Any, Iterable, List, Optional, Tuple, Type
+
+import requests
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.pydantic_v1 import Field
+from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
+
+logger = logging.getLogger(__name__)
+
+
+class Vectara(VectorStore):
+ """`Vectara API` vector store.
+
+ See (https://vectara.com).
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Vectara
+
+ vectorstore = Vectara(
+ vectara_customer_id=vectara_customer_id,
+ vectara_corpus_id=vectara_corpus_id,
+ vectara_api_key=vectara_api_key
+ )
+ """
+
+ def __init__(
+ self,
+ vectara_customer_id: Optional[str] = None,
+ vectara_corpus_id: Optional[str] = None,
+ vectara_api_key: Optional[str] = None,
+ vectara_api_timeout: int = 120,
+ source: str = "langchain",
+ ):
+ """Initialize with Vectara API."""
+ self._vectara_customer_id = vectara_customer_id or os.environ.get(
+ "VECTARA_CUSTOMER_ID"
+ )
+ self._vectara_corpus_id = vectara_corpus_id or os.environ.get(
+ "VECTARA_CORPUS_ID"
+ )
+ self._vectara_api_key = vectara_api_key or os.environ.get("VECTARA_API_KEY")
+ if (
+ self._vectara_customer_id is None
+ or self._vectara_corpus_id is None
+ or self._vectara_api_key is None
+ ):
+ logger.warning(
+ "Can't find Vectara credentials, customer_id or corpus_id in "
+ "environment."
+ )
+ else:
+ logger.debug(f"Using corpus id {self._vectara_corpus_id}")
+ self._source = source
+
+ self._session = requests.Session() # to reuse connections
+ adapter = requests.adapters.HTTPAdapter(max_retries=3)
+ self._session.mount("http://", adapter)
+ self.vectara_api_timeout = vectara_api_timeout
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return None
+
+ def _get_post_headers(self) -> dict:
+ """Returns headers that should be attached to each post request."""
+ return {
+ "x-api-key": self._vectara_api_key,
+ "customer-id": self._vectara_customer_id,
+ "Content-Type": "application/json",
+ "X-Source": self._source,
+ }
+
+ def _delete_doc(self, doc_id: str) -> bool:
+ """
+ Delete a document from the Vectara corpus.
+
+ Args:
+ url (str): URL of the page to delete.
+ doc_id (str): ID of the document to delete.
+
+ Returns:
+ bool: True if deletion was successful, False otherwise.
+ """
+ body = {
+ "customer_id": self._vectara_customer_id,
+ "corpus_id": self._vectara_corpus_id,
+ "document_id": doc_id,
+ }
+ response = self._session.post(
+ "https://api.vectara.io/v1/delete-doc",
+ data=json.dumps(body),
+ verify=True,
+ headers=self._get_post_headers(),
+ timeout=self.vectara_api_timeout,
+ )
+ if response.status_code != 200:
+ logger.error(
+ f"Delete request failed for doc_id = {doc_id} with status code "
+ f"{response.status_code}, reason {response.reason}, text "
+ f"{response.text}"
+ )
+ return False
+ return True
+
+ def _index_doc(self, doc: dict) -> str:
+ request: dict[str, Any] = {}
+ request["customer_id"] = self._vectara_customer_id
+ request["corpus_id"] = self._vectara_corpus_id
+ request["document"] = doc
+
+ response = self._session.post(
+ headers=self._get_post_headers(),
+ url="https://api.vectara.io/v1/index",
+ data=json.dumps(request),
+ timeout=self.vectara_api_timeout,
+ verify=True,
+ )
+
+ status_code = response.status_code
+
+ result = response.json()
+ status_str = result["status"]["code"] if "status" in result else None
+ if status_code == 409 or status_str and (status_str == "ALREADY_EXISTS"):
+ return "E_ALREADY_EXISTS"
+ elif status_str and (status_str == "FORBIDDEN"):
+ return "E_NO_PERMISSIONS"
+ else:
+ return "E_SUCCEEDED"
+
+ def add_files(
+ self,
+ files_list: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """
+ Vectara provides a way to add documents directly via our API where
+ pre-processing and chunking occurs internally in an optimal way
+ This method provides a way to use that API in LangChain
+
+ Args:
+ files_list: Iterable of strings, each representing a local file path.
+ Files could be text, HTML, PDF, markdown, doc/docx, ppt/pptx, etc.
+ see API docs for full list
+ metadatas: Optional list of metadatas associated with each file
+
+ Returns:
+ List of ids associated with each of the files indexed
+ """
+ doc_ids = []
+ for inx, file in enumerate(files_list):
+ if not os.path.exists(file):
+ logger.error(f"File {file} does not exist, skipping")
+ continue
+ md = metadatas[inx] if metadatas else {}
+ files: dict = {
+ "file": (file, open(file, "rb")),
+ "doc_metadata": json.dumps(md),
+ }
+ headers = self._get_post_headers()
+ headers.pop("Content-Type")
+ response = self._session.post(
+ f"https://api.vectara.io/upload?c={self._vectara_customer_id}&o={self._vectara_corpus_id}&d=True",
+ files=files,
+ verify=True,
+ headers=headers,
+ timeout=self.vectara_api_timeout,
+ )
+
+ if response.status_code == 409:
+ doc_id = response.json()["document"]["documentId"]
+ logger.info(
+ f"File {file} already exists on Vectara (doc_id={doc_id}), skipping"
+ )
+ elif response.status_code == 200:
+ doc_id = response.json()["document"]["documentId"]
+ doc_ids.append(doc_id)
+ else:
+ logger.info(f"Error indexing file {file}: {response.json()}")
+
+ return doc_ids
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ doc_metadata: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ doc_metadata: optional metadata for the document
+
+ This function indexes all the input text strings in the Vectara corpus as a
+ single Vectara document, where each input text is considered a "section" and the
+ metadata are associated with each section.
+ if 'doc_metadata' is provided, it is associated with the Vectara document.
+
+ Returns:
+ document ID of the document added
+
+ """
+ doc_hash = md5()
+ for t in texts:
+ doc_hash.update(t.encode())
+ doc_id = doc_hash.hexdigest()
+ if metadatas is None:
+ metadatas = [{} for _ in texts]
+ if doc_metadata:
+ doc_metadata["source"] = "langchain"
+ else:
+ doc_metadata = {"source": "langchain"}
+ doc = {
+ "document_id": doc_id,
+ "metadataJson": json.dumps(doc_metadata),
+ "section": [
+ {"text": text, "metadataJson": json.dumps(md)}
+ for text, md in zip(texts, metadatas)
+ ],
+ }
+
+ success_str = self._index_doc(doc)
+ if success_str == "E_ALREADY_EXISTS":
+ self._delete_doc(doc_id)
+ self._index_doc(doc)
+ elif success_str == "E_NO_PERMISSIONS":
+ print(
+ """No permissions to add document to Vectara.
+ Check your corpus ID, customer ID and API key"""
+ )
+ return [doc_id]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 5,
+ lambda_val: float = 0.025,
+ filter: Optional[str] = None,
+ score_threshold: Optional[float] = None,
+ n_sentence_context: int = 2,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return Vectara documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 5.
+ lambda_val: lexical match parameter for hybrid search.
+ filter: Dictionary of argument(s) to filter on metadata. For example a
+ filter can be "doc.rating > 3.0 and part.lang = 'deu'"} see
+ https://docs.vectara.com/docs/search-apis/sql/filter-overview
+ for more details.
+ score_threshold: minimal score threshold for the result.
+ If defined, results with score less than this value will be
+ filtered out.
+ n_sentence_context: number of sentences before/after the matching segment
+ to add, defaults to 2
+
+ Returns:
+ List of Documents most similar to the query and score for each.
+ """
+ data = json.dumps(
+ {
+ "query": [
+ {
+ "query": query,
+ "start": 0,
+ "num_results": k,
+ "context_config": {
+ "sentences_before": n_sentence_context,
+ "sentences_after": n_sentence_context,
+ },
+ "corpus_key": [
+ {
+ "customer_id": self._vectara_customer_id,
+ "corpus_id": self._vectara_corpus_id,
+ "metadataFilter": filter,
+ "lexical_interpolation_config": {"lambda": lambda_val},
+ }
+ ],
+ }
+ ]
+ }
+ )
+
+ response = self._session.post(
+ headers=self._get_post_headers(),
+ url="https://api.vectara.io/v1/query",
+ data=data,
+ timeout=self.vectara_api_timeout,
+ )
+
+ if response.status_code != 200:
+ logger.error(
+ "Query failed %s",
+ f"(code {response.status_code}, reason {response.reason}, details "
+ f"{response.text})",
+ )
+ return []
+
+ result = response.json()
+ if score_threshold:
+ responses = [
+ r
+ for r in result["responseSet"][0]["response"]
+ if r["score"] > score_threshold
+ ]
+ else:
+ responses = result["responseSet"][0]["response"]
+ documents = result["responseSet"][0]["document"]
+
+ metadatas = []
+ for x in responses:
+ md = {m["name"]: m["value"] for m in x["metadata"]}
+ doc_num = x["documentIndex"]
+ doc_md = {m["name"]: m["value"] for m in documents[doc_num]["metadata"]}
+ md.update(doc_md)
+ metadatas.append(md)
+
+ docs_with_score = [
+ (
+ Document(
+ page_content=x["text"],
+ metadata=md,
+ ),
+ x["score"],
+ )
+ for x, md in zip(responses, metadatas)
+ ]
+
+ return docs_with_score
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 5,
+ lambda_val: float = 0.025,
+ filter: Optional[str] = None,
+ n_sentence_context: int = 2,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return Vectara documents most similar to query, along with scores.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 5.
+ filter: Dictionary of argument(s) to filter on metadata. For example a
+ filter can be "doc.rating > 3.0 and part.lang = 'deu'"} see
+ https://docs.vectara.com/docs/search-apis/sql/filter-overview for more
+ details.
+ n_sentence_context: number of sentences before/after the matching segment
+ to add, defaults to 2
+
+ Returns:
+ List of Documents most similar to the query
+ """
+ docs_and_scores = self.similarity_search_with_score(
+ query,
+ k=k,
+ lambda_val=lambda_val,
+ filter=filter,
+ score_threshold=None,
+ n_sentence_context=n_sentence_context,
+ **kwargs,
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ @classmethod
+ def from_texts(
+ cls: Type[Vectara],
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> Vectara:
+ """Construct Vectara wrapper from raw documents.
+ This is intended to be a quick way to get started.
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Vectara
+ vectara = Vectara.from_texts(
+ texts,
+ vectara_customer_id=customer_id,
+ vectara_corpus_id=corpus_id,
+ vectara_api_key=api_key,
+ )
+ """
+ # Notes:
+ # * Vectara generates its own embeddings, so we ignore the provided
+ # embeddings (required by interface)
+ # * when metadatas[] are provided they are associated with each "part"
+ # in Vectara. doc_metadata can be used to provide additional metadata
+ # for the document itself (applies to all "texts" in this call)
+ doc_metadata = kwargs.pop("doc_metadata", {})
+ vectara = cls(**kwargs)
+ vectara.add_texts(texts, metadatas, doc_metadata=doc_metadata, **kwargs)
+ return vectara
+
+ @classmethod
+ def from_files(
+ cls: Type[Vectara],
+ files: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> Vectara:
+ """Construct Vectara wrapper from raw documents.
+ This is intended to be a quick way to get started.
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Vectara
+ vectara = Vectara.from_files(
+ files_list,
+ vectara_customer_id=customer_id,
+ vectara_corpus_id=corpus_id,
+ vectara_api_key=api_key,
+ )
+ """
+ # Note: Vectara generates its own embeddings, so we ignore the provided
+ # embeddings (required by interface)
+ vectara = cls(**kwargs)
+ vectara.add_files(files, metadatas)
+ return vectara
+
+ def as_retriever(self, **kwargs: Any) -> VectaraRetriever:
+ tags = kwargs.pop("tags", None) or []
+ tags.extend(self._get_retriever_tags())
+ return VectaraRetriever(vectorstore=self, search_kwargs=kwargs, tags=tags)
+
+
+class VectaraRetriever(VectorStoreRetriever):
+ """Retriever class for `Vectara`."""
+
+ vectorstore: Vectara
+ """Vectara vectorstore."""
+ search_kwargs: dict = Field(
+ default_factory=lambda: {
+ "lambda_val": 0.0,
+ "k": 5,
+ "filter": "",
+ "n_sentence_context": "2",
+ }
+ )
+
+ """Search params.
+ k: Number of Documents to return. Defaults to 5.
+ lambda_val: lexical match parameter for hybrid search.
+ filter: Dictionary of argument(s) to filter on metadata. For example a
+ filter can be "doc.rating > 3.0 and part.lang = 'deu'"} see
+ https://docs.vectara.com/docs/search-apis/sql/filter-overview
+ for more details.
+ n_sentence_context: number of sentences before/after the matching segment to add
+ """
+
+ def add_texts(
+ self,
+ texts: List[str],
+ metadatas: Optional[List[dict]] = None,
+ doc_metadata: Optional[dict] = None,
+ ) -> None:
+ """Add text to the Vectara vectorstore.
+
+ Args:
+ texts (List[str]): The text
+ metadatas (List[dict]): Metadata dicts, must line up with existing store
+ """
+ self.vectorstore.add_texts(texts, metadatas, doc_metadata or {})
diff --git a/libs/community/langchain_community/vectorstores/vespa.py b/libs/community/langchain_community/vectorstores/vespa.py
new file mode 100644
index 00000000000..6fe6585cd86
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/vespa.py
@@ -0,0 +1,267 @@
+from __future__ import annotations
+
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
+
+
+class VespaStore(VectorStore):
+ """
+ `Vespa` vector store.
+
+ To use, you should have the python client library ``pyvespa`` installed.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import VespaStore
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ from vespa.application import Vespa
+
+ # Create a vespa client dependent upon your application,
+ # e.g. either connecting to Vespa Cloud or a local deployment
+ # such as Docker. Please refer to the PyVespa documentation on
+ # how to initialize the client.
+
+ vespa_app = Vespa(url="...", port=..., application_package=...)
+
+ # You need to instruct LangChain on which fields to use for embeddings
+ vespa_config = dict(
+ page_content_field="text",
+ embedding_field="embedding",
+ input_field="query_embedding",
+ metadata_fields=["date", "rating", "author"]
+ )
+
+ embedding_function = OpenAIEmbeddings()
+ vectorstore = VespaStore(vespa_app, embedding_function, **vespa_config)
+
+ """
+
+ def __init__(
+ self,
+ app: Any,
+ embedding_function: Optional[Embeddings] = None,
+ page_content_field: Optional[str] = None,
+ embedding_field: Optional[str] = None,
+ input_field: Optional[str] = None,
+ metadata_fields: Optional[List[str]] = None,
+ ) -> None:
+ """
+ Initialize with a PyVespa client.
+ """
+ try:
+ from vespa.application import Vespa
+ except ImportError:
+ raise ImportError(
+ "Could not import Vespa python package. "
+ "Please install it with `pip install pyvespa`."
+ )
+ if not isinstance(app, Vespa):
+ raise ValueError(
+ f"app should be an instance of vespa.application.Vespa, got {type(app)}"
+ )
+
+ self._vespa_app = app
+ self._embedding_function = embedding_function
+ self._page_content_field = page_content_field
+ self._embedding_field = embedding_field
+ self._input_field = input_field
+ self._metadata_fields = metadata_fields
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """
+ Add texts to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ ids: Optional list of ids associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+
+ embeddings = None
+ if self._embedding_function is not None:
+ embeddings = self._embedding_function.embed_documents(list(texts))
+
+ if ids is None:
+ ids = [str(f"{i+1}") for i, _ in enumerate(texts)]
+
+ batch = []
+ for i, text in enumerate(texts):
+ fields: Dict[str, Union[str, List[float]]] = {}
+ if self._page_content_field is not None:
+ fields[self._page_content_field] = text
+ if self._embedding_field is not None and embeddings is not None:
+ fields[self._embedding_field] = embeddings[i]
+ if metadatas is not None and self._metadata_fields is not None:
+ for metadata_field in self._metadata_fields:
+ if metadata_field in metadatas[i]:
+ fields[metadata_field] = metadatas[i][metadata_field]
+ batch.append({"id": ids[i], "fields": fields})
+
+ results = self._vespa_app.feed_batch(batch)
+ for result in results:
+ if not (str(result.status_code).startswith("2")):
+ raise RuntimeError(
+ f"Could not add document to Vespa. "
+ f"Error code: {result.status_code}. "
+ f"Message: {result.json['message']}"
+ )
+ return ids
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ if ids is None:
+ return False
+ batch = [{"id": id} for id in ids]
+ result = self._vespa_app.delete_batch(batch)
+ return sum([0 if r.status_code == 200 else 1 for r in result]) == 0
+
+ def _create_query(
+ self, query_embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> Dict:
+ hits = k
+ doc_embedding_field = self._embedding_field
+ input_embedding_field = self._input_field
+ ranking_function = kwargs["ranking"] if "ranking" in kwargs else "default"
+ filter = kwargs["filter"] if "filter" in kwargs else None
+
+ approximate = kwargs["approximate"] if "approximate" in kwargs else False
+ approximate = "true" if approximate else "false"
+
+ yql = "select * from sources * where "
+ yql += f"{{targetHits: {hits}, approximate: {approximate}}}"
+ yql += f"nearestNeighbor({doc_embedding_field}, {input_embedding_field})"
+ if filter is not None:
+ yql += f" and {filter}"
+
+ query = {
+ "yql": yql,
+ f"input.query({input_embedding_field})": query_embedding,
+ "ranking": ranking_function,
+ "hits": hits,
+ }
+ return query
+
+ def similarity_search_by_vector_with_score(
+ self, query_embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """
+ Performs similarity search from a embeddings vector.
+
+ Args:
+ query_embedding: Embeddings vector to search for.
+ k: Number of results to return.
+ custom_query: Use this custom query instead default query (kwargs)
+ kwargs: other vector store specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ if "custom_query" in kwargs:
+ query = kwargs["custom_query"]
+ else:
+ query = self._create_query(query_embedding, k, **kwargs)
+
+ try:
+ response = self._vespa_app.query(body=query)
+ except Exception as e:
+ raise RuntimeError(
+ f"Could not retrieve data from Vespa: "
+ f"{e.args[0][0]['summary']}. "
+ f"Error: {e.args[0][0]['message']}"
+ )
+ if not str(response.status_code).startswith("2"):
+ raise RuntimeError(
+ f"Could not retrieve data from Vespa. "
+ f"Error code: {response.status_code}. "
+ f"Message: {response.json['message']}"
+ )
+
+ root = response.json["root"]
+ if "errors" in root:
+ import json
+
+ raise RuntimeError(json.dumps(root["errors"]))
+
+ if response is None or response.hits is None:
+ return []
+
+ docs = []
+ for child in response.hits:
+ page_content = child["fields"][self._page_content_field]
+ score = child["relevance"]
+ metadata = {"id": child["id"]}
+ if self._metadata_fields is not None:
+ for field in self._metadata_fields:
+ metadata[field] = child["fields"].get(field)
+ doc = Document(page_content=page_content, metadata=metadata)
+ docs.append((doc, score))
+ return docs
+
+ def similarity_search_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ results = self.similarity_search_by_vector_with_score(embedding, k, **kwargs)
+ return [r[0] for r in results]
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ query_emb = []
+ if self._embedding_function is not None:
+ query_emb = self._embedding_function.embed_query(query)
+ return self.similarity_search_by_vector_with_score(query_emb, k, **kwargs)
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ results = self.similarity_search_with_score(query, k, **kwargs)
+ return [r[0] for r in results]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ raise NotImplementedError("MMR search not implemented")
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ raise NotImplementedError("MMR search by vector not implemented")
+
+ @classmethod
+ def from_texts(
+ cls: Type[VespaStore],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> VespaStore:
+ vespa = cls(embedding_function=embedding, **kwargs)
+ vespa.add_texts(texts=texts, metadatas=metadatas, ids=ids)
+ return vespa
+
+ def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
+ return super().as_retriever(**kwargs)
diff --git a/libs/community/langchain_community/vectorstores/weaviate.py b/libs/community/langchain_community/vectorstores/weaviate.py
new file mode 100644
index 00000000000..119e7bc0d7e
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/weaviate.py
@@ -0,0 +1,528 @@
+from __future__ import annotations
+
+import datetime
+import os
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+)
+from uuid import uuid4
+
+import numpy as np
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.vectorstores.utils import maximal_marginal_relevance
+
+if TYPE_CHECKING:
+ import weaviate
+
+
+def _default_schema(index_name: str) -> Dict:
+ return {
+ "class": index_name,
+ "properties": [
+ {
+ "name": "text",
+ "dataType": ["text"],
+ }
+ ],
+ }
+
+
+def _create_weaviate_client(
+ url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ **kwargs: Any,
+) -> weaviate.Client:
+ try:
+ import weaviate
+ except ImportError:
+ raise ImportError(
+ "Could not import weaviate python package. "
+ "Please install it with `pip install weaviate-client`"
+ )
+ url = url or os.environ.get("WEAVIATE_URL")
+ api_key = api_key or os.environ.get("WEAVIATE_API_KEY")
+ auth = weaviate.auth.AuthApiKey(api_key=api_key) if api_key else None
+ return weaviate.Client(url=url, auth_client_secret=auth, **kwargs)
+
+
+def _default_score_normalizer(val: float) -> float:
+ return 1 - 1 / (1 + np.exp(val))
+
+
+def _json_serializable(value: Any) -> Any:
+ if isinstance(value, datetime.datetime):
+ return value.isoformat()
+ return value
+
+
+class Weaviate(VectorStore):
+ """`Weaviate` vector store.
+
+ To use, you should have the ``weaviate-client`` python package installed.
+
+ Example:
+ .. code-block:: python
+
+ import weaviate
+ from langchain_community.vectorstores import Weaviate
+
+ client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
+ weaviate = Weaviate(client, index_name, text_key)
+
+ """
+
+ def __init__(
+ self,
+ client: Any,
+ index_name: str,
+ text_key: str,
+ embedding: Optional[Embeddings] = None,
+ attributes: Optional[List[str]] = None,
+ relevance_score_fn: Optional[
+ Callable[[float], float]
+ ] = _default_score_normalizer,
+ by_text: bool = True,
+ ):
+ """Initialize with Weaviate client."""
+ try:
+ import weaviate
+ except ImportError:
+ raise ImportError(
+ "Could not import weaviate python package. "
+ "Please install it with `pip install weaviate-client`."
+ )
+ if not isinstance(client, weaviate.Client):
+ raise ValueError(
+ f"client should be an instance of weaviate.Client, got {type(client)}"
+ )
+ self._client = client
+ self._index_name = index_name
+ self._embedding = embedding
+ self._text_key = text_key
+ self._query_attrs = [self._text_key]
+ self.relevance_score_fn = relevance_score_fn
+ self._by_text = by_text
+ if attributes is not None:
+ self._query_attrs.extend(attributes)
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ return self._embedding
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ return (
+ self.relevance_score_fn
+ if self.relevance_score_fn
+ else _default_score_normalizer
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Upload texts with metadata (properties) to Weaviate."""
+ from weaviate.util import get_valid_uuid
+
+ ids = []
+ embeddings: Optional[List[List[float]]] = None
+ if self._embedding:
+ if not isinstance(texts, list):
+ texts = list(texts)
+ embeddings = self._embedding.embed_documents(texts)
+
+ with self._client.batch as batch:
+ for i, text in enumerate(texts):
+ data_properties = {self._text_key: text}
+ if metadatas is not None:
+ for key, val in metadatas[i].items():
+ data_properties[key] = _json_serializable(val)
+
+ # Allow for ids (consistent w/ other methods)
+ # # Or uuids (backwards compatible w/ existing arg)
+ # If the UUID of one of the objects already exists
+ # then the existing object will be replaced by the new object.
+ _id = get_valid_uuid(uuid4())
+ if "uuids" in kwargs:
+ _id = kwargs["uuids"][i]
+ elif "ids" in kwargs:
+ _id = kwargs["ids"][i]
+
+ batch.add_data_object(
+ data_object=data_properties,
+ class_name=self._index_name,
+ uuid=_id,
+ vector=embeddings[i] if embeddings else None,
+ tenant=kwargs.get("tenant"),
+ )
+ ids.append(_id)
+ return ids
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ if self._by_text:
+ return self.similarity_search_by_text(query, k, **kwargs)
+ else:
+ if self._embedding is None:
+ raise ValueError(
+ "_embedding cannot be None for similarity_search when "
+ "_by_text=False"
+ )
+ embedding = self._embedding.embed_query(query)
+ return self.similarity_search_by_vector(embedding, k, **kwargs)
+
+ def similarity_search_by_text(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ content: Dict[str, Any] = {"concepts": [query]}
+ if kwargs.get("search_distance"):
+ content["certainty"] = kwargs.get("search_distance")
+ query_obj = self._client.query.get(self._index_name, self._query_attrs)
+ if kwargs.get("where_filter"):
+ query_obj = query_obj.with_where(kwargs.get("where_filter"))
+ if kwargs.get("tenant"):
+ query_obj = query_obj.with_tenant(kwargs.get("tenant"))
+ if kwargs.get("additional"):
+ query_obj = query_obj.with_additional(kwargs.get("additional"))
+ result = query_obj.with_near_text(content).with_limit(k).do()
+ if "errors" in result:
+ raise ValueError(f"Error during query: {result['errors']}")
+ docs = []
+ for res in result["data"]["Get"][self._index_name]:
+ text = res.pop(self._text_key)
+ docs.append(Document(page_content=text, metadata=res))
+ return docs
+
+ def similarity_search_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Look up similar documents by embedding vector in Weaviate."""
+ vector = {"vector": embedding}
+ query_obj = self._client.query.get(self._index_name, self._query_attrs)
+ if kwargs.get("where_filter"):
+ query_obj = query_obj.with_where(kwargs.get("where_filter"))
+ if kwargs.get("tenant"):
+ query_obj = query_obj.with_tenant(kwargs.get("tenant"))
+ if kwargs.get("additional"):
+ query_obj = query_obj.with_additional(kwargs.get("additional"))
+ result = query_obj.with_near_vector(vector).with_limit(k).do()
+ if "errors" in result:
+ raise ValueError(f"Error during query: {result['errors']}")
+ docs = []
+ for res in result["data"]["Get"][self._index_name]:
+ text = res.pop(self._text_key)
+ docs.append(Document(page_content=text, metadata=res))
+ return docs
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ if self._embedding is not None:
+ embedding = self._embedding.embed_query(query)
+ else:
+ raise ValueError(
+ "max_marginal_relevance_search requires a suitable Embeddings object"
+ )
+
+ return self.max_marginal_relevance_search_by_vector(
+ embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
+ )
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ vector = {"vector": embedding}
+ query_obj = self._client.query.get(self._index_name, self._query_attrs)
+ if kwargs.get("where_filter"):
+ query_obj = query_obj.with_where(kwargs.get("where_filter"))
+ if kwargs.get("tenant"):
+ query_obj = query_obj.with_tenant(kwargs.get("tenant"))
+ results = (
+ query_obj.with_additional("vector")
+ .with_near_vector(vector)
+ .with_limit(fetch_k)
+ .do()
+ )
+
+ payload = results["data"]["Get"][self._index_name]
+ embeddings = [result["_additional"]["vector"] for result in payload]
+ mmr_selected = maximal_marginal_relevance(
+ np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
+ )
+
+ docs = []
+ for idx in mmr_selected:
+ text = payload[idx].pop(self._text_key)
+ payload[idx].pop("_additional")
+ meta = payload[idx]
+ docs.append(Document(page_content=text, metadata=meta))
+ return docs
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """
+ Return list of documents most similar to the query
+ text and cosine distance in float for each.
+ Lower score represents more similarity.
+ """
+ if self._embedding is None:
+ raise ValueError(
+ "_embedding cannot be None for similarity_search_with_score"
+ )
+ content: Dict[str, Any] = {"concepts": [query]}
+ if kwargs.get("search_distance"):
+ content["certainty"] = kwargs.get("search_distance")
+ query_obj = self._client.query.get(self._index_name, self._query_attrs)
+ if kwargs.get("where_filter"):
+ query_obj = query_obj.with_where(kwargs.get("where_filter"))
+ if kwargs.get("tenant"):
+ query_obj = query_obj.with_tenant(kwargs.get("tenant"))
+
+ embedded_query = self._embedding.embed_query(query)
+ if not self._by_text:
+ vector = {"vector": embedded_query}
+ result = (
+ query_obj.with_near_vector(vector)
+ .with_limit(k)
+ .with_additional("vector")
+ .do()
+ )
+ else:
+ result = (
+ query_obj.with_near_text(content)
+ .with_limit(k)
+ .with_additional("vector")
+ .do()
+ )
+
+ if "errors" in result:
+ raise ValueError(f"Error during query: {result['errors']}")
+
+ docs_and_scores = []
+ for res in result["data"]["Get"][self._index_name]:
+ text = res.pop(self._text_key)
+ score = np.dot(res["_additional"]["vector"], embedded_query)
+ docs_and_scores.append((Document(page_content=text, metadata=res), score))
+ return docs_and_scores
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ *,
+ client: Optional[weaviate.Client] = None,
+ weaviate_url: Optional[str] = None,
+ weaviate_api_key: Optional[str] = None,
+ batch_size: Optional[int] = None,
+ index_name: Optional[str] = None,
+ text_key: str = "text",
+ by_text: bool = False,
+ relevance_score_fn: Optional[
+ Callable[[float], float]
+ ] = _default_score_normalizer,
+ **kwargs: Any,
+ ) -> Weaviate:
+ """Construct Weaviate wrapper from raw documents.
+
+ This is a user-friendly interface that:
+ 1. Embeds documents.
+ 2. Creates a new index for the embeddings in the Weaviate instance.
+ 3. Adds the documents to the newly created Weaviate index.
+
+ This is intended to be a quick way to get started.
+
+ Args:
+ texts: Texts to add to vector store.
+ embedding: Text embedding model to use.
+ metadatas: Metadata associated with each text.
+ client: weaviate.Client to use.
+ weaviate_url: The Weaviate URL. If using Weaviate Cloud Services get it
+ from the ``Details`` tab. Can be passed in as a named param or by
+ setting the environment variable ``WEAVIATE_URL``. Should not be
+ specified if client is provided.
+ weaviate_api_key: The Weaviate API key. If enabled and using Weaviate Cloud
+ Services, get it from ``Details`` tab. Can be passed in as a named param
+ or by setting the environment variable ``WEAVIATE_API_KEY``. Should
+ not be specified if client is provided.
+ batch_size: Size of batch operations.
+ index_name: Index name.
+ text_key: Key to use for uploading/retrieving text to/from vectorstore.
+ by_text: Whether to search by text or by embedding.
+ relevance_score_fn: Function for converting whatever distance function the
+ vector store uses to a relevance score, which is a normalized similarity
+ score (0 means dissimilar, 1 means similar).
+ **kwargs: Additional named parameters to pass to ``Weaviate.__init__()``.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.embeddings import OpenAIEmbeddings
+ from langchain_community.vectorstores import Weaviate
+
+ embeddings = OpenAIEmbeddings()
+ weaviate = Weaviate.from_texts(
+ texts,
+ embeddings,
+ weaviate_url="http://localhost:8080"
+ )
+ """
+
+ try:
+ from weaviate.util import get_valid_uuid
+ except ImportError as e:
+ raise ImportError(
+ "Could not import weaviate python package. "
+ "Please install it with `pip install weaviate-client`"
+ ) from e
+
+ client = client or _create_weaviate_client(
+ url=weaviate_url,
+ api_key=weaviate_api_key,
+ )
+ if batch_size:
+ client.batch.configure(batch_size=batch_size)
+
+ index_name = index_name or f"LangChain_{uuid4().hex}"
+ schema = _default_schema(index_name)
+ # check whether the index already exists
+ if not client.schema.exists(index_name):
+ client.schema.create_class(schema)
+
+ embeddings = embedding.embed_documents(texts) if embedding else None
+ attributes = list(metadatas[0].keys()) if metadatas else None
+
+ # If the UUID of one of the objects already exists
+ # then the existing object will be replaced by the new object.
+ if "uuids" in kwargs:
+ uuids = kwargs.pop("uuids")
+ else:
+ uuids = [get_valid_uuid(uuid4()) for _ in range(len(texts))]
+
+ with client.batch as batch:
+ for i, text in enumerate(texts):
+ data_properties = {
+ text_key: text,
+ }
+ if metadatas is not None:
+ for key in metadatas[i].keys():
+ data_properties[key] = metadatas[i][key]
+
+ _id = uuids[i]
+
+ # if an embedding strategy is not provided, we let
+ # weaviate create the embedding. Note that this will only
+ # work if weaviate has been installed with a vectorizer module
+ # like text2vec-contextionary for example
+ params = {
+ "uuid": _id,
+ "data_object": data_properties,
+ "class_name": index_name,
+ }
+ if embeddings is not None:
+ params["vector"] = embeddings[i]
+
+ batch.add_data_object(**params)
+
+ batch.flush()
+
+ return cls(
+ client,
+ index_name,
+ text_key,
+ embedding=embedding,
+ attributes=attributes,
+ relevance_score_fn=relevance_score_fn,
+ by_text=by_text,
+ **kwargs,
+ )
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
+ """Delete by vector IDs.
+
+ Args:
+ ids: List of ids to delete.
+ """
+
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ # TODO: Check if this can be done in bulk
+ for id in ids:
+ self._client.data_object.delete(uuid=id)
diff --git a/libs/community/langchain_community/vectorstores/xata.py b/libs/community/langchain_community/vectorstores/xata.py
new file mode 100644
index 00000000000..bb2c0645369
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/xata.py
@@ -0,0 +1,263 @@
+from __future__ import annotations
+
+import time
+from itertools import repeat
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+
+class XataVectorStore(VectorStore):
+ """`Xata` vector store.
+
+ It assumes you have a Xata database
+ created with the right schema. See the guide at:
+ https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore
+
+ """
+
+ def __init__(
+ self,
+ api_key: str,
+ db_url: str,
+ embedding: Embeddings,
+ table_name: str,
+ ) -> None:
+ """Initialize with Xata client."""
+ try:
+ from xata.client import XataClient # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Could not import xata python package. "
+ "Please install it with `pip install xata`."
+ )
+ self._client = XataClient(api_key=api_key, db_url=db_url)
+ self._embedding: Embeddings = embedding
+ self._table_name = table_name or "vectors"
+
+ @property
+ def embeddings(self) -> Embeddings:
+ return self._embedding
+
+ def add_vectors(
+ self,
+ vectors: List[List[float]],
+ documents: List[Document],
+ ids: Optional[List[str]] = None,
+ ) -> List[str]:
+ return self._add_vectors(vectors, documents, ids)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ ids = ids
+ docs = self._texts_to_documents(texts, metadatas)
+
+ vectors = self._embedding.embed_documents(list(texts))
+ return self.add_vectors(vectors, docs, ids)
+
+ def _add_vectors(
+ self,
+ vectors: List[List[float]],
+ documents: List[Document],
+ ids: Optional[List[str]] = None,
+ ) -> List[str]:
+ """Add vectors to the Xata database."""
+
+ rows: List[Dict[str, Any]] = []
+ for idx, embedding in enumerate(vectors):
+ row = {
+ "content": documents[idx].page_content,
+ "embedding": embedding,
+ }
+ if ids:
+ row["id"] = ids[idx]
+ for key, val in documents[idx].metadata.items():
+ if key not in ["id", "content", "embedding"]:
+ row[key] = val
+ rows.append(row)
+
+ # XXX: I would have liked to use the BulkProcessor here, but it
+ # doesn't return the IDs, which we need here. Manual chunking it is.
+ chunk_size = 1000
+ id_list: List[str] = []
+ for i in range(0, len(rows), chunk_size):
+ chunk = rows[i : i + chunk_size]
+
+ r = self._client.records().bulk_insert(self._table_name, {"records": chunk})
+ if r.status_code != 200:
+ raise Exception(f"Error adding vectors to Xata: {r.status_code} {r}")
+ id_list.extend(r["recordIDs"])
+ return id_list
+
+ @staticmethod
+ def _texts_to_documents(
+ texts: Iterable[str],
+ metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
+ ) -> List[Document]:
+ """Return list of Documents from list of texts and metadatas."""
+ if metadatas is None:
+ metadatas = repeat({})
+
+ docs = [
+ Document(page_content=text, metadata=metadata)
+ for text, metadata in zip(texts, metadatas)
+ ]
+
+ return docs
+
+ @classmethod
+ def from_texts(
+ cls: Type["XataVectorStore"],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ api_key: Optional[str] = None,
+ db_url: Optional[str] = None,
+ table_name: str = "vectors",
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> "XataVectorStore":
+ """Return VectorStore initialized from texts and embeddings."""
+
+ if not api_key or not db_url:
+ raise ValueError("Xata api_key and db_url must be set.")
+
+ embeddings = embedding.embed_documents(texts)
+ ids = None # Xata will generate them for us
+ docs = cls._texts_to_documents(texts, metadatas)
+
+ vector_db = cls(
+ api_key=api_key,
+ db_url=db_url,
+ embedding=embedding,
+ table_name=table_name,
+ )
+
+ vector_db._add_vectors(embeddings, docs, ids)
+ return vector_db
+
+ def similarity_search(
+ self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
+ ) -> List[Document]:
+ """Return docs most similar to query.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+
+ Returns:
+ List of Documents most similar to the query.
+ """
+ docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
+ documents = [d[0] for d in docs_and_scores]
+ return documents
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Run similarity search with Chroma with distance.
+
+ Args:
+ query (str): Query text to search for.
+ k (int): Number of results to return. Defaults to 4.
+ filter (Optional[dict]): Filter by metadata. Defaults to None.
+
+ Returns:
+ List[Tuple[Document, float]]: List of documents most similar to the query
+ text with distance in float.
+ """
+ embedding = self._embedding.embed_query(query)
+ payload = {
+ "queryVector": embedding,
+ "column": "embedding",
+ "size": k,
+ }
+ if filter:
+ payload["filter"] = filter
+ r = self._client.data().vector_search(self._table_name, payload=payload)
+ if r.status_code != 200:
+ raise Exception(f"Error running similarity search: {r.status_code} {r}")
+ hits = r["records"]
+ docs_and_scores = [
+ (
+ Document(
+ page_content=hit["content"],
+ metadata=self._extractMetadata(hit),
+ ),
+ hit["xata"]["score"],
+ )
+ for hit in hits
+ ]
+ return docs_and_scores
+
+ def _extractMetadata(self, record: dict) -> dict:
+ """Extract metadata from a record. Filters out known columns."""
+ metadata = {}
+ for key, val in record.items():
+ if key not in ["id", "content", "embedding", "xata"]:
+ metadata[key] = val
+ return metadata
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ delete_all: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Delete by vector IDs.
+
+ Args:
+ ids: List of ids to delete.
+ delete_all: Delete all records in the table.
+ """
+ if delete_all:
+ self._delete_all()
+ self.wait_for_indexing(ndocs=0)
+ elif ids is not None:
+ chunk_size = 500
+ for i in range(0, len(ids), chunk_size):
+ chunk = ids[i : i + chunk_size]
+ operations = [
+ {"delete": {"table": self._table_name, "id": id}} for id in chunk
+ ]
+ self._client.records().transaction(payload={"operations": operations})
+ else:
+ raise ValueError("Either ids or delete_all must be set.")
+
+ def _delete_all(self) -> None:
+ """Delete all records in the table."""
+ while True:
+ r = self._client.data().query(self._table_name, payload={"columns": ["id"]})
+ if r.status_code != 200:
+ raise Exception(f"Error running query: {r.status_code} {r}")
+ ids = [rec["id"] for rec in r["records"]]
+ if len(ids) == 0:
+ break
+ operations = [
+ {"delete": {"table": self._table_name, "id": id}} for id in ids
+ ]
+ self._client.records().transaction(payload={"operations": operations})
+
+ def wait_for_indexing(self, timeout: float = 5, ndocs: int = 1) -> None:
+ """Wait for the search index to contain a certain number of
+ documents. Useful in tests.
+ """
+ start = time.time()
+ while True:
+ r = self._client.data().search_table(
+ self._table_name, payload={"query": "", "page": {"size": 0}}
+ )
+ if r.status_code != 200:
+ raise Exception(f"Error running search: {r.status_code} {r}")
+ if r["totalCount"] == ndocs:
+ break
+ if time.time() - start > timeout:
+ raise Exception("Timed out waiting for indexing to complete.")
+ time.sleep(0.5)
diff --git a/libs/community/langchain_community/vectorstores/yellowbrick.py b/libs/community/langchain_community/vectorstores/yellowbrick.py
new file mode 100644
index 00000000000..d7f1a159f3e
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/yellowbrick.py
@@ -0,0 +1,327 @@
+from __future__ import annotations
+
+import json
+import logging
+import uuid
+import warnings
+from itertools import repeat
+from typing import (
+ Any,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+from langchain_community.docstore.document import Document
+
+logger = logging.getLogger(__name__)
+
+
+class Yellowbrick(VectorStore):
+ """Wrapper around Yellowbrick as a vector database.
+ Example:
+ .. code-block:: python
+ from langchain_community.vectorstores import Yellowbrick
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
+ ...
+ """
+
+ def __init__(
+ self,
+ embedding: Embeddings,
+ connection_string: str,
+ table: str,
+ ) -> None:
+ """Initialize with yellowbrick client.
+ Args:
+ embedding: Embedding operator
+ connection_string: Format 'postgres://username:password@host:port/database'
+ table: Table used to store / retrieve embeddings from
+ """
+
+ import psycopg2
+
+ if not isinstance(embedding, Embeddings):
+ warnings.warn("embeddings input must be Embeddings object.")
+
+ self.connection_string = connection_string
+ self._table = table
+ self._embedding = embedding
+ self._connection = psycopg2.connect(connection_string)
+
+ self.__post_init__()
+
+ def __post_init__(
+ self,
+ ) -> None:
+ """Initialize the store."""
+ self.check_database_utf8()
+ self.create_table_if_not_exists()
+
+ def __del__(self) -> None:
+ if self._connection:
+ self._connection.close()
+
+ def create_table_if_not_exists(self) -> None:
+ """
+ Helper function: create table if not exists
+ """
+ from psycopg2 import sql
+
+ cursor = self._connection.cursor()
+ cursor.execute(
+ sql.SQL(
+ "CREATE TABLE IF NOT EXISTS {} ( \
+ id UUID, \
+ embedding_id INTEGER, \
+ text VARCHAR(60000), \
+ metadata VARCHAR(1024), \
+ embedding FLOAT)"
+ ).format(sql.Identifier(self._table))
+ )
+ self._connection.commit()
+ cursor.close()
+
+ def drop(self, table: str) -> None:
+ """
+ Helper function: Drop data
+ """
+ from psycopg2 import sql
+
+ cursor = self._connection.cursor()
+ cursor.execute(sql.SQL("DROP TABLE IF EXISTS {}").format(sql.Identifier(table)))
+ self._connection.commit()
+ cursor.close()
+
+ def check_database_utf8(self) -> bool:
+ """
+ Helper function: Test the database is UTF-8 encoded
+ """
+ cursor = self._connection.cursor()
+ query = "SELECT pg_encoding_to_char(encoding) \
+ FROM pg_database \
+ WHERE datname = current_database();"
+ cursor.execute(query)
+ encoding = cursor.fetchone()[0]
+ cursor.close()
+ if encoding.lower() == "utf8" or encoding.lower() == "utf-8":
+ return True
+ else:
+ raise Exception(
+ f"Database \
+ '{self.connection_string.split('/')[-1]}' encoding is not UTF-8"
+ )
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add more texts to the vectorstore index.
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+ """
+ from psycopg2 import sql
+
+ texts = list(texts)
+ cursor = self._connection.cursor()
+ embeddings = self._embedding.embed_documents(list(texts))
+ results = []
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+ for id in range(len(embeddings)):
+ doc_uuid = uuid.uuid4()
+ results.append(str(doc_uuid))
+ data_input = [
+ (str(id), embedding_id, text, json.dumps(metadata), embedding)
+ for id, embedding_id, text, metadata, embedding in zip(
+ repeat(doc_uuid),
+ range(len(embeddings[id])),
+ repeat(texts[id]),
+ repeat(metadatas[id]),
+ embeddings[id],
+ )
+ ]
+ flattened_input = [val for sublist in data_input for val in sublist]
+ insert_query = sql.SQL(
+ "INSERT INTO {t} \
+ (id, embedding_id, text, metadata, embedding) VALUES {v}"
+ ).format(
+ t=sql.Identifier(self._table),
+ v=(
+ sql.SQL(",").join(
+ [
+ sql.SQL("(%s,%s,%s,%s,%s)")
+ for _ in range(len(embeddings[id]))
+ ]
+ )
+ ),
+ )
+ cursor.execute(insert_query, flattened_input)
+ self._connection.commit()
+ return results
+
+ @classmethod
+ def from_texts(
+ cls: Type[Yellowbrick],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ connection_string: str = "",
+ table: str = "langchain",
+ **kwargs: Any,
+ ) -> Yellowbrick:
+ """Add texts to the vectorstore index.
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ connection_string: URI to Yellowbrick instance
+ embedding: Embedding function
+ table: table to store embeddings
+ kwargs: vectorstore specific parameters
+ """
+ if connection_string is None:
+ raise ValueError("connection_string must be provided")
+ vss = cls(
+ embedding=embedding,
+ connection_string=connection_string,
+ table=table,
+ )
+ vss.add_texts(texts=texts, metadatas=metadatas)
+ return vss
+
+ def similarity_search_with_score_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Perform a similarity search with Yellowbrick with vector
+
+ Args:
+ embedding (List[float]): query embedding
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+
+ NOTE: Please do not let end-user fill this and always be aware
+ of SQL injection.
+
+ Returns:
+ List[Document, float]: List of Documents and scores
+ """
+ from psycopg2 import sql
+
+ cursor = self._connection.cursor()
+ tmp_table = "tmp_" + self._table
+ cursor.execute(
+ sql.SQL(
+ "CREATE TEMPORARY TABLE {} ( \
+ embedding_id INTEGER, embedding FLOAT)"
+ ).format(sql.Identifier(tmp_table))
+ )
+ self._connection.commit()
+
+ data_input = [
+ (embedding_id, embedding)
+ for embedding_id, embedding in zip(range(len(embedding)), embedding)
+ ]
+ flattened_input = [val for sublist in data_input for val in sublist]
+ insert_query = sql.SQL(
+ "INSERT INTO {t} \
+ (embedding_id, embedding) VALUES {v}"
+ ).format(
+ t=sql.Identifier(tmp_table),
+ v=sql.SQL(",").join([sql.SQL("(%s,%s)") for _ in range(len(embedding))]),
+ )
+ cursor.execute(insert_query, flattened_input)
+ self._connection.commit()
+ sql_query = sql.SQL(
+ "SELECT text, \
+ metadata, \
+ sum(v1.embedding * v2.embedding) / \
+ ( sqrt(sum(v1.embedding * v1.embedding)) * \
+ sqrt(sum(v2.embedding * v2.embedding))) AS score \
+ FROM {v1} v1 INNER JOIN {v2} v2 \
+ ON v1.embedding_id = v2.embedding_id \
+ GROUP BY v2.id, v2.text, v2.metadata \
+ ORDER BY score DESC \
+ LIMIT %s"
+ ).format(v1=sql.Identifier(tmp_table), v2=sql.Identifier(self._table))
+ cursor.execute(sql_query, (k,))
+ results = cursor.fetchall()
+ self.drop(tmp_table)
+
+ documents: List[Tuple[Document, float]] = []
+ for result in results:
+ metadata = json.loads(result[1]) or {}
+ doc = Document(page_content=result[0], metadata=metadata)
+ documents.append((doc, result[2]))
+
+ cursor.close()
+ return documents
+
+ def similarity_search(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Perform a similarity search with Yellowbrick
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+
+ NOTE: Please do not let end-user fill this and always be aware
+ of SQL injection.
+
+ Returns:
+ List[Document]: List of Documents
+ """
+ embedding = self._embedding.embed_query(query)
+ documents = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k
+ )
+ return [doc for doc, _ in documents]
+
+ def similarity_search_with_score(
+ self, query: str, k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
+ """Perform a similarity search with Yellowbrick
+
+ Args:
+ query (str): query string
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+
+ NOTE: Please do not let end-user fill this and always be aware
+ of SQL injection.
+
+ Returns:
+ List[Document]: List of (Document, similarity)
+ """
+ embedding = self._embedding.embed_query(query)
+ documents = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k
+ )
+ return documents
+
+ def similarity_search_by_vector(
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Document]:
+ """Perform a similarity search with Yellowbrick by vectors
+
+ Args:
+ embedding (List[float]): query embedding
+ k (int, optional): Top K neighbors to retrieve. Defaults to 4.
+
+ NOTE: Please do not let end-user fill this and always be aware
+ of SQL injection.
+
+ Returns:
+ List[Document]: List of documents
+ """
+ documents = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k
+ )
+ return [doc for doc, _ in documents]
diff --git a/libs/community/langchain_community/vectorstores/zep.py b/libs/community/langchain_community/vectorstores/zep.py
new file mode 100644
index 00000000000..5348dd8ac6e
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/zep.py
@@ -0,0 +1,678 @@
+from __future__ import annotations
+
+import logging
+import warnings
+from dataclasses import asdict, dataclass
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.vectorstores import VectorStore
+
+if TYPE_CHECKING:
+ from zep_python.document import Document as ZepDocument
+ from zep_python.document import DocumentCollection
+
+
+logger = logging.getLogger()
+
+
+@dataclass
+class CollectionConfig:
+ """Configuration for a `Zep Collection`.
+
+ If the collection does not exist, it will be created.
+
+ Attributes:
+ name (str): The name of the collection.
+ description (Optional[str]): An optional description of the collection.
+ metadata (Optional[Dict[str, Any]]): Optional metadata for the collection.
+ embedding_dimensions (int): The number of dimensions for the embeddings in
+ the collection. This should match the Zep server configuration
+ if auto-embed is true.
+ is_auto_embedded (bool): A flag indicating whether the collection is
+ automatically embedded by Zep.
+ """
+
+ name: str
+ description: Optional[str]
+ metadata: Optional[Dict[str, Any]]
+ embedding_dimensions: int
+ is_auto_embedded: bool
+
+
+class ZepVectorStore(VectorStore):
+ """`Zep` vector store.
+
+ It provides methods for adding texts or documents to the store,
+ searching for similar documents, and deleting documents.
+
+ Search scores are calculated using cosine similarity normalized to [0, 1].
+
+ Args:
+ api_url (str): The URL of the Zep API.
+ collection_name (str): The name of the collection in the Zep store.
+ api_key (Optional[str]): The API key for the Zep API.
+ config (Optional[CollectionConfig]): The configuration for the collection.
+ Required if the collection does not already exist.
+ embedding (Optional[Embeddings]): Optional embedding function to use to
+ embed the texts. Required if the collection is not auto-embedded.
+ """
+
+ def __init__(
+ self,
+ collection_name: str,
+ api_url: str,
+ *,
+ api_key: Optional[str] = None,
+ config: Optional[CollectionConfig] = None,
+ embedding: Optional[Embeddings] = None,
+ ) -> None:
+ super().__init__()
+ if not collection_name:
+ raise ValueError(
+ "collection_name must be specified when using ZepVectorStore."
+ )
+ try:
+ from zep_python import ZepClient
+ except ImportError:
+ raise ImportError(
+ "Could not import zep-python python package. "
+ "Please install it with `pip install zep-python`."
+ )
+ self._client = ZepClient(api_url, api_key=api_key)
+
+ self.collection_name = collection_name
+ # If for some reason the collection name is not the same as the one in the
+ # config, update it.
+ if config and config.name != self.collection_name:
+ config.name = self.collection_name
+
+ self._collection_config = config
+ self._collection = self._load_collection()
+ self._embedding = embedding
+
+ # self.add_texts(texts, metadatas=metadatas, **kwargs)
+
+ @property
+ def embeddings(self) -> Optional[Embeddings]:
+ """Access the query embedding object if available."""
+ return self._embedding
+
+ def _load_collection(self) -> DocumentCollection:
+ """
+ Load the collection from the Zep backend.
+ """
+ from zep_python import NotFoundError
+
+ try:
+ collection = self._client.document.get_collection(self.collection_name)
+ except NotFoundError:
+ logger.info(
+ f"Collection {self.collection_name} not found. Creating new collection."
+ )
+ collection = self._create_collection()
+
+ return collection
+
+ def _create_collection(self) -> DocumentCollection:
+ """
+ Create a new collection in the Zep backend.
+ """
+ if not self._collection_config:
+ raise ValueError(
+ "Collection config must be specified when creating a new collection."
+ )
+ collection = self._client.document.add_collection(
+ **asdict(self._collection_config)
+ )
+ return collection
+
+ def _generate_documents_to_add(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[Any, Any]]] = None,
+ document_ids: Optional[List[str]] = None,
+ ) -> List[ZepDocument]:
+ from zep_python.document import Document as ZepDocument
+
+ embeddings = None
+ if self._collection and self._collection.is_auto_embedded:
+ if self._embedding is not None:
+ warnings.warn(
+ """The collection is set to auto-embed and an embedding
+ function is present. Ignoring the embedding function.""",
+ stacklevel=2,
+ )
+ elif self._embedding is not None:
+ embeddings = self._embedding.embed_documents(list(texts))
+ if self._collection and self._collection.embedding_dimensions != len(
+ embeddings[0]
+ ):
+ raise ValueError(
+ "The embedding dimensions of the collection and the embedding"
+ " function do not match. Collection dimensions:"
+ f" {self._collection.embedding_dimensions}, Embedding dimensions:"
+ f" {len(embeddings[0])}"
+ )
+ else:
+ pass
+
+ documents: List[ZepDocument] = []
+ for i, d in enumerate(texts):
+ documents.append(
+ ZepDocument(
+ content=d,
+ metadata=metadatas[i] if metadatas else None,
+ document_id=document_ids[i] if document_ids else None,
+ embedding=embeddings[i] if embeddings else None,
+ )
+ )
+ return documents
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[str, Any]]] = None,
+ document_ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ metadatas: Optional list of metadatas associated with the texts.
+ document_ids: Optional list of document ids associated with the texts.
+ kwargs: vectorstore specific parameters
+
+ Returns:
+ List of ids from adding the texts into the vectorstore.
+ """
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ documents = self._generate_documents_to_add(texts, metadatas, document_ids)
+ uuids = self._collection.add_documents(documents)
+
+ return uuids
+
+ async def aadd_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[Dict[str, Any]]] = None,
+ document_ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Run more texts through the embeddings and add to the vectorstore."""
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ documents = self._generate_documents_to_add(texts, metadatas, document_ids)
+ uuids = await self._collection.aadd_documents(documents)
+
+ return uuids
+
+ def search(
+ self,
+ query: str,
+ search_type: str,
+ metadata: Optional[Dict[str, Any]] = None,
+ k: int = 3,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query using specified search type."""
+ if search_type == "similarity":
+ return self.similarity_search(query, k=k, metadata=metadata, **kwargs)
+ elif search_type == "mmr":
+ return self.max_marginal_relevance_search(
+ query, k=k, metadata=metadata, **kwargs
+ )
+ else:
+ raise ValueError(
+ f"search_type of {search_type} not allowed. Expected "
+ "search_type to be 'similarity' or 'mmr'."
+ )
+
+ async def asearch(
+ self,
+ query: str,
+ search_type: str,
+ metadata: Optional[Dict[str, Any]] = None,
+ k: int = 3,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query using specified search type."""
+ if search_type == "similarity":
+ return await self.asimilarity_search(
+ query, k=k, metadata=metadata, **kwargs
+ )
+ elif search_type == "mmr":
+ return await self.amax_marginal_relevance_search(
+ query, k=k, metadata=metadata, **kwargs
+ )
+ else:
+ raise ValueError(
+ f"search_type of {search_type} not allowed. Expected "
+ "search_type to be 'similarity' or 'mmr'."
+ )
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query."""
+
+ results = self._similarity_search_with_relevance_scores(
+ query, k=k, metadata=metadata, **kwargs
+ )
+ return [doc for doc, _ in results]
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Run similarity search with distance."""
+
+ return self._similarity_search_with_relevance_scores(
+ query, k=k, metadata=metadata, **kwargs
+ )
+
+ def _similarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """
+ Default similarity search with relevance scores. Modify if necessary
+ in subclass.
+ Return docs and relevance scores in the range [0, 1].
+
+ 0 is dissimilar, 1 is most similar.
+
+ Args:
+ query: input text
+ k: Number of Documents to return. Defaults to 4.
+ metadata: Optional, metadata filter
+ **kwargs: kwargs to be passed to similarity search. Should include:
+ score_threshold: Optional, a floating point value between 0 to 1 and
+ filter the resulting set of retrieved docs
+
+ Returns:
+ List of Tuples of (doc, similarity_score)
+ """
+
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ if not self._collection.is_auto_embedded and self._embedding:
+ query_vector = self._embedding.embed_query(query)
+ results = self._collection.search(
+ embedding=query_vector, limit=k, metadata=metadata, **kwargs
+ )
+ else:
+ results = self._collection.search(
+ query, limit=k, metadata=metadata, **kwargs
+ )
+
+ return [
+ (
+ Document(
+ page_content=doc.content,
+ metadata=doc.metadata,
+ ),
+ doc.score or 0.0,
+ )
+ for doc in results
+ ]
+
+ async def asimilarity_search_with_relevance_scores(
+ self,
+ query: str,
+ k: int = 4,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Tuple[Document, float]]:
+ """Return docs most similar to query."""
+
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ if not self._collection.is_auto_embedded and self._embedding:
+ query_vector = self._embedding.embed_query(query)
+ results = await self._collection.asearch(
+ embedding=query_vector, limit=k, metadata=metadata, **kwargs
+ )
+ else:
+ results = await self._collection.asearch(
+ query, limit=k, metadata=metadata, **kwargs
+ )
+
+ return [
+ (
+ Document(
+ page_content=doc.content,
+ metadata=doc.metadata,
+ ),
+ doc.score or 0.0,
+ )
+ for doc in results
+ ]
+
+ async def asimilarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to query."""
+
+ results = await self.asimilarity_search_with_relevance_scores(
+ query, k, metadata=metadata, **kwargs
+ )
+
+ return [doc for doc, _ in results]
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ metadata: Optional, metadata filter
+
+ Returns:
+ List of Documents most similar to the query vector.
+ """
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ results = self._collection.search(
+ embedding=embedding, limit=k, metadata=metadata, **kwargs
+ )
+
+ return [
+ Document(
+ page_content=doc.content,
+ metadata=doc.metadata,
+ )
+ for doc in results
+ ]
+
+ async def asimilarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs most similar to embedding vector."""
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ results = self._collection.search(
+ embedding=embedding, limit=k, metadata=metadata, **kwargs
+ )
+
+ return [
+ Document(
+ page_content=doc.content,
+ metadata=doc.metadata,
+ )
+ for doc in results
+ ]
+
+ def max_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ query: Text to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ Zep determines this automatically and this parameter is
+ ignored.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ metadata: Optional, metadata to filter the resulting set of retrieved docs
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ if not self._collection.is_auto_embedded and self._embedding:
+ query_vector = self._embedding.embed_query(query)
+ results = self._collection.search(
+ embedding=query_vector,
+ limit=k,
+ metadata=metadata,
+ search_type="mmr",
+ mmr_lambda=lambda_mult,
+ **kwargs,
+ )
+ else:
+ results, query_vector = self._collection.search_return_query_vector(
+ query,
+ limit=k,
+ metadata=metadata,
+ search_type="mmr",
+ mmr_lambda=lambda_mult,
+ **kwargs,
+ )
+
+ return [Document(page_content=d.content, metadata=d.metadata) for d in results]
+
+ async def amax_marginal_relevance_search(
+ self,
+ query: str,
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance."""
+
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ if not self._collection.is_auto_embedded and self._embedding:
+ query_vector = self._embedding.embed_query(query)
+ results = await self._collection.asearch(
+ embedding=query_vector,
+ limit=k,
+ metadata=metadata,
+ search_type="mmr",
+ mmr_lambda=lambda_mult,
+ **kwargs,
+ )
+ else:
+ results, query_vector = await self._collection.asearch_return_query_vector(
+ query,
+ limit=k,
+ metadata=metadata,
+ search_type="mmr",
+ mmr_lambda=lambda_mult,
+ **kwargs,
+ )
+
+ return [Document(page_content=d.content, metadata=d.metadata) for d in results]
+
+ def max_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance.
+
+ Maximal marginal relevance optimizes for similarity to query AND diversity
+ among selected documents.
+
+ Args:
+ embedding: Embedding to look up documents similar to.
+ k: Number of Documents to return. Defaults to 4.
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+ Zep determines this automatically and this parameter is
+ ignored.
+ lambda_mult: Number between 0 and 1 that determines the degree
+ of diversity among the results with 0 corresponding
+ to maximum diversity and 1 to minimum diversity.
+ Defaults to 0.5.
+ metadata: Optional, metadata to filter the resulting set of retrieved docs
+ Returns:
+ List of Documents selected by maximal marginal relevance.
+ """
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ results = self._collection.search(
+ embedding=embedding,
+ limit=k,
+ metadata=metadata,
+ search_type="mmr",
+ mmr_lambda=lambda_mult,
+ **kwargs,
+ )
+
+ return [Document(page_content=d.content, metadata=d.metadata) for d in results]
+
+ async def amax_marginal_relevance_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ fetch_k: int = 20,
+ lambda_mult: float = 0.5,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Return docs selected using the maximal marginal relevance."""
+ if not self._collection:
+ raise ValueError(
+ "collection should be an instance of a Zep DocumentCollection"
+ )
+
+ results = await self._collection.asearch(
+ embedding=embedding,
+ limit=k,
+ metadata=metadata,
+ search_type="mmr",
+ mmr_lambda=lambda_mult,
+ **kwargs,
+ )
+
+ return [Document(page_content=d.content, metadata=d.metadata) for d in results]
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Optional[Embeddings] = None,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = "",
+ api_url: str = "",
+ api_key: Optional[str] = None,
+ config: Optional[CollectionConfig] = None,
+ **kwargs: Any,
+ ) -> ZepVectorStore:
+ """
+ Class method that returns a ZepVectorStore instance initialized from texts.
+
+ If the collection does not exist, it will be created.
+
+ Args:
+ texts (List[str]): The list of texts to add to the vectorstore.
+ embedding (Optional[Embeddings]): Optional embedding function to use to
+ embed the texts.
+ metadatas (Optional[List[Dict[str, Any]]]): Optional list of metadata
+ associated with the texts.
+ collection_name (str): The name of the collection in the Zep store.
+ api_url (str): The URL of the Zep API.
+ api_key (Optional[str]): The API key for the Zep API.
+ config (Optional[CollectionConfig]): The configuration for the collection.
+ **kwargs: Additional parameters specific to the vectorstore.
+
+ Returns:
+ ZepVectorStore: An instance of ZepVectorStore.
+ """
+ vecstore = cls(
+ collection_name,
+ api_url,
+ api_key=api_key,
+ config=config,
+ embedding=embedding,
+ )
+ vecstore.add_texts(texts, metadatas)
+ return vecstore
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
+ """Delete by Zep vector UUIDs.
+
+ Parameters
+ ----------
+ ids : Optional[List[str]]
+ The UUIDs of the vectors to delete.
+
+ Raises
+ ------
+ ValueError
+ If no UUIDs are provided.
+ """
+
+ if ids is None or len(ids) == 0:
+ raise ValueError("No uuids provided to delete.")
+
+ if self._collection is None:
+ raise ValueError("No collection name provided.")
+
+ for u in ids:
+ self._collection.delete_document(u)
diff --git a/libs/community/langchain_community/vectorstores/zilliz.py b/libs/community/langchain_community/vectorstores/zilliz.py
new file mode 100644
index 00000000000..c66b9294e80
--- /dev/null
+++ b/libs/community/langchain_community/vectorstores/zilliz.py
@@ -0,0 +1,185 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, List, Optional
+
+from langchain_core.embeddings import Embeddings
+
+from langchain_community.vectorstores.milvus import Milvus
+
+logger = logging.getLogger(__name__)
+
+
+class Zilliz(Milvus):
+ """`Zilliz` vector store.
+
+ You need to have `pymilvus` installed and a
+ running Zilliz database.
+
+ See the following documentation for how to run a Zilliz instance:
+ https://docs.zilliz.com/docs/create-cluster
+
+
+ IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
+
+ Args:
+ embedding_function (Embeddings): Function used to embed the text.
+ collection_name (str): Which Zilliz collection to use. Defaults to
+ "LangChainCollection".
+ connection_args (Optional[dict[str, any]]): The connection args used for
+ this class comes in the form of a dict.
+ consistency_level (str): The consistency level to use for a collection.
+ Defaults to "Session".
+ index_params (Optional[dict]): Which index params to use. Defaults to
+ HNSW/AUTOINDEX depending on service.
+ search_params (Optional[dict]): Which search params to use. Defaults to
+ default of index.
+ drop_old (Optional[bool]): Whether to drop the current collection. Defaults
+ to False.
+
+ The connection args used for this class comes in the form of a dict,
+ here are a few of the options:
+ address (str): The actual address of Zilliz
+ instance. Example address: "localhost:19530"
+ uri (str): The uri of Zilliz instance. Example uri:
+ "https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
+ host (str): The host of Zilliz instance. Default at "localhost",
+ PyMilvus will fill in the default host if only port is provided.
+ port (str/int): The port of Zilliz instance. Default at 19530, PyMilvus
+ will fill in the default port if only host is provided.
+ user (str): Use which user to connect to Zilliz instance. If user and
+ password are provided, we will add related header in every RPC call.
+ password (str): Required when user is provided. The password
+ corresponding to the user.
+ token (str): API key, for serverless clusters which can be used as
+ replacements for user and password.
+ secure (bool): Default is false. If set to true, tls will be enabled.
+ client_key_path (str): If use tls two-way authentication, need to
+ write the client.key path.
+ client_pem_path (str): If use tls two-way authentication, need to
+ write the client.pem path.
+ ca_pem_path (str): If use tls two-way authentication, need to write
+ the ca.pem path.
+ server_pem_path (str): If use tls one-way authentication, need to
+ write the server.pem path.
+ server_name (str): If use tls, need to write the common name.
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.vectorstores import Zilliz
+ from langchain_community.embeddings import OpenAIEmbeddings
+
+ embedding = OpenAIEmbeddings()
+ # Connect to a Zilliz instance
+ milvus_store = Milvus(
+ embedding_function = embedding,
+ collection_name = "LangChainCollection",
+ connection_args = {
+ "uri": "https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
+ "user": "temp",
+ "password": "temp",
+ "token": "temp", # API key as replacements for user and password
+ "secure": True
+ }
+ drop_old: True,
+ )
+
+ Raises:
+ ValueError: If the pymilvus python package is not installed.
+ """
+
+ def _create_index(self) -> None:
+ """Create a index on the collection"""
+ from pymilvus import Collection, MilvusException
+
+ if isinstance(self.col, Collection) and self._get_index() is None:
+ try:
+ # If no index params, use a default AutoIndex based one
+ if self.index_params is None:
+ self.index_params = {
+ "metric_type": "L2",
+ "index_type": "AUTOINDEX",
+ "params": {},
+ }
+
+ try:
+ self.col.create_index(
+ self._vector_field,
+ index_params=self.index_params,
+ using=self.alias,
+ )
+
+ # If default did not work, most likely Milvus self-hosted
+ except MilvusException:
+ # Use HNSW based index
+ self.index_params = {
+ "metric_type": "L2",
+ "index_type": "HNSW",
+ "params": {"M": 8, "efConstruction": 64},
+ }
+ self.col.create_index(
+ self._vector_field,
+ index_params=self.index_params,
+ using=self.alias,
+ )
+ logger.debug(
+ "Successfully created an index on collection: %s",
+ self.collection_name,
+ )
+
+ except MilvusException as e:
+ logger.error(
+ "Failed to create an index on collection: %s", self.collection_name
+ )
+ raise e
+
+ @classmethod
+ def from_texts(
+ cls,
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = "LangChainCollection",
+ connection_args: Optional[Dict[str, Any]] = None,
+ consistency_level: str = "Session",
+ index_params: Optional[dict] = None,
+ search_params: Optional[dict] = None,
+ drop_old: bool = False,
+ **kwargs: Any,
+ ) -> Zilliz:
+ """Create a Zilliz collection, indexes it with HNSW, and insert data.
+
+ Args:
+ texts (List[str]): Text data.
+ embedding (Embeddings): Embedding function.
+ metadatas (Optional[List[dict]]): Metadata for each text if it exists.
+ Defaults to None.
+ collection_name (str, optional): Collection name to use. Defaults to
+ "LangChainCollection".
+ connection_args (dict[str, Any], optional): Connection args to use. Defaults
+ to DEFAULT_MILVUS_CONNECTION.
+ consistency_level (str, optional): Which consistency level to use. Defaults
+ to "Session".
+ index_params (Optional[dict], optional): Which index_params to use.
+ Defaults to None.
+ search_params (Optional[dict], optional): Which search params to use.
+ Defaults to None.
+ drop_old (Optional[bool], optional): Whether to drop the collection with
+ that name if it exists. Defaults to False.
+
+ Returns:
+ Zilliz: Zilliz Vector Store
+ """
+ vector_db = cls(
+ embedding_function=embedding,
+ collection_name=collection_name,
+ connection_args=connection_args or {},
+ consistency_level=consistency_level,
+ index_params=index_params,
+ search_params=search_params,
+ drop_old=drop_old,
+ **kwargs,
+ )
+ vector_db.add_texts(texts=texts, metadatas=metadatas)
+ return vector_db
diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock
new file mode 100644
index 00000000000..e443c892f44
--- /dev/null
+++ b/libs/community/poetry.lock
@@ -0,0 +1,8488 @@
+# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
+
+[[package]]
+name = "aiodns"
+version = "3.1.1"
+description = "Simple DNS resolver for asyncio"
+optional = true
+python-versions = "*"
+files = [
+ {file = "aiodns-3.1.1-py3-none-any.whl", hash = "sha256:a387b63da4ced6aad35b1dda2d09620ad608a1c7c0fb71efa07ebb4cd511928d"},
+ {file = "aiodns-3.1.1.tar.gz", hash = "sha256:1073eac48185f7a4150cad7f96a5192d6911f12b4fb894de80a088508c9b3a99"},
+]
+
+[package.dependencies]
+pycares = ">=4.0.0"
+
+[[package]]
+name = "aiohttp"
+version = "3.9.1"
+description = "Async http client/server framework (asyncio)"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "aiohttp-3.9.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e1f80197f8b0b846a8d5cf7b7ec6084493950d0882cc5537fb7b96a69e3c8590"},
+ {file = "aiohttp-3.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c72444d17777865734aa1a4d167794c34b63e5883abb90356a0364a28904e6c0"},
+ {file = "aiohttp-3.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9b05d5cbe9dafcdc733262c3a99ccf63d2f7ce02543620d2bd8db4d4f7a22f83"},
+ {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c4fa235d534b3547184831c624c0b7c1e262cd1de847d95085ec94c16fddcd5"},
+ {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:289ba9ae8e88d0ba16062ecf02dd730b34186ea3b1e7489046fc338bdc3361c4"},
+ {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bff7e2811814fa2271be95ab6e84c9436d027a0e59665de60edf44e529a42c1f"},
+ {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81b77f868814346662c96ab36b875d7814ebf82340d3284a31681085c051320f"},
+ {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b9c7426923bb7bd66d409da46c41e3fb40f5caf679da624439b9eba92043fa6"},
+ {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8d44e7bf06b0c0a70a20f9100af9fcfd7f6d9d3913e37754c12d424179b4e48f"},
+ {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22698f01ff5653fe66d16ffb7658f582a0ac084d7da1323e39fd9eab326a1f26"},
+ {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ca7ca5abfbfe8d39e653870fbe8d7710be7a857f8a8386fc9de1aae2e02ce7e4"},
+ {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:8d7f98fde213f74561be1d6d3fa353656197f75d4edfbb3d94c9eb9b0fc47f5d"},
+ {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5216b6082c624b55cfe79af5d538e499cd5f5b976820eac31951fb4325974501"},
+ {file = "aiohttp-3.9.1-cp310-cp310-win32.whl", hash = "sha256:0e7ba7ff228c0d9a2cd66194e90f2bca6e0abca810b786901a569c0de082f489"},
+ {file = "aiohttp-3.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:c7e939f1ae428a86e4abbb9a7c4732bf4706048818dfd979e5e2839ce0159f23"},
+ {file = "aiohttp-3.9.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:df9cf74b9bc03d586fc53ba470828d7b77ce51b0582d1d0b5b2fb673c0baa32d"},
+ {file = "aiohttp-3.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecca113f19d5e74048c001934045a2b9368d77b0b17691d905af18bd1c21275e"},
+ {file = "aiohttp-3.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8cef8710fb849d97c533f259103f09bac167a008d7131d7b2b0e3a33269185c0"},
+ {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bea94403a21eb94c93386d559bce297381609153e418a3ffc7d6bf772f59cc35"},
+ {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91c742ca59045dce7ba76cab6e223e41d2c70d79e82c284a96411f8645e2afff"},
+ {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c93b7c2e52061f0925c3382d5cb8980e40f91c989563d3d32ca280069fd6a87"},
+ {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee2527134f95e106cc1653e9ac78846f3a2ec1004cf20ef4e02038035a74544d"},
+ {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11ff168d752cb41e8492817e10fb4f85828f6a0142b9726a30c27c35a1835f01"},
+ {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b8c3a67eb87394386847d188996920f33b01b32155f0a94f36ca0e0c635bf3e3"},
+ {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c7b5d5d64e2a14e35a9240b33b89389e0035e6de8dbb7ffa50d10d8b65c57449"},
+ {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:69985d50a2b6f709412d944ffb2e97d0be154ea90600b7a921f95a87d6f108a2"},
+ {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:c9110c06eaaac7e1f5562caf481f18ccf8f6fdf4c3323feab28a93d34cc646bd"},
+ {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d737e69d193dac7296365a6dcb73bbbf53bb760ab25a3727716bbd42022e8d7a"},
+ {file = "aiohttp-3.9.1-cp311-cp311-win32.whl", hash = "sha256:4ee8caa925aebc1e64e98432d78ea8de67b2272252b0a931d2ac3bd876ad5544"},
+ {file = "aiohttp-3.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:a34086c5cc285be878622e0a6ab897a986a6e8bf5b67ecb377015f06ed316587"},
+ {file = "aiohttp-3.9.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f800164276eec54e0af5c99feb9494c295118fc10a11b997bbb1348ba1a52065"},
+ {file = "aiohttp-3.9.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:500f1c59906cd142d452074f3811614be04819a38ae2b3239a48b82649c08821"},
+ {file = "aiohttp-3.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0b0a6a36ed7e164c6df1e18ee47afbd1990ce47cb428739d6c99aaabfaf1b3af"},
+ {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69da0f3ed3496808e8cbc5123a866c41c12c15baaaead96d256477edf168eb57"},
+ {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:176df045597e674fa950bf5ae536be85699e04cea68fa3a616cf75e413737eb5"},
+ {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b796b44111f0cab6bbf66214186e44734b5baab949cb5fb56154142a92989aeb"},
+ {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f27fdaadce22f2ef950fc10dcdf8048407c3b42b73779e48a4e76b3c35bca26c"},
+ {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcb6532b9814ea7c5a6a3299747c49de30e84472fa72821b07f5a9818bce0f66"},
+ {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:54631fb69a6e44b2ba522f7c22a6fb2667a02fd97d636048478db2fd8c4e98fe"},
+ {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4b4c452d0190c5a820d3f5c0f3cd8a28ace48c54053e24da9d6041bf81113183"},
+ {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:cae4c0c2ca800c793cae07ef3d40794625471040a87e1ba392039639ad61ab5b"},
+ {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:565760d6812b8d78d416c3c7cfdf5362fbe0d0d25b82fed75d0d29e18d7fc30f"},
+ {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:54311eb54f3a0c45efb9ed0d0a8f43d1bc6060d773f6973efd90037a51cd0a3f"},
+ {file = "aiohttp-3.9.1-cp312-cp312-win32.whl", hash = "sha256:85c3e3c9cb1d480e0b9a64c658cd66b3cfb8e721636ab8b0e746e2d79a7a9eed"},
+ {file = "aiohttp-3.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:11cb254e397a82efb1805d12561e80124928e04e9c4483587ce7390b3866d213"},
+ {file = "aiohttp-3.9.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8a22a34bc594d9d24621091d1b91511001a7eea91d6652ea495ce06e27381f70"},
+ {file = "aiohttp-3.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:598db66eaf2e04aa0c8900a63b0101fdc5e6b8a7ddd805c56d86efb54eb66672"},
+ {file = "aiohttp-3.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c9376e2b09895c8ca8b95362283365eb5c03bdc8428ade80a864160605715f1"},
+ {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41473de252e1797c2d2293804e389a6d6986ef37cbb4a25208de537ae32141dd"},
+ {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c5857612c9813796960c00767645cb5da815af16dafb32d70c72a8390bbf690"},
+ {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffcd828e37dc219a72c9012ec44ad2e7e3066bec6ff3aaa19e7d435dbf4032ca"},
+ {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:219a16763dc0294842188ac8a12262b5671817042b35d45e44fd0a697d8c8361"},
+ {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f694dc8a6a3112059258a725a4ebe9acac5fe62f11c77ac4dcf896edfa78ca28"},
+ {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bcc0ea8d5b74a41b621ad4a13d96c36079c81628ccc0b30cfb1603e3dfa3a014"},
+ {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:90ec72d231169b4b8d6085be13023ece8fa9b1bb495e4398d847e25218e0f431"},
+ {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:cf2a0ac0615842b849f40c4d7f304986a242f1e68286dbf3bd7a835e4f83acfd"},
+ {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:0e49b08eafa4f5707ecfb321ab9592717a319e37938e301d462f79b4e860c32a"},
+ {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2c59e0076ea31c08553e868cec02d22191c086f00b44610f8ab7363a11a5d9d8"},
+ {file = "aiohttp-3.9.1-cp38-cp38-win32.whl", hash = "sha256:4831df72b053b1eed31eb00a2e1aff6896fb4485301d4ccb208cac264b648db4"},
+ {file = "aiohttp-3.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:3135713c5562731ee18f58d3ad1bf41e1d8883eb68b363f2ffde5b2ea4b84cc7"},
+ {file = "aiohttp-3.9.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cfeadf42840c1e870dc2042a232a8748e75a36b52d78968cda6736de55582766"},
+ {file = "aiohttp-3.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:70907533db712f7aa791effb38efa96f044ce3d4e850e2d7691abd759f4f0ae0"},
+ {file = "aiohttp-3.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cdefe289681507187e375a5064c7599f52c40343a8701761c802c1853a504558"},
+ {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7481f581251bb5558ba9f635db70908819caa221fc79ee52a7f58392778c636"},
+ {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:49f0c1b3c2842556e5de35f122fc0f0b721334ceb6e78c3719693364d4af8499"},
+ {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d406b01a9f5a7e232d1b0d161b40c05275ffbcbd772dc18c1d5a570961a1ca4"},
+ {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d8e4450e7fe24d86e86b23cc209e0023177b6d59502e33807b732d2deb6975f"},
+ {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c0266cd6f005e99f3f51e583012de2778e65af6b73860038b968a0a8888487a"},
+ {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab221850108a4a063c5b8a70f00dd7a1975e5a1713f87f4ab26a46e5feac5a0e"},
+ {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c88a15f272a0ad3d7773cf3a37cc7b7d077cbfc8e331675cf1346e849d97a4e5"},
+ {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:237533179d9747080bcaad4d02083ce295c0d2eab3e9e8ce103411a4312991a0"},
+ {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:02ab6006ec3c3463b528374c4cdce86434e7b89ad355e7bf29e2f16b46c7dd6f"},
+ {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04fa38875e53eb7e354ece1607b1d2fdee2d175ea4e4d745f6ec9f751fe20c7c"},
+ {file = "aiohttp-3.9.1-cp39-cp39-win32.whl", hash = "sha256:82eefaf1a996060602f3cc1112d93ba8b201dbf5d8fd9611227de2003dddb3b7"},
+ {file = "aiohttp-3.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:9b05d33ff8e6b269e30a7957bd3244ffbce2a7a35a81b81c382629b80af1a8bf"},
+ {file = "aiohttp-3.9.1.tar.gz", hash = "sha256:8fc49a87ac269d4529da45871e2ffb6874e87779c3d0e2ccd813c0899221239d"},
+]
+
+[package.dependencies]
+aiosignal = ">=1.1.2"
+async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""}
+attrs = ">=17.3.0"
+frozenlist = ">=1.1.1"
+multidict = ">=4.5,<7.0"
+yarl = ">=1.0,<2.0"
+
+[package.extras]
+speedups = ["Brotli", "aiodns", "brotlicffi"]
+
+[[package]]
+name = "aiohttp-retry"
+version = "2.8.3"
+description = "Simple retry client for aiohttp"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "aiohttp_retry-2.8.3-py3-none-any.whl", hash = "sha256:3aeeead8f6afe48272db93ced9440cf4eda8b6fd7ee2abb25357b7eb28525b45"},
+ {file = "aiohttp_retry-2.8.3.tar.gz", hash = "sha256:9a8e637e31682ad36e1ff9f8bcba912fcfc7d7041722bc901a4b948da4d71ea9"},
+]
+
+[package.dependencies]
+aiohttp = "*"
+
+[[package]]
+name = "aiosignal"
+version = "1.3.1"
+description = "aiosignal: a list of registered asynchronous callbacks"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
+ {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"},
+]
+
+[package.dependencies]
+frozenlist = ">=1.1.0"
+
+[[package]]
+name = "aiosqlite"
+version = "0.19.0"
+description = "asyncio bridge to the standard sqlite3 module"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"},
+ {file = "aiosqlite-0.19.0.tar.gz", hash = "sha256:95ee77b91c8d2808bd08a59fbebf66270e9090c3d92ffbf260dc0db0b979577d"},
+]
+
+[package.extras]
+dev = ["aiounittest (==1.4.1)", "attribution (==1.6.2)", "black (==23.3.0)", "coverage[toml] (==7.2.3)", "flake8 (==5.0.4)", "flake8-bugbear (==23.3.12)", "flit (==3.7.1)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"]
+docs = ["sphinx (==6.1.3)", "sphinx-mdinclude (==0.5.3)"]
+
+[[package]]
+name = "aleph-alpha-client"
+version = "2.17.0"
+description = "python client to interact with Aleph Alpha api endpoints"
+optional = true
+python-versions = "*"
+files = [
+ {file = "aleph-alpha-client-2.17.0.tar.gz", hash = "sha256:c2d664c7b829f4932306153bec45e11c08e03252f1dbfd9f48584c402d7050a3"},
+ {file = "aleph_alpha_client-2.17.0-py3-none-any.whl", hash = "sha256:9106a36a5e08dba6aea2b0b2a0de6ff0c3bb77926edc98226debae121b0925e2"},
+]
+
+[package.dependencies]
+aiodns = ">=3.0.0"
+aiohttp = ">=3.8.3"
+aiohttp-retry = ">=2.8.3"
+Pillow = ">=9.2.0"
+requests = ">=2.28"
+tokenizers = ">=0.13.2"
+typing-extensions = ">=4.5.0"
+urllib3 = ">=1.26"
+
+[package.extras]
+dev = ["black", "ipykernel", "mypy", "nbconvert", "pytest", "pytest-aiohttp", "pytest-cov", "pytest-dotenv", "pytest-httpserver", "types-Pillow", "types-requests"]
+docs = ["sphinx", "sphinx-rtd-theme"]
+test = ["pytest", "pytest-aiohttp", "pytest-cov", "pytest-dotenv", "pytest-httpserver"]
+types = ["mypy", "types-Pillow", "types-requests"]
+
+[[package]]
+name = "altair"
+version = "4.2.2"
+description = "Altair: A declarative statistical visualization library for Python."
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "altair-4.2.2-py3-none-any.whl", hash = "sha256:8b45ebeaf8557f2d760c5c77b79f02ae12aee7c46c27c06014febab6f849bc87"},
+ {file = "altair-4.2.2.tar.gz", hash = "sha256:39399a267c49b30d102c10411e67ab26374156a84b1aeb9fcd15140429ba49c5"},
+]
+
+[package.dependencies]
+entrypoints = "*"
+jinja2 = "*"
+jsonschema = ">=3.0"
+numpy = "*"
+pandas = ">=0.18"
+toolz = "*"
+
+[package.extras]
+dev = ["black", "docutils", "flake8", "ipython", "m2r", "mistune (<2.0.0)", "pytest", "recommonmark", "sphinx", "vega-datasets"]
+
+[[package]]
+name = "anthropic"
+version = "0.3.11"
+description = "Client library for the anthropic API"
+optional = false
+python-versions = ">=3.7,<4.0"
+files = [
+ {file = "anthropic-0.3.11-py3-none-any.whl", hash = "sha256:5c81105cd9ee7388bff3fdb739aaddedc83bbae9b95d51c2d50c13b1ad106138"},
+ {file = "anthropic-0.3.11.tar.gz", hash = "sha256:2e0fa5351c9b368cbed0bbd7217deaa9409b82b56afaf244e2196e99eb4fe20e"},
+]
+
+[package.dependencies]
+anyio = ">=3.5.0,<4"
+distro = ">=1.7.0,<2"
+httpx = ">=0.23.0,<1"
+pydantic = ">=1.9.0,<3"
+tokenizers = ">=0.13.0"
+typing-extensions = ">=4.5,<5"
+
+[[package]]
+name = "anyio"
+version = "3.7.1"
+description = "High level compatibility layer for multiple asynchronous event loop implementations"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"},
+ {file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"},
+]
+
+[package.dependencies]
+exceptiongroup = {version = "*", markers = "python_version < \"3.11\""}
+idna = ">=2.8"
+sniffio = ">=1.1"
+
+[package.extras]
+doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"]
+test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
+trio = ["trio (<0.22)"]
+
+[[package]]
+name = "appnope"
+version = "0.1.3"
+description = "Disable App Nap on macOS >= 10.9"
+optional = false
+python-versions = "*"
+files = [
+ {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"},
+ {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"},
+]
+
+[[package]]
+name = "argon2-cffi"
+version = "23.1.0"
+description = "Argon2 for Python"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"},
+ {file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"},
+]
+
+[package.dependencies]
+argon2-cffi-bindings = "*"
+
+[package.extras]
+dev = ["argon2-cffi[tests,typing]", "tox (>4)"]
+docs = ["furo", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-notfound-page"]
+tests = ["hypothesis", "pytest"]
+typing = ["mypy"]
+
+[[package]]
+name = "argon2-cffi-bindings"
+version = "21.2.0"
+description = "Low-level CFFI bindings for Argon2"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9524464572e12979364b7d600abf96181d3541da11e23ddf565a32e70bd4dc0d"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58ed19212051f49a523abb1dbe954337dc82d947fb6e5a0da60f7c8471a8476c"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:bd46088725ef7f58b5a1ef7ca06647ebaf0eb4baff7d1d0d177c6cc8744abd86"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_i686.whl", hash = "sha256:8cd69c07dd875537a824deec19f978e0f2078fdda07fd5c42ac29668dda5f40f"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f1152ac548bd5b8bcecfb0b0371f082037e47128653df2e8ba6e914d384f3c3e"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win32.whl", hash = "sha256:603ca0aba86b1349b147cab91ae970c63118a0f30444d4bc80355937c950c082"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win_amd64.whl", hash = "sha256:b2ef1c30440dbbcba7a5dc3e319408b59676e2e039e2ae11a8775ecf482b192f"},
+ {file = "argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3e385d1c39c520c08b53d63300c3ecc28622f076f4c2b0e6d7e796e9f6502194"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3e3cc67fdb7d82c4718f19b4e7a87123caf8a93fde7e23cf66ac0337d3cb3f"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a22ad9800121b71099d0fb0a65323810a15f2e292f2ba450810a7316e128ee5"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9f8b450ed0547e3d473fdc8612083fd08dd2120d6ac8f73828df9b7d45bb351"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:93f9bf70084f97245ba10ee36575f0c3f1e7d7724d67d8e5b08e61787c320ed7"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3b9ef65804859d335dc6b31582cad2c5166f0c3e7975f324d9ffaa34ee7e6583"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4966ef5848d820776f5f562a7d45fdd70c2f330c961d0d745b784034bd9f48d"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ef543a89dee4db46a1a6e206cd015360e5a75822f76df533845c3cbaf72670"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed2937d286e2ad0cc79a7087d3c272832865f779430e0cc2b4f3718d3159b0cb"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"},
+]
+
+[package.dependencies]
+cffi = ">=1.0.1"
+
+[package.extras]
+dev = ["cogapp", "pre-commit", "pytest", "wheel"]
+tests = ["pytest"]
+
+[[package]]
+name = "arrow"
+version = "1.3.0"
+description = "Better dates & times for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80"},
+ {file = "arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85"},
+]
+
+[package.dependencies]
+python-dateutil = ">=2.7.0"
+types-python-dateutil = ">=2.8.10"
+
+[package.extras]
+doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"]
+test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"]
+
+[[package]]
+name = "arxiv"
+version = "1.4.8"
+description = "Python wrapper for the arXiv API: http://arxiv.org/help/api/"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "arxiv-1.4.8-py3-none-any.whl", hash = "sha256:c3dbef0fb7ed85c9b4c2157b40a62f5a04ce0d2f63c3ff7caa7798abf6166378"},
+ {file = "arxiv-1.4.8.tar.gz", hash = "sha256:2a818ea749eaa62a6e24fc31d53b769b4d33ff55cfc5dda7c7b7d309a3b29373"},
+]
+
+[package.dependencies]
+feedparser = "*"
+
+[[package]]
+name = "assemblyai"
+version = "0.17.0"
+description = "AssemblyAI Python SDK"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "assemblyai-0.17.0-py3-none-any.whl", hash = "sha256:3bad8cc7545b5b831f243f1b2f01bc4cc0e8aad78babf44c8008f2293c540e36"},
+ {file = "assemblyai-0.17.0.tar.gz", hash = "sha256:6d5bbfbbaa626ed021c3d3dec0ca52b3ebf6e6ef277ac76a7a6aed52182d531e"},
+]
+
+[package.dependencies]
+httpx = ">=0.19.0"
+pydantic = ">=1.7.0,<1.10.7 || >1.10.7"
+typing-extensions = ">=3.7"
+websockets = ">=11.0"
+
+[package.extras]
+extras = ["pyaudio (>=0.2.13)"]
+
+[[package]]
+name = "asttokens"
+version = "2.4.1"
+description = "Annotate AST trees with source code positions"
+optional = false
+python-versions = "*"
+files = [
+ {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"},
+ {file = "asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0"},
+]
+
+[package.dependencies]
+six = ">=1.12.0"
+
+[package.extras]
+astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"]
+test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"]
+
+[[package]]
+name = "async-lru"
+version = "2.0.4"
+description = "Simple LRU cache for asyncio"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "async-lru-2.0.4.tar.gz", hash = "sha256:b8a59a5df60805ff63220b2a0c5b5393da5521b113cd5465a44eb037d81a5627"},
+ {file = "async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224"},
+]
+
+[package.dependencies]
+typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
+
+[[package]]
+name = "async-timeout"
+version = "4.0.3"
+description = "Timeout context manager for asyncio programs"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
+ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
+]
+
+[[package]]
+name = "asyncpg"
+version = "0.29.0"
+description = "An asyncio PostgreSQL driver"
+optional = true
+python-versions = ">=3.8.0"
+files = [
+ {file = "asyncpg-0.29.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72fd0ef9f00aeed37179c62282a3d14262dbbafb74ec0ba16e1b1864d8a12169"},
+ {file = "asyncpg-0.29.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52e8f8f9ff6e21f9b39ca9f8e3e33a5fcdceaf5667a8c5c32bee158e313be385"},
+ {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e6823a7012be8b68301342ba33b4740e5a166f6bbda0aee32bc01638491a22"},
+ {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:746e80d83ad5d5464cfbf94315eb6744222ab00aa4e522b704322fb182b83610"},
+ {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ff8e8109cd6a46ff852a5e6bab8b0a047d7ea42fcb7ca5ae6eaae97d8eacf397"},
+ {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:97eb024685b1d7e72b1972863de527c11ff87960837919dac6e34754768098eb"},
+ {file = "asyncpg-0.29.0-cp310-cp310-win32.whl", hash = "sha256:5bbb7f2cafd8d1fa3e65431833de2642f4b2124be61a449fa064e1a08d27e449"},
+ {file = "asyncpg-0.29.0-cp310-cp310-win_amd64.whl", hash = "sha256:76c3ac6530904838a4b650b2880f8e7af938ee049e769ec2fba7cd66469d7772"},
+ {file = "asyncpg-0.29.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4900ee08e85af01adb207519bb4e14b1cae8fd21e0ccf80fac6aa60b6da37b4"},
+ {file = "asyncpg-0.29.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a65c1dcd820d5aea7c7d82a3fdcb70e096f8f70d1a8bf93eb458e49bfad036ac"},
+ {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b52e46f165585fd6af4863f268566668407c76b2c72d366bb8b522fa66f1870"},
+ {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc600ee8ef3dd38b8d67421359779f8ccec30b463e7aec7ed481c8346decf99f"},
+ {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:039a261af4f38f949095e1e780bae84a25ffe3e370175193174eb08d3cecab23"},
+ {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6feaf2d8f9138d190e5ec4390c1715c3e87b37715cd69b2c3dfca616134efd2b"},
+ {file = "asyncpg-0.29.0-cp311-cp311-win32.whl", hash = "sha256:1e186427c88225ef730555f5fdda6c1812daa884064bfe6bc462fd3a71c4b675"},
+ {file = "asyncpg-0.29.0-cp311-cp311-win_amd64.whl", hash = "sha256:cfe73ffae35f518cfd6e4e5f5abb2618ceb5ef02a2365ce64f132601000587d3"},
+ {file = "asyncpg-0.29.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6011b0dc29886ab424dc042bf9eeb507670a3b40aece3439944006aafe023178"},
+ {file = "asyncpg-0.29.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b544ffc66b039d5ec5a7454667f855f7fec08e0dfaf5a5490dfafbb7abbd2cfb"},
+ {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d84156d5fb530b06c493f9e7635aa18f518fa1d1395ef240d211cb563c4e2364"},
+ {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54858bc25b49d1114178d65a88e48ad50cb2b6f3e475caa0f0c092d5f527c106"},
+ {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bde17a1861cf10d5afce80a36fca736a86769ab3579532c03e45f83ba8a09c59"},
+ {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:37a2ec1b9ff88d8773d3eb6d3784dc7e3fee7756a5317b67f923172a4748a175"},
+ {file = "asyncpg-0.29.0-cp312-cp312-win32.whl", hash = "sha256:bb1292d9fad43112a85e98ecdc2e051602bce97c199920586be83254d9dafc02"},
+ {file = "asyncpg-0.29.0-cp312-cp312-win_amd64.whl", hash = "sha256:2245be8ec5047a605e0b454c894e54bf2ec787ac04b1cb7e0d3c67aa1e32f0fe"},
+ {file = "asyncpg-0.29.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0009a300cae37b8c525e5b449233d59cd9868fd35431abc470a3e364d2b85cb9"},
+ {file = "asyncpg-0.29.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cad1324dbb33f3ca0cd2074d5114354ed3be2b94d48ddfd88af75ebda7c43cc"},
+ {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:012d01df61e009015944ac7543d6ee30c2dc1eb2f6b10b62a3f598beb6531548"},
+ {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000c996c53c04770798053e1730d34e30cb645ad95a63265aec82da9093d88e7"},
+ {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e0bfe9c4d3429706cf70d3249089de14d6a01192d617e9093a8e941fea8ee775"},
+ {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:642a36eb41b6313ffa328e8a5c5c2b5bea6ee138546c9c3cf1bffaad8ee36dd9"},
+ {file = "asyncpg-0.29.0-cp38-cp38-win32.whl", hash = "sha256:a921372bbd0aa3a5822dd0409da61b4cd50df89ae85150149f8c119f23e8c408"},
+ {file = "asyncpg-0.29.0-cp38-cp38-win_amd64.whl", hash = "sha256:103aad2b92d1506700cbf51cd8bb5441e7e72e87a7b3a2ca4e32c840f051a6a3"},
+ {file = "asyncpg-0.29.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5340dd515d7e52f4c11ada32171d87c05570479dc01dc66d03ee3e150fb695da"},
+ {file = "asyncpg-0.29.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e17b52c6cf83e170d3d865571ba574577ab8e533e7361a2b8ce6157d02c665d3"},
+ {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f100d23f273555f4b19b74a96840aa27b85e99ba4b1f18d4ebff0734e78dc090"},
+ {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48e7c58b516057126b363cec8ca02b804644fd012ef8e6c7e23386b7d5e6ce83"},
+ {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f9ea3f24eb4c49a615573724d88a48bd1b7821c890c2effe04f05382ed9e8810"},
+ {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8d36c7f14a22ec9e928f15f92a48207546ffe68bc412f3be718eedccdf10dc5c"},
+ {file = "asyncpg-0.29.0-cp39-cp39-win32.whl", hash = "sha256:797ab8123ebaed304a1fad4d7576d5376c3a006a4100380fb9d517f0b59c1ab2"},
+ {file = "asyncpg-0.29.0-cp39-cp39-win_amd64.whl", hash = "sha256:cce08a178858b426ae1aa8409b5cc171def45d4293626e7aa6510696d46decd8"},
+ {file = "asyncpg-0.29.0.tar.gz", hash = "sha256:d1c49e1f44fffafd9a55e1a9b101590859d881d639ea2922516f5d9c512d354e"},
+]
+
+[package.dependencies]
+async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.12.0\""}
+
+[package.extras]
+docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
+test = ["flake8 (>=6.1,<7.0)", "uvloop (>=0.15.3)"]
+
+[[package]]
+name = "atlassian-python-api"
+version = "3.41.4"
+description = "Python Atlassian REST API Wrapper"
+optional = true
+python-versions = "*"
+files = [
+ {file = "atlassian-python-api-3.41.4.tar.gz", hash = "sha256:10e51fc2bd2d13423d6e34b1534600366758a711fd4fd8b9cdc2de658bf327fa"},
+ {file = "atlassian_python_api-3.41.4-py3-none-any.whl", hash = "sha256:4bfbf2b420addc77c636b67c3037d7056185662762a78f78d96ec1e267d0b435"},
+]
+
+[package.dependencies]
+deprecated = "*"
+oauthlib = "*"
+requests = "*"
+requests-oauthlib = "*"
+six = "*"
+
+[package.extras]
+kerberos = ["requests-kerberos"]
+
+[[package]]
+name = "attrs"
+version = "23.1.0"
+description = "Classes Without Boilerplate"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"},
+ {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"},
+]
+
+[package.extras]
+cov = ["attrs[tests]", "coverage[toml] (>=5.3)"]
+dev = ["attrs[docs,tests]", "pre-commit"]
+docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"]
+tests = ["attrs[tests-no-zope]", "zope-interface"]
+tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+
+[[package]]
+name = "babel"
+version = "2.13.1"
+description = "Internationalization utilities"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "Babel-2.13.1-py3-none-any.whl", hash = "sha256:7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed"},
+ {file = "Babel-2.13.1.tar.gz", hash = "sha256:33e0952d7dd6374af8dbf6768cc4ddf3ccfefc244f9986d4074704f2fbd18900"},
+]
+
+[package.dependencies]
+pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""}
+setuptools = {version = "*", markers = "python_version >= \"3.12\""}
+
+[package.extras]
+dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"]
+
+[[package]]
+name = "backcall"
+version = "0.2.0"
+description = "Specifications for callback functions passed in to an API"
+optional = false
+python-versions = "*"
+files = [
+ {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"},
+ {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"},
+]
+
+[[package]]
+name = "backoff"
+version = "2.2.1"
+description = "Function decoration for backoff and retry"
+optional = true
+python-versions = ">=3.7,<4.0"
+files = [
+ {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"},
+ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
+]
+
+[[package]]
+name = "backports-zoneinfo"
+version = "0.2.1"
+description = "Backport of the standard library zoneinfo module"
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"},
+ {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"},
+ {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"},
+ {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"},
+ {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"},
+ {file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"},
+ {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"},
+ {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"},
+ {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"},
+ {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"},
+ {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"},
+ {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"},
+ {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"},
+ {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"},
+ {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"},
+ {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"},
+]
+
+[package.extras]
+tzdata = ["tzdata"]
+
+[[package]]
+name = "beautifulsoup4"
+version = "4.12.2"
+description = "Screen-scraping library"
+optional = false
+python-versions = ">=3.6.0"
+files = [
+ {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"},
+ {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"},
+]
+
+[package.dependencies]
+soupsieve = ">1.2"
+
+[package.extras]
+html5lib = ["html5lib"]
+lxml = ["lxml"]
+
+[[package]]
+name = "bibtexparser"
+version = "1.4.1"
+description = "Bibtex parser for python 3"
+optional = true
+python-versions = "*"
+files = [
+ {file = "bibtexparser-1.4.1.tar.gz", hash = "sha256:e00e29e24676c4808e0b4333b37bb55cca9cbb7871a56f63058509281588d789"},
+]
+
+[package.dependencies]
+pyparsing = ">=2.0.3"
+
+[[package]]
+name = "bleach"
+version = "6.1.0"
+description = "An easy safelist-based HTML-sanitizing tool."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "bleach-6.1.0-py3-none-any.whl", hash = "sha256:3225f354cfc436b9789c66c4ee030194bee0568fbf9cbdad3bc8b5c26c5f12b6"},
+ {file = "bleach-6.1.0.tar.gz", hash = "sha256:0a31f1837963c41d46bbf1331b8778e1308ea0791db03cc4e7357b97cf42a8fe"},
+]
+
+[package.dependencies]
+six = ">=1.9.0"
+webencodings = "*"
+
+[package.extras]
+css = ["tinycss2 (>=1.1.0,<1.3)"]
+
+[[package]]
+name = "blinker"
+version = "1.7.0"
+description = "Fast, simple object-to-object and broadcast signaling"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "blinker-1.7.0-py3-none-any.whl", hash = "sha256:c3f865d4d54db7abc53758a01601cf343fe55b84c1de4e3fa910e420b438d5b9"},
+ {file = "blinker-1.7.0.tar.gz", hash = "sha256:e6820ff6fa4e4d1d8e2747c2283749c3f547e4fee112b98555cdcdae32996182"},
+]
+
+[[package]]
+name = "boto3"
+version = "1.33.11"
+description = "The AWS SDK for Python"
+optional = false
+python-versions = ">= 3.7"
+files = [
+ {file = "boto3-1.33.11-py3-none-any.whl", hash = "sha256:8d54fa3a9290020f9a7f488f9cbe821029de0af05a677751b12973a5f726a5e2"},
+ {file = "boto3-1.33.11.tar.gz", hash = "sha256:620f1eb3e18e780be58383b4a4e10db003d2314131190514153996032c8d932d"},
+]
+
+[package.dependencies]
+botocore = ">=1.33.11,<1.34.0"
+jmespath = ">=0.7.1,<2.0.0"
+s3transfer = ">=0.8.2,<0.9.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
+
+[[package]]
+name = "botocore"
+version = "1.33.11"
+description = "Low-level, data-driven core of boto 3."
+optional = false
+python-versions = ">= 3.7"
+files = [
+ {file = "botocore-1.33.11-py3-none-any.whl", hash = "sha256:b46227eb3fa9cfdc8f5a83920ef347e67adea8095830ed265a3373b13b54421f"},
+ {file = "botocore-1.33.11.tar.gz", hash = "sha256:b14b328f902d120de0a09eaa657a9a701c0ceeb711197c2f01ef0523f855086c"},
+]
+
+[package.dependencies]
+jmespath = ">=0.7.1,<2.0.0"
+python-dateutil = ">=2.1,<3.0.0"
+urllib3 = [
+ {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""},
+ {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""},
+]
+
+[package.extras]
+crt = ["awscrt (==0.19.17)"]
+
+[[package]]
+name = "cachetools"
+version = "5.3.2"
+description = "Extensible memoizing collections and decorators"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"},
+ {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"},
+]
+
+[[package]]
+name = "cassandra-driver"
+version = "3.28.0"
+description = "DataStax Driver for Apache Cassandra"
+optional = false
+python-versions = "*"
+files = [
+ {file = "cassandra-driver-3.28.0.tar.gz", hash = "sha256:64ff130d19f994b80997c14343a8306be52a0e7ab92520a534eed944c88d70df"},
+ {file = "cassandra_driver-3.28.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8cceb2cc658b3ebf28873f84aab4f28bbd5df23a6528a5b38ecf89a45232509"},
+ {file = "cassandra_driver-3.28.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:35aef74e2a593a969b77a3fcf02d27e9b82a078d9aa66caa3bd2d2583c46a82c"},
+ {file = "cassandra_driver-3.28.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:48f20e0d21b6c7406dfd8a4d9e07fddc3c7c3d6ad7d5b5d480bf82aac7068739"},
+ {file = "cassandra_driver-3.28.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3820a421fb7e4cf215718dc35522869c5f933d4fd4c50fd43307d3ce5d9dd138"},
+ {file = "cassandra_driver-3.28.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dd9511fe5b85010e92199f6589e0733ab14ed3d2279dcc6ae504c0cef11d652"},
+ {file = "cassandra_driver-3.28.0-cp310-cp310-win32.whl", hash = "sha256:887f7e3df9b34b41de6dfdd5f2ef8804c2d9782bbc39202eda9d3b67a3c8fe37"},
+ {file = "cassandra_driver-3.28.0-cp310-cp310-win_amd64.whl", hash = "sha256:28c636239b15944103df18a12ef95e6401ceadd7b9aca2d59f4beccf9ca21e2d"},
+ {file = "cassandra_driver-3.28.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9232434459303b0e1a26fa65006fd8438475037aef4e6204a32dfaeb10e7f739"},
+ {file = "cassandra_driver-3.28.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:634a553a5309a9faa08c3256fe0237ff0308152210211f3b8eab0664335560e0"},
+ {file = "cassandra_driver-3.28.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4a101114a7d93505ee79272edc82dba0cfc706172ad7948a6e4fb3dc1eb8b59c"},
+ {file = "cassandra_driver-3.28.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d844ba0089111858fad3c53897b0fea7c91cedd8bd205eeb82fe22fd60e748"},
+ {file = "cassandra_driver-3.28.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3bf6bacb60dc8d1b8ba5ddd7d35772e3b98da951aed6bb148827aa9c38cd009"},
+ {file = "cassandra_driver-3.28.0-cp311-cp311-win32.whl", hash = "sha256:212eb39ca99ab5960eb5c31ce279b61e075df02ac7a6209415982a3f8cfe1126"},
+ {file = "cassandra_driver-3.28.0-cp311-cp311-win_amd64.whl", hash = "sha256:777f60ed821ec43d5b3f7a65eaf02decbd9cbc11e32f2099bfe9d7a6bfe33da9"},
+ {file = "cassandra_driver-3.28.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b867c49c3c9efa21923845456cfb3e81ad13a33e40eb20279f58b3642d54614f"},
+ {file = "cassandra_driver-3.28.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1dc54edf3b664dc8e45a9c8fed163dacbad8bc92c788c84a371ccb700e18638"},
+ {file = "cassandra_driver-3.28.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e383aff200b7194d0d5625bf162bbc8471d05db7163c546341e5f27b36b53134"},
+ {file = "cassandra_driver-3.28.0-cp37-cp37m-win32.whl", hash = "sha256:a5e8b066f816868b344c108f34acc04b53c44caed2cdbcfe08ebdcbc1fd35046"},
+ {file = "cassandra_driver-3.28.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ae8c8e9a46e1b0174ace1e836d4ea97292aa6de509db0def0f816322468fb430"},
+ {file = "cassandra_driver-3.28.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:d5e8cf7db955b113f51274f166be9db0f0a06620c894abc41159828f0aeda259"},
+ {file = "cassandra_driver-3.28.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:26cbdb0d04f749b78bf7de17fd6a713b90430d1c70d8aa442845d51db823b9eb"},
+ {file = "cassandra_driver-3.28.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:fe302940780932d83414ad5282c8a6bd72b248f3b1fceff995f28c77a6ebc925"},
+ {file = "cassandra_driver-3.28.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3694c1e19d310668f5a60c16511fb12c3ad4c387d089a8080b74239a916620fb"},
+ {file = "cassandra_driver-3.28.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f5690b7b121e82c4365d298bd49dc574ecd8eed3ec0bafdf43fce708f2f992b"},
+ {file = "cassandra_driver-3.28.0-cp38-cp38-win32.whl", hash = "sha256:d09c8b0b392064054656050448dece04e4fa890af3c677a2f2034af14983ceb5"},
+ {file = "cassandra_driver-3.28.0-cp38-cp38-win_amd64.whl", hash = "sha256:e2342420bae4f80587e2ddebb38ade448c9ab1d210787a8030c1c04f54ef4a84"},
+ {file = "cassandra_driver-3.28.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c8d934cb7eac6586823a7eb69d40019154fd8e7d640bfaed49ac7edc373578df"},
+ {file = "cassandra_driver-3.28.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8b51805d57ff6ed73a95c83c25d0479391da28c765c2bf019ee1370d8ca64cd0"},
+ {file = "cassandra_driver-3.28.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5f05495ccabe5be046bb9f1c2cc3e3ff696a94fd4f2f2b1004c951e56b1ea38d"},
+ {file = "cassandra_driver-3.28.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59050666423c4ffdda9626676c18cce83a71c8331dd3d99f6b9306e0941348cf"},
+ {file = "cassandra_driver-3.28.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a665841c15f2fade6b00a8404d3424fed8757971b75e791b69bfedacc4753f7c"},
+ {file = "cassandra_driver-3.28.0-cp39-cp39-win32.whl", hash = "sha256:46433de332b8ef59ad44140f287b584303b90111cf6f355ec8c990830135dd21"},
+ {file = "cassandra_driver-3.28.0-cp39-cp39-win_amd64.whl", hash = "sha256:5e6213f10d58b05a6120bcff4f479d89c152d3f4ba43b3bda3283ee67c3abe23"},
+]
+
+[package.dependencies]
+geomet = ">=0.1,<0.3"
+six = ">=1.9"
+
+[package.extras]
+cle = ["cryptography (>=35.0)"]
+graph = ["gremlinpython (==3.4.6)"]
+
+[[package]]
+name = "cassio"
+version = "0.1.3"
+description = "A framework-agnostic Python library to seamlessly integrate Apache Cassandra(R) with ML/LLM/genAI workloads."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "cassio-0.1.3-py3-none-any.whl", hash = "sha256:2ced5b7e5c6e58b7b4647388d8629c77fdb9a8d745f8763e7e87d1da924ff0f1"},
+ {file = "cassio-0.1.3.tar.gz", hash = "sha256:dbea30c1aa3014205fd48e036d2bcc8ba949e8b3f3351ca9cef698665cb40a18"},
+]
+
+[package.dependencies]
+cassandra-driver = ">=3.28.0"
+numpy = ">=1.0"
+requests = ">=2"
+
+[[package]]
+name = "certifi"
+version = "2023.11.17"
+description = "Python package for providing Mozilla's CA Bundle."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"},
+ {file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"},
+]
+
+[[package]]
+name = "cffi"
+version = "1.16.0"
+description = "Foreign Function Interface for Python calling C code."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"},
+ {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"},
+ {file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"},
+ {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"},
+ {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"},
+ {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"},
+ {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"},
+ {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"},
+ {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"},
+ {file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"},
+ {file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"},
+ {file = "cffi-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404"},
+ {file = "cffi-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417"},
+ {file = "cffi-1.16.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627"},
+ {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936"},
+ {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d"},
+ {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56"},
+ {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e"},
+ {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc"},
+ {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb"},
+ {file = "cffi-1.16.0-cp311-cp311-win32.whl", hash = "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab"},
+ {file = "cffi-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba"},
+ {file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"},
+ {file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"},
+ {file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"},
+ {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"},
+ {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"},
+ {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"},
+ {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"},
+ {file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"},
+ {file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"},
+ {file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"},
+ {file = "cffi-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc"},
+ {file = "cffi-1.16.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0"},
+ {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b"},
+ {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c"},
+ {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b"},
+ {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324"},
+ {file = "cffi-1.16.0-cp38-cp38-win32.whl", hash = "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a"},
+ {file = "cffi-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36"},
+ {file = "cffi-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed"},
+ {file = "cffi-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2"},
+ {file = "cffi-1.16.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872"},
+ {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8"},
+ {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f"},
+ {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4"},
+ {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098"},
+ {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000"},
+ {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe"},
+ {file = "cffi-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4"},
+ {file = "cffi-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8"},
+ {file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"},
+]
+
+[package.dependencies]
+pycparser = "*"
+
+[[package]]
+name = "chardet"
+version = "5.2.0"
+description = "Universal encoding detector for Python 3"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "chardet-5.2.0-py3-none-any.whl", hash = "sha256:e1cf59446890a00105fe7b7912492ea04b6e6f06d4b742b2c788469e34c82970"},
+ {file = "chardet-5.2.0.tar.gz", hash = "sha256:1b3b6ff479a8c414bc3fa2c0852995695c4a026dcd6d0633b2dd092ca39c1cf7"},
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.3.2"
+description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+optional = false
+python-versions = ">=3.7.0"
+files = [
+ {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"},
+ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
+]
+
+[[package]]
+name = "click"
+version = "8.1.7"
+description = "Composable command line interface toolkit"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
+ {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[[package]]
+name = "click-plugins"
+version = "1.1.1"
+description = "An extension module for click to enable registering CLI commands via setuptools entry-points."
+optional = true
+python-versions = "*"
+files = [
+ {file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"},
+ {file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"},
+]
+
+[package.dependencies]
+click = ">=4.0"
+
+[package.extras]
+dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"]
+
+[[package]]
+name = "cligj"
+version = "0.7.2"
+description = "Click params for commmand line interfaces to GeoJSON"
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4"
+files = [
+ {file = "cligj-0.7.2-py3-none-any.whl", hash = "sha256:c1ca117dbce1fe20a5809dc96f01e1c2840f6dcc939b3ddbb1111bf330ba82df"},
+ {file = "cligj-0.7.2.tar.gz", hash = "sha256:a4bc13d623356b373c2c27c53dbd9c68cae5d526270bfa71f6c6fa69669c6b27"},
+]
+
+[package.dependencies]
+click = ">=4.0"
+
+[package.extras]
+test = ["pytest-cov"]
+
+[[package]]
+name = "cloudpickle"
+version = "2.2.1"
+description = "Extended pickling support for Python objects"
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"},
+ {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"},
+]
+
+[[package]]
+name = "codespell"
+version = "2.2.6"
+description = "Codespell"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "codespell-2.2.6-py3-none-any.whl", hash = "sha256:9ee9a3e5df0990604013ac2a9f22fa8e57669c827124a2e961fe8a1da4cacc07"},
+ {file = "codespell-2.2.6.tar.gz", hash = "sha256:a8c65d8eb3faa03deabab6b3bbe798bea72e1799c7e9e955d57eca4096abcff9"},
+]
+
+[package.extras]
+dev = ["Pygments", "build", "chardet", "pre-commit", "pytest", "pytest-cov", "pytest-dependency", "ruff", "tomli", "twine"]
+hard-encoding-detection = ["chardet"]
+toml = ["tomli"]
+types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency"]
+
+[[package]]
+name = "cohere"
+version = "4.37"
+description = "Python SDK for the Cohere API"
+optional = true
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "cohere-4.37-py3-none-any.whl", hash = "sha256:f3fad3a0f8d86761d4de851dfd2233a1e5c7634a024102212d850bde9c9bb031"},
+ {file = "cohere-4.37.tar.gz", hash = "sha256:788021d9d992c6c31d1985d95cccb277c7265882c4acd7a49b3e47da77b4bec8"},
+]
+
+[package.dependencies]
+aiohttp = ">=3.0,<4.0"
+backoff = ">=2.0,<3.0"
+fastavro = ">=1.8,<2.0"
+importlib_metadata = ">=6.0,<7.0"
+requests = ">=2.25.0,<3.0.0"
+urllib3 = ">=1.26,<3"
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+description = "Cross-platform colored terminal text."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+files = [
+ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+
+[[package]]
+name = "coloredlogs"
+version = "15.0.1"
+description = "Colored terminal output for Python's logging module"
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+ {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
+ {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
+]
+
+[package.dependencies]
+humanfriendly = ">=9.1"
+
+[package.extras]
+cron = ["capturer (>=2.4)"]
+
+[[package]]
+name = "comm"
+version = "0.2.0"
+description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "comm-0.2.0-py3-none-any.whl", hash = "sha256:2da8d9ebb8dd7bfc247adaff99f24dce705638a8042b85cb995066793e391001"},
+ {file = "comm-0.2.0.tar.gz", hash = "sha256:a517ea2ca28931c7007a7a99c562a0fa5883cfb48963140cf642c41c948498be"},
+]
+
+[package.dependencies]
+traitlets = ">=4"
+
+[package.extras]
+test = ["pytest"]
+
+[[package]]
+name = "coverage"
+version = "7.3.2"
+description = "Code coverage measurement for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "coverage-7.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d872145f3a3231a5f20fd48500274d7df222e291d90baa2026cc5152b7ce86bf"},
+ {file = "coverage-7.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:310b3bb9c91ea66d59c53fa4989f57d2436e08f18fb2f421a1b0b6b8cc7fffda"},
+ {file = "coverage-7.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f47d39359e2c3779c5331fc740cf4bce6d9d680a7b4b4ead97056a0ae07cb49a"},
+ {file = "coverage-7.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa72dbaf2c2068404b9870d93436e6d23addd8bbe9295f49cbca83f6e278179c"},
+ {file = "coverage-7.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:beaa5c1b4777f03fc63dfd2a6bd820f73f036bfb10e925fce067b00a340d0f3f"},
+ {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:dbc1b46b92186cc8074fee9d9fbb97a9dd06c6cbbef391c2f59d80eabdf0faa6"},
+ {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:315a989e861031334d7bee1f9113c8770472db2ac484e5b8c3173428360a9148"},
+ {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d1bc430677773397f64a5c88cb522ea43175ff16f8bfcc89d467d974cb2274f9"},
+ {file = "coverage-7.3.2-cp310-cp310-win32.whl", hash = "sha256:a889ae02f43aa45032afe364c8ae84ad3c54828c2faa44f3bfcafecb5c96b02f"},
+ {file = "coverage-7.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:c0ba320de3fb8c6ec16e0be17ee1d3d69adcda99406c43c0409cb5c41788a611"},
+ {file = "coverage-7.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ac8c802fa29843a72d32ec56d0ca792ad15a302b28ca6203389afe21f8fa062c"},
+ {file = "coverage-7.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:89a937174104339e3a3ffcf9f446c00e3a806c28b1841c63edb2b369310fd074"},
+ {file = "coverage-7.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e267e9e2b574a176ddb983399dec325a80dbe161f1a32715c780b5d14b5f583a"},
+ {file = "coverage-7.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2443cbda35df0d35dcfb9bf8f3c02c57c1d6111169e3c85fc1fcc05e0c9f39a3"},
+ {file = "coverage-7.3.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4175e10cc8dda0265653e8714b3174430b07c1dca8957f4966cbd6c2b1b8065a"},
+ {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf38419fb1a347aaf63481c00f0bdc86889d9fbf3f25109cf96c26b403fda1"},
+ {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5c913b556a116b8d5f6ef834038ba983834d887d82187c8f73dec21049abd65c"},
+ {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1981f785239e4e39e6444c63a98da3a1db8e971cb9ceb50a945ba6296b43f312"},
+ {file = "coverage-7.3.2-cp311-cp311-win32.whl", hash = "sha256:43668cabd5ca8258f5954f27a3aaf78757e6acf13c17604d89648ecc0cc66640"},
+ {file = "coverage-7.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10c39c0452bf6e694511c901426d6b5ac005acc0f78ff265dbe36bf81f808a2"},
+ {file = "coverage-7.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4cbae1051ab791debecc4a5dcc4a1ff45fc27b91b9aee165c8a27514dd160836"},
+ {file = "coverage-7.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12d15ab5833a997716d76f2ac1e4b4d536814fc213c85ca72756c19e5a6b3d63"},
+ {file = "coverage-7.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c7bba973ebee5e56fe9251300c00f1579652587a9f4a5ed8404b15a0471f216"},
+ {file = "coverage-7.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe494faa90ce6381770746077243231e0b83ff3f17069d748f645617cefe19d4"},
+ {file = "coverage-7.3.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6e9589bd04d0461a417562649522575d8752904d35c12907d8c9dfeba588faf"},
+ {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d51ac2a26f71da1b57f2dc81d0e108b6ab177e7d30e774db90675467c847bbdf"},
+ {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:99b89d9f76070237975b315b3d5f4d6956ae354a4c92ac2388a5695516e47c84"},
+ {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fa28e909776dc69efb6ed975a63691bc8172b64ff357e663a1bb06ff3c9b589a"},
+ {file = "coverage-7.3.2-cp312-cp312-win32.whl", hash = "sha256:289fe43bf45a575e3ab10b26d7b6f2ddb9ee2dba447499f5401cfb5ecb8196bb"},
+ {file = "coverage-7.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:7dbc3ed60e8659bc59b6b304b43ff9c3ed858da2839c78b804973f613d3e92ed"},
+ {file = "coverage-7.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f94b734214ea6a36fe16e96a70d941af80ff3bfd716c141300d95ebc85339738"},
+ {file = "coverage-7.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:af3d828d2c1cbae52d34bdbb22fcd94d1ce715d95f1a012354a75e5913f1bda2"},
+ {file = "coverage-7.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630b13e3036e13c7adc480ca42fa7afc2a5d938081d28e20903cf7fd687872e2"},
+ {file = "coverage-7.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9eacf273e885b02a0273bb3a2170f30e2d53a6d53b72dbe02d6701b5296101c"},
+ {file = "coverage-7.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8f17966e861ff97305e0801134e69db33b143bbfb36436efb9cfff6ec7b2fd9"},
+ {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b4275802d16882cf9c8b3d057a0839acb07ee9379fa2749eca54efbce1535b82"},
+ {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:72c0cfa5250f483181e677ebc97133ea1ab3eb68645e494775deb6a7f6f83901"},
+ {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cb536f0dcd14149425996821a168f6e269d7dcd2c273a8bff8201e79f5104e76"},
+ {file = "coverage-7.3.2-cp38-cp38-win32.whl", hash = "sha256:307adb8bd3abe389a471e649038a71b4eb13bfd6b7dd9a129fa856f5c695cf92"},
+ {file = "coverage-7.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:88ed2c30a49ea81ea3b7f172e0269c182a44c236eb394718f976239892c0a27a"},
+ {file = "coverage-7.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b631c92dfe601adf8f5ebc7fc13ced6bb6e9609b19d9a8cd59fa47c4186ad1ce"},
+ {file = "coverage-7.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d3d9df4051c4a7d13036524b66ecf7a7537d14c18a384043f30a303b146164e9"},
+ {file = "coverage-7.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f7363d3b6a1119ef05015959ca24a9afc0ea8a02c687fe7e2d557705375c01f"},
+ {file = "coverage-7.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f11cc3c967a09d3695d2a6f03fb3e6236622b93be7a4b5dc09166a861be6d25"},
+ {file = "coverage-7.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149de1d2401ae4655c436a3dced6dd153f4c3309f599c3d4bd97ab172eaf02d9"},
+ {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:3a4006916aa6fee7cd38db3bfc95aa9c54ebb4ffbfc47c677c8bba949ceba0a6"},
+ {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9028a3871280110d6e1aa2df1afd5ef003bab5fb1ef421d6dc748ae1c8ef2ebc"},
+ {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f805d62aec8eb92bab5b61c0f07329275b6f41c97d80e847b03eb894f38d083"},
+ {file = "coverage-7.3.2-cp39-cp39-win32.whl", hash = "sha256:d1c88ec1a7ff4ebca0219f5b1ef863451d828cccf889c173e1253aa84b1e07ce"},
+ {file = "coverage-7.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b4767da59464bb593c07afceaddea61b154136300881844768037fd5e859353f"},
+ {file = "coverage-7.3.2-pp38.pp39.pp310-none-any.whl", hash = "sha256:ae97af89f0fbf373400970c0a21eef5aa941ffeed90aee43650b81f7d7f47637"},
+ {file = "coverage-7.3.2.tar.gz", hash = "sha256:be32ad29341b0170e795ca590e1c07e81fc061cb5b10c74ce7203491484404ef"},
+]
+
+[package.dependencies]
+tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
+
+[package.extras]
+toml = ["tomli"]
+
+[[package]]
+name = "cryptography"
+version = "41.0.7"
+description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "cryptography-41.0.7-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:3c78451b78313fa81607fa1b3f1ae0a5ddd8014c38a02d9db0616133987b9cdf"},
+ {file = "cryptography-41.0.7-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:928258ba5d6f8ae644e764d0f996d61a8777559f72dfeb2eea7e2fe0ad6e782d"},
+ {file = "cryptography-41.0.7-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a1b41bc97f1ad230a41657d9155113c7521953869ae57ac39ac7f1bb471469a"},
+ {file = "cryptography-41.0.7-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:841df4caa01008bad253bce2a6f7b47f86dc9f08df4b433c404def869f590a15"},
+ {file = "cryptography-41.0.7-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5429ec739a29df2e29e15d082f1d9ad683701f0ec7709ca479b3ff2708dae65a"},
+ {file = "cryptography-41.0.7-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:43f2552a2378b44869fe8827aa19e69512e3245a219104438692385b0ee119d1"},
+ {file = "cryptography-41.0.7-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:af03b32695b24d85a75d40e1ba39ffe7db7ffcb099fe507b39fd41a565f1b157"},
+ {file = "cryptography-41.0.7-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:49f0805fc0b2ac8d4882dd52f4a3b935b210935d500b6b805f321addc8177406"},
+ {file = "cryptography-41.0.7-cp37-abi3-win32.whl", hash = "sha256:f983596065a18a2183e7f79ab3fd4c475205b839e02cbc0efbbf9666c4b3083d"},
+ {file = "cryptography-41.0.7-cp37-abi3-win_amd64.whl", hash = "sha256:90452ba79b8788fa380dfb587cca692976ef4e757b194b093d845e8d99f612f2"},
+ {file = "cryptography-41.0.7-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:079b85658ea2f59c4f43b70f8119a52414cdb7be34da5d019a77bf96d473b960"},
+ {file = "cryptography-41.0.7-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:b640981bf64a3e978a56167594a0e97db71c89a479da8e175d8bb5be5178c003"},
+ {file = "cryptography-41.0.7-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e3114da6d7f95d2dee7d3f4eec16dacff819740bbab931aff8648cb13c5ff5e7"},
+ {file = "cryptography-41.0.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d5ec85080cce7b0513cfd233914eb8b7bbd0633f1d1703aa28d1dd5a72f678ec"},
+ {file = "cryptography-41.0.7-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7a698cb1dac82c35fcf8fe3417a3aaba97de16a01ac914b89a0889d364d2f6be"},
+ {file = "cryptography-41.0.7-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:37a138589b12069efb424220bf78eac59ca68b95696fc622b6ccc1c0a197204a"},
+ {file = "cryptography-41.0.7-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:68a2dec79deebc5d26d617bfdf6e8aab065a4f34934b22d3b5010df3ba36612c"},
+ {file = "cryptography-41.0.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:09616eeaef406f99046553b8a40fbf8b1e70795a91885ba4c96a70793de5504a"},
+ {file = "cryptography-41.0.7-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48a0476626da912a44cc078f9893f292f0b3e4c739caf289268168d8f4702a39"},
+ {file = "cryptography-41.0.7-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c7f3201ec47d5207841402594f1d7950879ef890c0c495052fa62f58283fde1a"},
+ {file = "cryptography-41.0.7-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c5ca78485a255e03c32b513f8c2bc39fedb7f5c5f8535545bdc223a03b24f248"},
+ {file = "cryptography-41.0.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d6c391c021ab1f7a82da5d8d0b3cee2f4b2c455ec86c8aebbc84837a631ff309"},
+ {file = "cryptography-41.0.7.tar.gz", hash = "sha256:13f93ce9bea8016c253b34afc6bd6a75993e5c40672ed5405a9c832f0d4a00bc"},
+]
+
+[package.dependencies]
+cffi = ">=1.12"
+
+[package.extras]
+docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"]
+docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
+nox = ["nox"]
+pep8test = ["black", "check-sdist", "mypy", "ruff"]
+sdist = ["build"]
+ssh = ["bcrypt (>=3.1.5)"]
+test = ["pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
+test-randomorder = ["pytest-randomly"]
+
+[[package]]
+name = "cssselect"
+version = "1.2.0"
+description = "cssselect parses CSS3 Selectors and translates them to XPath 1.0"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "cssselect-1.2.0-py2.py3-none-any.whl", hash = "sha256:da1885f0c10b60c03ed5eccbb6b68d6eff248d91976fcde348f395d54c9fd35e"},
+ {file = "cssselect-1.2.0.tar.gz", hash = "sha256:666b19839cfaddb9ce9d36bfe4c969132c647b92fc9088c4e23f786b30f1b3dc"},
+]
+
+[[package]]
+name = "dashvector"
+version = "1.0.7"
+description = "DashVector Client Python Sdk Library"
+optional = true
+python-versions = ">=3.7,<4.0"
+files = [
+ {file = "dashvector-1.0.7-py3-none-any.whl", hash = "sha256:35457170edefb4c1b5a33229f8f7a8ab7135bf6938c1f88959304fe8a155fbb0"},
+ {file = "dashvector-1.0.7.tar.gz", hash = "sha256:75fb369756582384e449bcd98cd9f96568e5a5d03b35625f3ff08ab3a83c29a1"},
+]
+
+[package.dependencies]
+aiohttp = ">=3.1.0,<4.0.0"
+certifi = ">=2023.7.22,<2024.0.0"
+grpcio = [
+ {version = ">=1.49.1,<=1.56.0", markers = "sys_platform == \"win32\" and python_version >= \"3.11\""},
+ {version = ">=1.49.1", markers = "python_version >= \"3.11\" and sys_platform != \"win32\""},
+ {version = ">=1.22.0,<=1.56.0", markers = "sys_platform == \"win32\" and python_version < \"3.11\""},
+ {version = ">=1.22.0", markers = "python_version < \"3.11\" and sys_platform != \"win32\""},
+]
+numpy = "*"
+protobuf = ">=3.8.0,<4.0.0"
+
+[[package]]
+name = "databricks-cli"
+version = "0.18.0"
+description = "A command line interface for Databricks"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "databricks-cli-0.18.0.tar.gz", hash = "sha256:87569709eda9af3e9db8047b691e420b5e980c62ef01675575c0d2b9b4211eb7"},
+ {file = "databricks_cli-0.18.0-py2.py3-none-any.whl", hash = "sha256:1176a5f42d3e8af4abfc915446fb23abc44513e325c436725f5898cbb9e3384b"},
+]
+
+[package.dependencies]
+click = ">=7.0"
+oauthlib = ">=3.1.0"
+pyjwt = ">=1.7.0"
+requests = ">=2.17.3"
+six = ">=1.10.0"
+tabulate = ">=0.7.7"
+urllib3 = ">=1.26.7,<3"
+
+[[package]]
+name = "databricks-vectorsearch"
+version = "0.21"
+description = "Databricks Vector Search Client"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "databricks_vectorsearch-0.21-py3-none-any.whl", hash = "sha256:18265affdb38d44e7ec4cc95f8267379c5109bdb6e75bb61a729f126b2433868"},
+]
+
+[package.dependencies]
+mlflow-skinny = ">=2.4.0,<3"
+protobuf = ">=3.12.0,<5"
+requests = ">=2"
+
+[[package]]
+name = "dataclasses-json"
+version = "0.6.3"
+description = "Easily serialize dataclasses to and from JSON."
+optional = false
+python-versions = ">=3.7,<4.0"
+files = [
+ {file = "dataclasses_json-0.6.3-py3-none-any.whl", hash = "sha256:4aeb343357997396f6bca1acae64e486c3a723d8f5c76301888abeccf0c45176"},
+ {file = "dataclasses_json-0.6.3.tar.gz", hash = "sha256:35cb40aae824736fdf959801356641836365219cfe14caeb115c39136f775d2a"},
+]
+
+[package.dependencies]
+marshmallow = ">=3.18.0,<4.0.0"
+typing-inspect = ">=0.4.0,<1"
+
+[[package]]
+name = "datasets"
+version = "2.15.0"
+description = "HuggingFace community-driven open-source library of datasets"
+optional = true
+python-versions = ">=3.8.0"
+files = [
+ {file = "datasets-2.15.0-py3-none-any.whl", hash = "sha256:6d658d23811393dfc982d026082e1650bdaaae28f6a86e651966cb072229a228"},
+ {file = "datasets-2.15.0.tar.gz", hash = "sha256:a26d059370bd7503bd60e9337977199a13117a83f72fb61eda7e66f0c4d50b2b"},
+]
+
+[package.dependencies]
+aiohttp = "*"
+dill = ">=0.3.0,<0.3.8"
+fsspec = {version = ">=2023.1.0,<=2023.10.0", extras = ["http"]}
+huggingface-hub = ">=0.18.0"
+multiprocess = "*"
+numpy = ">=1.17"
+packaging = "*"
+pandas = "*"
+pyarrow = ">=8.0.0"
+pyarrow-hotfix = "*"
+pyyaml = ">=5.1"
+requests = ">=2.19.0"
+tqdm = ">=4.62.1"
+xxhash = "*"
+
+[package.extras]
+apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"]
+audio = ["librosa", "soundfile (>=0.12.1)"]
+benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
+dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
+docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
+jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
+metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
+quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"]
+s3 = ["s3fs"]
+tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
+tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
+tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
+torch = ["torch"]
+vision = ["Pillow (>=6.2.1)"]
+
+[[package]]
+name = "debugpy"
+version = "1.8.0"
+description = "An implementation of the Debug Adapter Protocol for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "debugpy-1.8.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7fb95ca78f7ac43393cd0e0f2b6deda438ec7c5e47fa5d38553340897d2fbdfb"},
+ {file = "debugpy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef9ab7df0b9a42ed9c878afd3eaaff471fce3fa73df96022e1f5c9f8f8c87ada"},
+ {file = "debugpy-1.8.0-cp310-cp310-win32.whl", hash = "sha256:a8b7a2fd27cd9f3553ac112f356ad4ca93338feadd8910277aff71ab24d8775f"},
+ {file = "debugpy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5d9de202f5d42e62f932507ee8b21e30d49aae7e46d5b1dd5c908db1d7068637"},
+ {file = "debugpy-1.8.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:ef54404365fae8d45cf450d0544ee40cefbcb9cb85ea7afe89a963c27028261e"},
+ {file = "debugpy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60009b132c91951354f54363f8ebdf7457aeb150e84abba5ae251b8e9f29a8a6"},
+ {file = "debugpy-1.8.0-cp311-cp311-win32.whl", hash = "sha256:8cd0197141eb9e8a4566794550cfdcdb8b3db0818bdf8c49a8e8f8053e56e38b"},
+ {file = "debugpy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:a64093656c4c64dc6a438e11d59369875d200bd5abb8f9b26c1f5f723622e153"},
+ {file = "debugpy-1.8.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:b05a6b503ed520ad58c8dc682749113d2fd9f41ffd45daec16e558ca884008cd"},
+ {file = "debugpy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c6fb41c98ec51dd010d7ed650accfd07a87fe5e93eca9d5f584d0578f28f35f"},
+ {file = "debugpy-1.8.0-cp38-cp38-win32.whl", hash = "sha256:46ab6780159eeabb43c1495d9c84cf85d62975e48b6ec21ee10c95767c0590aa"},
+ {file = "debugpy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:bdc5ef99d14b9c0fcb35351b4fbfc06ac0ee576aeab6b2511702e5a648a2e595"},
+ {file = "debugpy-1.8.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:61eab4a4c8b6125d41a34bad4e5fe3d2cc145caecd63c3fe953be4cc53e65bf8"},
+ {file = "debugpy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:125b9a637e013f9faac0a3d6a82bd17c8b5d2c875fb6b7e2772c5aba6d082332"},
+ {file = "debugpy-1.8.0-cp39-cp39-win32.whl", hash = "sha256:57161629133113c97b387382045649a2b985a348f0c9366e22217c87b68b73c6"},
+ {file = "debugpy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:e3412f9faa9ade82aa64a50b602544efcba848c91384e9f93497a458767e6926"},
+ {file = "debugpy-1.8.0-py2.py3-none-any.whl", hash = "sha256:9c9b0ac1ce2a42888199df1a1906e45e6f3c9555497643a85e0bf2406e3ffbc4"},
+ {file = "debugpy-1.8.0.zip", hash = "sha256:12af2c55b419521e33d5fb21bd022df0b5eb267c3e178f1d374a63a2a6bdccd0"},
+]
+
+[[package]]
+name = "decorator"
+version = "5.1.1"
+description = "Decorators for Humans"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
+ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
+]
+
+[[package]]
+name = "defusedxml"
+version = "0.7.1"
+description = "XML bomb protection for Python stdlib modules"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
+ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
+]
+
+[[package]]
+name = "deprecated"
+version = "1.2.14"
+description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"},
+ {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"},
+]
+
+[package.dependencies]
+wrapt = ">=1.10,<2"
+
+[package.extras]
+dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
+
+[[package]]
+name = "deprecation"
+version = "2.1.0"
+description = "A library to handle automated deprecations"
+optional = true
+python-versions = "*"
+files = [
+ {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
+ {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
+]
+
+[package.dependencies]
+packaging = "*"
+
+[[package]]
+name = "dgml-utils"
+version = "0.3.0"
+description = "Python utilities to work with the Docugami Markup Language (DGML) format."
+optional = true
+python-versions = ">=3.8.1,<4.0"
+files = [
+ {file = "dgml_utils-0.3.0-py3-none-any.whl", hash = "sha256:0cb8f6fd7f5fa31919343266260c166aa53009b42a11a172e808fc707e1ac5ba"},
+ {file = "dgml_utils-0.3.0.tar.gz", hash = "sha256:02722e899122caedfb1e90d0be557c7e6dddf86f7f4c19d7888212efde9f78c9"},
+]
+
+[package.dependencies]
+lxml = ">=4.9.3,<5.0.0"
+tabulate = ">=0.9.0,<0.10.0"
+
+[[package]]
+name = "dill"
+version = "0.3.7"
+description = "serialize all of Python"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"},
+ {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"},
+]
+
+[package.extras]
+graph = ["objgraph (>=1.7.2)"]
+
+[[package]]
+name = "distro"
+version = "1.8.0"
+description = "Distro - an OS platform information API"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "distro-1.8.0-py3-none-any.whl", hash = "sha256:99522ca3e365cac527b44bde033f64c6945d90eb9f769703caaec52b09bbd3ff"},
+ {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"},
+]
+
+[[package]]
+name = "dnspython"
+version = "2.4.2"
+description = "DNS toolkit"
+optional = true
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "dnspython-2.4.2-py3-none-any.whl", hash = "sha256:57c6fbaaeaaf39c891292012060beb141791735dbb4004798328fc2c467402d8"},
+ {file = "dnspython-2.4.2.tar.gz", hash = "sha256:8dcfae8c7460a2f84b4072e26f1c9f4101ca20c071649cb7c34e8b6a93d58984"},
+]
+
+[package.extras]
+dnssec = ["cryptography (>=2.6,<42.0)"]
+doh = ["h2 (>=4.1.0)", "httpcore (>=0.17.3)", "httpx (>=0.24.1)"]
+doq = ["aioquic (>=0.9.20)"]
+idna = ["idna (>=2.1,<4.0)"]
+trio = ["trio (>=0.14,<0.23)"]
+wmi = ["wmi (>=1.5.1,<2.0.0)"]
+
+[[package]]
+name = "docopt"
+version = "0.6.2"
+description = "Pythonic argument parser, that will make you smile"
+optional = true
+python-versions = "*"
+files = [
+ {file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"},
+]
+
+[[package]]
+name = "duckdb"
+version = "0.9.2"
+description = "DuckDB embedded database"
+optional = false
+python-versions = ">=3.7.0"
+files = [
+ {file = "duckdb-0.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:aadcea5160c586704c03a8a796c06a8afffbefefb1986601104a60cb0bfdb5ab"},
+ {file = "duckdb-0.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:08215f17147ed83cbec972175d9882387366de2ed36c21cbe4add04b39a5bcb4"},
+ {file = "duckdb-0.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ee6c2a8aba6850abef5e1be9dbc04b8e72a5b2c2b67f77892317a21fae868fe7"},
+ {file = "duckdb-0.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ff49f3da9399900fd58b5acd0bb8bfad22c5147584ad2427a78d937e11ec9d0"},
+ {file = "duckdb-0.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd5ac5baf8597efd2bfa75f984654afcabcd698342d59b0e265a0bc6f267b3f0"},
+ {file = "duckdb-0.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:81c6df905589a1023a27e9712edb5b724566587ef280a0c66a7ec07c8083623b"},
+ {file = "duckdb-0.9.2-cp310-cp310-win32.whl", hash = "sha256:a298cd1d821c81d0dec8a60878c4b38c1adea04a9675fb6306c8f9083bbf314d"},
+ {file = "duckdb-0.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:492a69cd60b6cb4f671b51893884cdc5efc4c3b2eb76057a007d2a2295427173"},
+ {file = "duckdb-0.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:061a9ea809811d6e3025c5de31bc40e0302cfb08c08feefa574a6491e882e7e8"},
+ {file = "duckdb-0.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a43f93be768af39f604b7b9b48891f9177c9282a408051209101ff80f7450d8f"},
+ {file = "duckdb-0.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ac29c8c8f56fff5a681f7bf61711ccb9325c5329e64f23cb7ff31781d7b50773"},
+ {file = "duckdb-0.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b14d98d26bab139114f62ade81350a5342f60a168d94b27ed2c706838f949eda"},
+ {file = "duckdb-0.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:796a995299878913e765b28cc2b14c8e44fae2f54ab41a9ee668c18449f5f833"},
+ {file = "duckdb-0.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6cb64ccfb72c11ec9c41b3cb6181b6fd33deccceda530e94e1c362af5f810ba1"},
+ {file = "duckdb-0.9.2-cp311-cp311-win32.whl", hash = "sha256:930740cb7b2cd9e79946e1d3a8f66e15dc5849d4eaeff75c8788d0983b9256a5"},
+ {file = "duckdb-0.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:c28f13c45006fd525001b2011cdf91fa216530e9751779651e66edc0e446be50"},
+ {file = "duckdb-0.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fbce7bbcb4ba7d99fcec84cec08db40bc0dd9342c6c11930ce708817741faeeb"},
+ {file = "duckdb-0.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15a82109a9e69b1891f0999749f9e3265f550032470f51432f944a37cfdc908b"},
+ {file = "duckdb-0.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9490fb9a35eb74af40db5569d90df8a04a6f09ed9a8c9caa024998c40e2506aa"},
+ {file = "duckdb-0.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:696d5c6dee86c1a491ea15b74aafe34ad2b62dcd46ad7e03b1d00111ca1a8c68"},
+ {file = "duckdb-0.9.2-cp37-cp37m-win32.whl", hash = "sha256:4f0935300bdf8b7631ddfc838f36a858c1323696d8c8a2cecbd416bddf6b0631"},
+ {file = "duckdb-0.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:0aab900f7510e4d2613263865570203ddfa2631858c7eb8cbed091af6ceb597f"},
+ {file = "duckdb-0.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:7d8130ed6a0c9421b135d0743705ea95b9a745852977717504e45722c112bf7a"},
+ {file = "duckdb-0.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:974e5de0294f88a1a837378f1f83330395801e9246f4e88ed3bfc8ada65dcbee"},
+ {file = "duckdb-0.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4fbc297b602ef17e579bb3190c94d19c5002422b55814421a0fc11299c0c1100"},
+ {file = "duckdb-0.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1dd58a0d84a424924a35b3772419f8cd78a01c626be3147e4934d7a035a8ad68"},
+ {file = "duckdb-0.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11a1194a582c80dfb57565daa06141727e415ff5d17e022dc5f31888a5423d33"},
+ {file = "duckdb-0.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:be45d08541002a9338e568dca67ab4f20c0277f8f58a73dfc1435c5b4297c996"},
+ {file = "duckdb-0.9.2-cp38-cp38-win32.whl", hash = "sha256:dd6f88aeb7fc0bfecaca633629ff5c986ac966fe3b7dcec0b2c48632fd550ba2"},
+ {file = "duckdb-0.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:28100c4a6a04e69aa0f4a6670a6d3d67a65f0337246a0c1a429f3f28f3c40b9a"},
+ {file = "duckdb-0.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7ae5bf0b6ad4278e46e933e51473b86b4b932dbc54ff097610e5b482dd125552"},
+ {file = "duckdb-0.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e5d0bb845a80aa48ed1fd1d2d285dd352e96dc97f8efced2a7429437ccd1fe1f"},
+ {file = "duckdb-0.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ce262d74a52500d10888110dfd6715989926ec936918c232dcbaddb78fc55b4"},
+ {file = "duckdb-0.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6935240da090a7f7d2666f6d0a5e45ff85715244171ca4e6576060a7f4a1200e"},
+ {file = "duckdb-0.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5cfb93e73911696a98b9479299d19cfbc21dd05bb7ab11a923a903f86b4d06e"},
+ {file = "duckdb-0.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:64e3bc01751f31e7572d2716c3e8da8fe785f1cdc5be329100818d223002213f"},
+ {file = "duckdb-0.9.2-cp39-cp39-win32.whl", hash = "sha256:6e5b80f46487636368e31b61461940e3999986359a78660a50dfdd17dd72017c"},
+ {file = "duckdb-0.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:e6142a220180dbeea4f341708bd5f9501c5c962ce7ef47c1cadf5e8810b4cb13"},
+ {file = "duckdb-0.9.2.tar.gz", hash = "sha256:3843afeab7c3fc4a4c0b53686a4cc1d9cdbdadcbb468d60fef910355ecafd447"},
+]
+
+[[package]]
+name = "duckdb-engine"
+version = "0.9.2"
+description = "SQLAlchemy driver for duckdb"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "duckdb_engine-0.9.2-py3-none-any.whl", hash = "sha256:764e83dfb37e2f0ce6afcb8e701299e7b28060a40fdae86cfd7f08e0fca4496a"},
+ {file = "duckdb_engine-0.9.2.tar.gz", hash = "sha256:efcd7b468f9b17e4480a97f0c60eade25cc081e8cfc04c46d63828677964b48f"},
+]
+
+[package.dependencies]
+duckdb = ">=0.4.0"
+sqlalchemy = ">=1.3.22"
+
+[[package]]
+name = "entrypoints"
+version = "0.4"
+description = "Discover and load entry points from installed packages."
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "entrypoints-0.4-py3-none-any.whl", hash = "sha256:f174b5ff827504fd3cd97cc3f8649f3693f51538c7e4bdf3ef002c8429d42f9f"},
+ {file = "entrypoints-0.4.tar.gz", hash = "sha256:b706eddaa9218a19ebcd67b56818f05bb27589b1ca9e8d797b74affad4ccacd4"},
+]
+
+[[package]]
+name = "esprima"
+version = "4.0.1"
+description = "ECMAScript parsing infrastructure for multipurpose analysis in Python"
+optional = true
+python-versions = "*"
+files = [
+ {file = "esprima-4.0.1.tar.gz", hash = "sha256:08db1a876d3c2910db9cfaeb83108193af5411fc3a3a66ebefacd390d21323ee"},
+]
+
+[[package]]
+name = "exceptiongroup"
+version = "1.2.0"
+description = "Backport of PEP 654 (exception groups)"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"},
+ {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"},
+]
+
+[package.extras]
+test = ["pytest (>=6)"]
+
+[[package]]
+name = "executing"
+version = "2.0.1"
+description = "Get the currently executing AST node of a frame, and other information"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"},
+ {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"},
+]
+
+[package.extras]
+tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"]
+
+[[package]]
+name = "faiss-cpu"
+version = "1.7.4"
+description = "A library for efficient similarity search and clustering of dense vectors."
+optional = true
+python-versions = "*"
+files = [
+ {file = "faiss-cpu-1.7.4.tar.gz", hash = "sha256:265dc31b0c079bf4433303bf6010f73922490adff9188b915e2d3f5e9c82dd0a"},
+ {file = "faiss_cpu-1.7.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:50d4ebe7f1869483751c558558504f818980292a9b55be36f9a1ee1009d9a686"},
+ {file = "faiss_cpu-1.7.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7b1db7fae7bd8312aeedd0c41536bcd19a6e297229e1dce526bde3a73ab8c0b5"},
+ {file = "faiss_cpu-1.7.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17b7fa7194a228a84929d9e6619d0e7dbf00cc0f717e3462253766f5e3d07de8"},
+ {file = "faiss_cpu-1.7.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dca531952a2e3eac56f479ff22951af4715ee44788a3fe991d208d766d3f95f3"},
+ {file = "faiss_cpu-1.7.4-cp310-cp310-win_amd64.whl", hash = "sha256:7173081d605e74766f950f2e3d6568a6f00c53f32fd9318063e96728c6c62821"},
+ {file = "faiss_cpu-1.7.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d0bbd6f55d7940cc0692f79e32a58c66106c3c950cee2341b05722de9da23ea3"},
+ {file = "faiss_cpu-1.7.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e13c14280376100f143767d0efe47dcb32618f69e62bbd3ea5cd38c2e1755926"},
+ {file = "faiss_cpu-1.7.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c521cb8462f3b00c0c7dfb11caff492bb67816528b947be28a3b76373952c41d"},
+ {file = "faiss_cpu-1.7.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afdd9fe1141117fed85961fd36ee627c83fc3b9fd47bafb52d3c849cc2f088b7"},
+ {file = "faiss_cpu-1.7.4-cp311-cp311-win_amd64.whl", hash = "sha256:2ff7f57889ea31d945e3b87275be3cad5d55b6261a4e3f51c7aba304d76b81fb"},
+ {file = "faiss_cpu-1.7.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:eeaf92f27d76249fb53c1adafe617b0f217ab65837acf7b4ec818511caf6e3d8"},
+ {file = "faiss_cpu-1.7.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:102b1bd763e9b0c281ac312590af3eaf1c8b663ccbc1145821fe6a9f92b8eaaf"},
+ {file = "faiss_cpu-1.7.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5512da6707c967310c46ff712b00418b7ae28e93cb609726136e826e9f2f14fa"},
+ {file = "faiss_cpu-1.7.4-cp37-cp37m-win_amd64.whl", hash = "sha256:0c2e5b9d8c28c99f990e87379d5bbcc6c914da91ebb4250166864fd12db5755b"},
+ {file = "faiss_cpu-1.7.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:43f67f325393145d360171cd98786fcea6120ce50397319afd3bb78be409fb8a"},
+ {file = "faiss_cpu-1.7.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6a4e4af194b8fce74c4b770cad67ad1dd1b4673677fc169723e4c50ba5bd97a8"},
+ {file = "faiss_cpu-1.7.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31bfb7b9cffc36897ae02a983e04c09fe3b8c053110a287134751a115334a1df"},
+ {file = "faiss_cpu-1.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52d7de96abef2340c0d373c1f5cbc78026a3cebb0f8f3a5920920a00210ead1f"},
+ {file = "faiss_cpu-1.7.4-cp38-cp38-win_amd64.whl", hash = "sha256:699feef85b23c2c729d794e26ca69bebc0bee920d676028c06fd0e0becc15c7e"},
+ {file = "faiss_cpu-1.7.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:559a0133f5ed44422acb09ee1ac0acffd90c6666d1bc0d671c18f6e93ad603e2"},
+ {file = "faiss_cpu-1.7.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1d71539fe3dc0f1bed41ef954ca701678776f231046bf0ca22ccea5cf5bef6"},
+ {file = "faiss_cpu-1.7.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12d45e0157024eb3249842163162983a1ac8b458f1a8b17bbf86f01be4585a99"},
+ {file = "faiss_cpu-1.7.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f0eab359e066d32c874f51a7d4bf6440edeec068b7fe47e6d803c73605a8b4c"},
+ {file = "faiss_cpu-1.7.4-cp39-cp39-win_amd64.whl", hash = "sha256:98459ceeeb735b9df1a5b94572106ffe0a6ce740eb7e4626715dd218657bb4dc"},
+]
+
+[[package]]
+name = "fastavro"
+version = "1.9.0"
+description = "Fast read/write of AVRO files"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "fastavro-1.9.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:00826f295f290ba95f1f68d5c36970b4db7f9245a1b1a33dd9d464a382733894"},
+ {file = "fastavro-1.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ff7ac97cfe07ad90fdcca3ea90b14461ba8831bc45f02e13440b6c634f291c8"},
+ {file = "fastavro-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c251e7122b436458b8e1151c0613d6dac2b5edb6acbbc35de3b4c5f6ebb80b7"},
+ {file = "fastavro-1.9.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:35a32f5d33f91fcb7e8daf7afc82a75c8d7c774cf4d93937b2ad487d28f3f707"},
+ {file = "fastavro-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:228e7c525ff15a9f21f1adb2097ec87888933ef5c8a682c2f1d5d83796e4dd42"},
+ {file = "fastavro-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:d694bb1c2b20f1703bcb698a74f58f0f503eda8f49cb6d46209c8f3715098348"},
+ {file = "fastavro-1.9.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0f044b71d8b0ba6bbd6166be6836c3caeadd26eeaabee70b6ac7c6a9b884f6bf"},
+ {file = "fastavro-1.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:172d6d5c186ba51ec6eaa98eaaadc8e859b5a56862ae724413424a858619da7f"},
+ {file = "fastavro-1.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07dee19dcc2797a8cb1b410d9e65febb55af2a18d9a7b85465b039d4276b9a29"},
+ {file = "fastavro-1.9.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:83402b450f718b690ebd88f1df2ea70609f1192bed1498308d29ac737e992391"},
+ {file = "fastavro-1.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b3704847d79377a5b4252ccf6d3a391497cdb8f57017cde2613f92f5274d6261"},
+ {file = "fastavro-1.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:602492ea0c458020cd19138ff2b9e97aa187ae01c290183dd9bbb7ff2d2e83c4"},
+ {file = "fastavro-1.9.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1cea6c2508dfb06d65cddb5b90bd6a79d3e481f1d80adc5f6ce6e3dacb4a8773"},
+ {file = "fastavro-1.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8629d4367373db7d195672834c59c86e2642172bbebd5ec6d83797b39ac4ef01"},
+ {file = "fastavro-1.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f45dfc29de276b509c8dbbfa6076ba6562be055c877928d4ffa1cf35b8ec59dc"},
+ {file = "fastavro-1.9.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc3b2de071e4d6de19974ffd328e63f7c85de2348d614222238fda2b35578b63"},
+ {file = "fastavro-1.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a0d2570052b4e2d7b46bec4cd74c8b12d8e21cd151f5bfc837da990cb62385c5"},
+ {file = "fastavro-1.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:718e5df505029269e7a80afdd7e5f196d24f1473ad47eea41061ce630609f80e"},
+ {file = "fastavro-1.9.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:6cebcc09c932931e3084c96fe2c666c9cfc8c4043520651fbfeb58575edeb7da"},
+ {file = "fastavro-1.9.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb7e3a058a169d2c8bd19dfcbc7ae14c879750ce49fbaf3c436af683991f7eae"},
+ {file = "fastavro-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5af71895a01618c98ae7c563ee75b18f721d8a66324d66613bd2fcd8b2f8ac9"},
+ {file = "fastavro-1.9.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:db30121ce34f5a0a4c368504a5e2df05449382e8d4918c0b43058ffb1d31d723"},
+ {file = "fastavro-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:48d9214982c0c0f29e583df11781dc6884e8f3f3336b97991c6e7587f509a02b"},
+ {file = "fastavro-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d4a71d39760de455dbe0b2121ea1bbd85fc851e8bab2970d9e9d6d8825277d2"},
+ {file = "fastavro-1.9.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:f803c33f4fd4e3bfc17bbdbf3c036fbcb92a1f8e6bd19a035800518479ce6b36"},
+ {file = "fastavro-1.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00361ea6d5a46813f3758511153fed9698308cae175500ff62562893d3570156"},
+ {file = "fastavro-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44fc998387271d57d0e3b29c30049ba903d2aead9471b12c20725284d60dd57e"},
+ {file = "fastavro-1.9.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:52e7df50431c21543682afd0ca95c40569c49e4c4599dcb78343f7c24fda6145"},
+ {file = "fastavro-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:215f40921d3f1f229cea89af25533e7be3fde16dd85c55436c15fb1ad067b486"},
+ {file = "fastavro-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:0c046ed9759d1100df59dc18452901253cff5a37d9e8e8701d0102116c3202cb"},
+ {file = "fastavro-1.9.0.tar.gz", hash = "sha256:71aad82b17442dc41223f8351b9f28a60dd877a8e5a7525eaf6342f45f6d23e1"},
+]
+
+[package.extras]
+codecs = ["cramjam", "lz4", "zstandard"]
+lz4 = ["lz4"]
+snappy = ["cramjam"]
+zstandard = ["zstandard"]
+
+[[package]]
+name = "fastjsonschema"
+version = "2.19.0"
+description = "Fastest Python implementation of JSON schema"
+optional = false
+python-versions = "*"
+files = [
+ {file = "fastjsonschema-2.19.0-py3-none-any.whl", hash = "sha256:b9fd1a2dd6971dbc7fee280a95bd199ae0dd9ce22beb91cc75e9c1c528a5170e"},
+ {file = "fastjsonschema-2.19.0.tar.gz", hash = "sha256:e25df6647e1bc4a26070b700897b07b542ec898dd4f1f6ea013e7f6a88417225"},
+]
+
+[package.extras]
+devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"]
+
+[[package]]
+name = "feedfinder2"
+version = "0.0.4"
+description = "Find the feed URLs for a website."
+optional = true
+python-versions = "*"
+files = [
+ {file = "feedfinder2-0.0.4.tar.gz", hash = "sha256:3701ee01a6c85f8b865a049c30ba0b4608858c803fe8e30d1d289fdbe89d0efe"},
+]
+
+[package.dependencies]
+beautifulsoup4 = "*"
+requests = "*"
+six = "*"
+
+[[package]]
+name = "feedparser"
+version = "6.0.10"
+description = "Universal feed parser, handles RSS 0.9x, RSS 1.0, RSS 2.0, CDF, Atom 0.3, and Atom 1.0 feeds"
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "feedparser-6.0.10-py3-none-any.whl", hash = "sha256:79c257d526d13b944e965f6095700587f27388e50ea16fd245babe4dfae7024f"},
+ {file = "feedparser-6.0.10.tar.gz", hash = "sha256:27da485f4637ce7163cdeab13a80312b93b7d0c1b775bef4a47629a3110bca51"},
+]
+
+[package.dependencies]
+sgmllib3k = "*"
+
+[[package]]
+name = "filelock"
+version = "3.13.1"
+description = "A platform independent file lock."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"},
+ {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"},
+]
+
+[package.extras]
+docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
+typing = ["typing-extensions (>=4.8)"]
+
+[[package]]
+name = "fiona"
+version = "1.9.5"
+description = "Fiona reads and writes spatial data files"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "fiona-1.9.5-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:5f40a40529ecfca5294260316cf987a0420c77a2f0cf0849f529d1afbccd093e"},
+ {file = "fiona-1.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:374efe749143ecb5cfdd79b585d83917d2bf8ecfbfc6953c819586b336ce9c63"},
+ {file = "fiona-1.9.5-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:35dae4b0308eb44617cdc4461ceb91f891d944fdebbcba5479efe524ec5db8de"},
+ {file = "fiona-1.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:5b4c6a3df53bee8f85bb46685562b21b43346be1fe96419f18f70fa1ab8c561c"},
+ {file = "fiona-1.9.5-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:6ad04c1877b9fd742871b11965606c6a52f40706f56a48d66a87cc3073943828"},
+ {file = "fiona-1.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9fb9a24a8046c724787719e20557141b33049466145fc3e665764ac7caf5748c"},
+ {file = "fiona-1.9.5-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:d722d7f01a66f4ab6cd08d156df3fdb92f0669cf5f8708ddcb209352f416f241"},
+ {file = "fiona-1.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:7ede8ddc798f3d447536080c6db9a5fb73733ad8bdb190cb65eed4e289dd4c50"},
+ {file = "fiona-1.9.5-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:8b098054a27c12afac4f819f98cb4d4bf2db9853f70b0c588d7d97d26e128c39"},
+ {file = "fiona-1.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d9f29e9bcbb33232ff7fa98b4a3c2234db910c1dc6c4147fc36c0b8b930f2e0"},
+ {file = "fiona-1.9.5-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:f1af08da4ecea5036cb81c9131946be4404245d1b434b5b24fd3871a1d4030d9"},
+ {file = "fiona-1.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:c521e1135c78dec0d7774303e5a1b4c62e0efb0e602bb8f167550ef95e0a2691"},
+ {file = "fiona-1.9.5-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:fce4b1dd98810cabccdaa1828430c7402d283295c2ae31bea4f34188ea9e88d7"},
+ {file = "fiona-1.9.5-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:3ea04ec2d8c57b5f81a31200fb352cb3242aa106fc3e328963f30ffbdf0ff7c8"},
+ {file = "fiona-1.9.5-cp37-cp37m-win_amd64.whl", hash = "sha256:4877cc745d9e82b12b3eafce3719db75759c27bd8a695521202135b36b58c2e7"},
+ {file = "fiona-1.9.5-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:ac2c250f509ec19fad7959d75b531984776517ef3c1222d1cc5b4f962825880b"},
+ {file = "fiona-1.9.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4df21906235928faad856c288cfea0298e9647f09c9a69a230535cbc8eadfa21"},
+ {file = "fiona-1.9.5-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:81d502369493687746cb8d3cd77e5ada4447fb71d513721c9a1826e4fb32b23a"},
+ {file = "fiona-1.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:ce3b29230ef70947ead4e701f3f82be81082b7f37fd4899009b1445cc8fc276a"},
+ {file = "fiona-1.9.5-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:8b53ce8de773fcd5e2e102e833c8c58479edd8796a522f3d83ef9e08b62bfeea"},
+ {file = "fiona-1.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bd2355e859a1cd24a3e485c6dc5003129f27a2051629def70036535ffa7e16a4"},
+ {file = "fiona-1.9.5-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:9a2da52f865db1aff0eaf41cdd4c87a7c079b3996514e8e7a1ca38457309e825"},
+ {file = "fiona-1.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:cfef6db5b779d463298b1113b50daa6c5b55f26f834dc9e37752116fa17277c1"},
+ {file = "fiona-1.9.5.tar.gz", hash = "sha256:99e2604332caa7692855c2ae6ed91e1fffdf9b59449aa8032dd18e070e59a2f7"},
+]
+
+[package.dependencies]
+attrs = ">=19.2.0"
+certifi = "*"
+click = ">=8.0,<9.0"
+click-plugins = ">=1.0"
+cligj = ">=0.5"
+importlib-metadata = {version = "*", markers = "python_version < \"3.10\""}
+setuptools = "*"
+six = "*"
+
+[package.extras]
+all = ["Fiona[calc,s3,test]"]
+calc = ["shapely"]
+s3 = ["boto3 (>=1.3.1)"]
+test = ["Fiona[s3]", "pytest (>=7)", "pytest-cov", "pytz"]
+
+[[package]]
+name = "fireworks-ai"
+version = "0.9.0"
+description = "Python client library for the Fireworks.ai Generative AI Platform"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "fireworks-ai-0.9.0.tar.gz", hash = "sha256:0aa8ec092d0b05e9b509e33c887142521251f89d8a709524529fff058ba1e09a"},
+ {file = "fireworks_ai-0.9.0-py3-none-any.whl", hash = "sha256:bef6ef19423885316bc70ff0c967a2f1936070827ff0a5c3581f6a2059b11f68"},
+]
+
+[package.dependencies]
+httpx = "*"
+httpx-sse = "*"
+Pillow = "*"
+pydantic = "*"
+
+[[package]]
+name = "flatbuffers"
+version = "23.5.26"
+description = "The FlatBuffers serialization format for Python"
+optional = true
+python-versions = "*"
+files = [
+ {file = "flatbuffers-23.5.26-py2.py3-none-any.whl", hash = "sha256:c0ff356da363087b915fde4b8b45bdda73432fc17cddb3c8157472eab1422ad1"},
+ {file = "flatbuffers-23.5.26.tar.gz", hash = "sha256:9ea1144cac05ce5d86e2859f431c6cd5e66cd9c78c558317c7955fb8d4c78d89"},
+]
+
+[[package]]
+name = "fqdn"
+version = "1.5.1"
+description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers"
+optional = false
+python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4"
+files = [
+ {file = "fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014"},
+ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"},
+]
+
+[[package]]
+name = "freezegun"
+version = "1.3.1"
+description = "Let your Python tests travel through time"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "freezegun-1.3.1-py3-none-any.whl", hash = "sha256:065e77a12624d05531afa87ade12a0b9bdb53495c4573893252a055b545ce3ea"},
+ {file = "freezegun-1.3.1.tar.gz", hash = "sha256:48984397b3b58ef5dfc645d6a304b0060f612bcecfdaaf45ce8aff0077a6cb6a"},
+]
+
+[package.dependencies]
+python-dateutil = ">=2.7"
+
+[[package]]
+name = "frozenlist"
+version = "1.4.0"
+description = "A list-like structure which implements collections.abc.MutableSequence"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"},
+ {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"},
+ {file = "frozenlist-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9ac08e601308e41eb533f232dbf6b7e4cea762f9f84f6357136eed926c15d12c"},
+ {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d081f13b095d74b67d550de04df1c756831f3b83dc9881c38985834387487f1b"},
+ {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71932b597f9895f011f47f17d6428252fc728ba2ae6024e13c3398a087c2cdea"},
+ {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:981b9ab5a0a3178ff413bca62526bb784249421c24ad7381e39d67981be2c326"},
+ {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e41f3de4df3e80de75845d3e743b3f1c4c8613c3997a912dbf0229fc61a8b963"},
+ {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6918d49b1f90821e93069682c06ffde41829c346c66b721e65a5c62b4bab0300"},
+ {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e5c8764c7829343d919cc2dfc587a8db01c4f70a4ebbc49abde5d4b158b007b"},
+ {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8d0edd6b1c7fb94922bf569c9b092ee187a83f03fb1a63076e7774b60f9481a8"},
+ {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e29cda763f752553fa14c68fb2195150bfab22b352572cb36c43c47bedba70eb"},
+ {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:0c7c1b47859ee2cac3846fde1c1dc0f15da6cec5a0e5c72d101e0f83dcb67ff9"},
+ {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:901289d524fdd571be1c7be054f48b1f88ce8dddcbdf1ec698b27d4b8b9e5d62"},
+ {file = "frozenlist-1.4.0-cp310-cp310-win32.whl", hash = "sha256:1a0848b52815006ea6596c395f87449f693dc419061cc21e970f139d466dc0a0"},
+ {file = "frozenlist-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:b206646d176a007466358aa21d85cd8600a415c67c9bd15403336c331a10d956"},
+ {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de343e75f40e972bae1ef6090267f8260c1446a1695e77096db6cfa25e759a95"},
+ {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad2a9eb6d9839ae241701d0918f54c51365a51407fd80f6b8289e2dfca977cc3"},
+ {file = "frozenlist-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7bd3b3830247580de99c99ea2a01416dfc3c34471ca1298bccabf86d0ff4dc"},
+ {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdf1847068c362f16b353163391210269e4f0569a3c166bc6a9f74ccbfc7e839"},
+ {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38461d02d66de17455072c9ba981d35f1d2a73024bee7790ac2f9e361ef1cd0c"},
+ {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5a32087d720c608f42caed0ef36d2b3ea61a9d09ee59a5142d6070da9041b8f"},
+ {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd65632acaf0d47608190a71bfe46b209719bf2beb59507db08ccdbe712f969b"},
+ {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261b9f5d17cac914531331ff1b1d452125bf5daa05faf73b71d935485b0c510b"},
+ {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b89ac9768b82205936771f8d2eb3ce88503b1556324c9f903e7156669f521472"},
+ {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:008eb8b31b3ea6896da16c38c1b136cb9fec9e249e77f6211d479db79a4eaf01"},
+ {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e74b0506fa5aa5598ac6a975a12aa8928cbb58e1f5ac8360792ef15de1aa848f"},
+ {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:490132667476f6781b4c9458298b0c1cddf237488abd228b0b3650e5ecba7467"},
+ {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:76d4711f6f6d08551a7e9ef28c722f4a50dd0fc204c56b4bcd95c6cc05ce6fbb"},
+ {file = "frozenlist-1.4.0-cp311-cp311-win32.whl", hash = "sha256:a02eb8ab2b8f200179b5f62b59757685ae9987996ae549ccf30f983f40602431"},
+ {file = "frozenlist-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:515e1abc578dd3b275d6a5114030b1330ba044ffba03f94091842852f806f1c1"},
+ {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0ed05f5079c708fe74bf9027e95125334b6978bf07fd5ab923e9e55e5fbb9d3"},
+ {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ca265542ca427bf97aed183c1676e2a9c66942e822b14dc6e5f42e038f92a503"},
+ {file = "frozenlist-1.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:491e014f5c43656da08958808588cc6c016847b4360e327a62cb308c791bd2d9"},
+ {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ae5cd0f333f94f2e03aaf140bb762c64783935cc764ff9c82dff626089bebf"},
+ {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e78fb68cf9c1a6aa4a9a12e960a5c9dfbdb89b3695197aa7064705662515de2"},
+ {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5655a942f5f5d2c9ed93d72148226d75369b4f6952680211972a33e59b1dfdc"},
+ {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11b0746f5d946fecf750428a95f3e9ebe792c1ee3b1e96eeba145dc631a9672"},
+ {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e66d2a64d44d50d2543405fb183a21f76b3b5fd16f130f5c99187c3fb4e64919"},
+ {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:88f7bc0fcca81f985f78dd0fa68d2c75abf8272b1f5c323ea4a01a4d7a614efc"},
+ {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5833593c25ac59ede40ed4de6d67eb42928cca97f26feea219f21d0ed0959b79"},
+ {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e"},
+ {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:b826d97e4276750beca7c8f0f1a4938892697a6bcd8ec8217b3312dad6982781"},
+ {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ceb6ec0a10c65540421e20ebd29083c50e6d1143278746a4ef6bcf6153171eb8"},
+ {file = "frozenlist-1.4.0-cp38-cp38-win32.whl", hash = "sha256:2b8bcf994563466db019fab287ff390fffbfdb4f905fc77bc1c1d604b1c689cc"},
+ {file = "frozenlist-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:a6c8097e01886188e5be3e6b14e94ab365f384736aa1fca6a0b9e35bd4a30bc7"},
+ {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6c38721585f285203e4b4132a352eb3daa19121a035f3182e08e437cface44bf"},
+ {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0c6da9aee33ff0b1a451e867da0c1f47408112b3391dd43133838339e410963"},
+ {file = "frozenlist-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93ea75c050c5bb3d98016b4ba2497851eadf0ac154d88a67d7a6816206f6fa7f"},
+ {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f61e2dc5ad442c52b4887f1fdc112f97caeff4d9e6ebe78879364ac59f1663e1"},
+ {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa384489fefeb62321b238e64c07ef48398fe80f9e1e6afeff22e140e0850eef"},
+ {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10ff5faaa22786315ef57097a279b833ecab1a0bfb07d604c9cbb1c4cdc2ed87"},
+ {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:007df07a6e3eb3e33e9a1fe6a9db7af152bbd8a185f9aaa6ece10a3529e3e1c6"},
+ {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f4f399d28478d1f604c2ff9119907af9726aed73680e5ed1ca634d377abb087"},
+ {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5374b80521d3d3f2ec5572e05adc94601985cc526fb276d0c8574a6d749f1b3"},
+ {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ce31ae3e19f3c902de379cf1323d90c649425b86de7bbdf82871b8a2a0615f3d"},
+ {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7211ef110a9194b6042449431e08c4d80c0481e5891e58d429df5899690511c2"},
+ {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:556de4430ce324c836789fa4560ca62d1591d2538b8ceb0b4f68fb7b2384a27a"},
+ {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7645a8e814a3ee34a89c4a372011dcd817964ce8cb273c8ed6119d706e9613e3"},
+ {file = "frozenlist-1.4.0-cp39-cp39-win32.whl", hash = "sha256:19488c57c12d4e8095a922f328df3f179c820c212940a498623ed39160bc3c2f"},
+ {file = "frozenlist-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:6221d84d463fb110bdd7619b69cb43878a11d51cbb9394ae3105d082d5199167"},
+ {file = "frozenlist-1.4.0.tar.gz", hash = "sha256:09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"},
+]
+
+[[package]]
+name = "fsspec"
+version = "2023.10.0"
+description = "File-system specification"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "fsspec-2023.10.0-py3-none-any.whl", hash = "sha256:346a8f024efeb749d2a5fca7ba8854474b1ff9af7c3faaf636a4548781136529"},
+ {file = "fsspec-2023.10.0.tar.gz", hash = "sha256:330c66757591df346ad3091a53bd907e15348c2ba17d63fd54f5c39c4457d2a5"},
+]
+
+[package.dependencies]
+aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
+requests = {version = "*", optional = true, markers = "extra == \"http\""}
+
+[package.extras]
+abfs = ["adlfs"]
+adl = ["adlfs"]
+arrow = ["pyarrow (>=1)"]
+dask = ["dask", "distributed"]
+devel = ["pytest", "pytest-cov"]
+dropbox = ["dropbox", "dropboxdrivefs", "requests"]
+full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
+fuse = ["fusepy"]
+gcs = ["gcsfs"]
+git = ["pygit2"]
+github = ["requests"]
+gs = ["gcsfs"]
+gui = ["panel"]
+hdfs = ["pyarrow (>=1)"]
+http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"]
+libarchive = ["libarchive-c"]
+oci = ["ocifs"]
+s3 = ["s3fs"]
+sftp = ["paramiko"]
+smb = ["smbprotocol"]
+ssh = ["paramiko"]
+tqdm = ["tqdm"]
+
+[[package]]
+name = "geomet"
+version = "0.2.1.post1"
+description = "GeoJSON <-> WKT/WKB conversion utilities"
+optional = false
+python-versions = ">2.6, !=3.3.*, <4"
+files = [
+ {file = "geomet-0.2.1.post1-py3-none-any.whl", hash = "sha256:a41a1e336b381416d6cbed7f1745c848e91defaa4d4c1bdc1312732e46ffad2b"},
+ {file = "geomet-0.2.1.post1.tar.gz", hash = "sha256:91d754f7c298cbfcabd3befdb69c641c27fe75e808b27aa55028605761d17e95"},
+]
+
+[package.dependencies]
+click = "*"
+six = "*"
+
+[[package]]
+name = "geopandas"
+version = "0.13.2"
+description = "Geographic pandas extensions"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "geopandas-0.13.2-py3-none-any.whl", hash = "sha256:101cfd0de54bcf9e287a55b5ea17ebe0db53a5e25a28bacf100143d0507cabd9"},
+ {file = "geopandas-0.13.2.tar.gz", hash = "sha256:e5b56d9c20800c77bcc0c914db3f27447a37b23b2cd892be543f5001a694a968"},
+]
+
+[package.dependencies]
+fiona = ">=1.8.19"
+packaging = "*"
+pandas = ">=1.1.0"
+pyproj = ">=3.0.1"
+shapely = ">=1.7.1"
+
+[[package]]
+name = "gitdb"
+version = "4.0.11"
+description = "Git Object Database"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"},
+ {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"},
+]
+
+[package.dependencies]
+smmap = ">=3.0.1,<6"
+
+[[package]]
+name = "gitpython"
+version = "3.1.40"
+description = "GitPython is a Python library used to interact with Git repositories"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "GitPython-3.1.40-py3-none-any.whl", hash = "sha256:cf14627d5a8049ffbf49915732e5eddbe8134c3bdb9d476e6182b676fc573f8a"},
+ {file = "GitPython-3.1.40.tar.gz", hash = "sha256:22b126e9ffb671fdd0c129796343a02bf67bf2994b35449ffc9321aa755e18a4"},
+]
+
+[package.dependencies]
+gitdb = ">=4.0.1,<5"
+
+[package.extras]
+test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"]
+
+[[package]]
+name = "google-api-core"
+version = "1.34.0"
+description = "Google API client core library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google-api-core-1.34.0.tar.gz", hash = "sha256:6fb380f49d19ee1d09a9722d0379042b7edb06c0112e4796c7a395078a043e71"},
+ {file = "google_api_core-1.34.0-py3-none-any.whl", hash = "sha256:7421474c39d396a74dfa317dddbc69188f2336835f526087c7648f91105e32ff"},
+]
+
+[package.dependencies]
+google-auth = ">=1.25.0,<3.0dev"
+googleapis-common-protos = ">=1.56.2,<2.0dev"
+grpcio = {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}
+grpcio-status = {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.0.0dev"
+requests = ">=2.18.0,<3.0.0dev"
+
+[package.extras]
+grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio-status (>=1.33.2,<2.0dev)"]
+grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"]
+grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"]
+
+[[package]]
+name = "google-api-core"
+version = "2.14.0"
+description = "Google API client core library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google-api-core-2.14.0.tar.gz", hash = "sha256:5368a4502b793d9bbf812a5912e13e4e69f9bd87f6efb508460c43f5bbd1ce41"},
+ {file = "google_api_core-2.14.0-py3-none-any.whl", hash = "sha256:de2fb50ed34d47ddbb2bd2dcf680ee8fead46279f4ed6b16de362aca23a18952"},
+]
+
+[package.dependencies]
+google-auth = ">=2.14.1,<3.0.dev0"
+googleapis-common-protos = ">=1.56.2,<2.0.dev0"
+grpcio = {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}
+grpcio-status = {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
+requests = ">=2.18.0,<3.0.0.dev0"
+
+[package.extras]
+grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
+grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
+grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
+
+[[package]]
+name = "google-auth"
+version = "2.24.0"
+description = "Google Authentication Library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google-auth-2.24.0.tar.gz", hash = "sha256:2ec7b2a506989d7dbfdbe81cb8d0ead8876caaed14f86d29d34483cbe99c57af"},
+ {file = "google_auth-2.24.0-py2.py3-none-any.whl", hash = "sha256:9b82d5c8d3479a5391ea0a46d81cca698d328459da31d4a459d4e901a5d927e0"},
+]
+
+[package.dependencies]
+cachetools = ">=2.0.0,<6.0"
+pyasn1-modules = ">=0.2.1"
+rsa = ">=3.1.4,<5"
+
+[package.extras]
+aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"]
+enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
+pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
+reauth = ["pyu2f (>=0.1.5)"]
+requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
+
+[[package]]
+name = "google-cloud-aiplatform"
+version = "1.37.0"
+description = "Vertex AI API client library"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "google-cloud-aiplatform-1.37.0.tar.gz", hash = "sha256:51cac0334fc7274142b50363dd10cbb3d303ff6354e5c06a7fb51e1b7db02dfb"},
+ {file = "google_cloud_aiplatform-1.37.0-py2.py3-none-any.whl", hash = "sha256:9b3d6e681084a60b7e4de7063d41ca2d354b3546a918609ddaf9d8d1f8b19c36"},
+]
+
+[package.dependencies]
+google-api-core = {version = ">=1.32.0,<2.0.dev0 || >=2.8.dev0,<3.0.0dev", extras = ["grpc"]}
+google-cloud-bigquery = ">=1.15.0,<4.0.0dev"
+google-cloud-resource-manager = ">=1.3.3,<3.0.0dev"
+google-cloud-storage = ">=1.32.0,<3.0.0dev"
+packaging = ">=14.3"
+proto-plus = ">=1.22.0,<2.0.0dev"
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
+setuptools = {version = "*", markers = "python_version >= \"3.12\""}
+shapely = "<3.0.0dev"
+
+[package.extras]
+autologging = ["mlflow (>=1.27.0,<=2.1.1)"]
+cloud-profiler = ["tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"]
+datasets = ["pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)"]
+endpoint = ["requests (>=2.28.1)"]
+full = ["cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<0.103.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (==0.0.11)", "google-vizier (==0.0.4)", "google-vizier (>=0.0.14)", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml (==5.3.1)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)", "requests (>=2.28.1)", "starlette (>=0.17.1)", "tensorflow (>=2.3.0,<3.0.0dev)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)"]
+lit = ["explainable-ai-sdk (>=1.0.0)", "lit-nlp (==0.4.0)", "pandas (>=1.0.0)", "tensorflow (>=2.3.0,<3.0.0dev)"]
+metadata = ["numpy (>=1.15.0)", "pandas (>=1.0.0)"]
+pipelines = ["pyyaml (==5.3.1)"]
+prediction = ["docker (>=5.0.3)", "fastapi (>=0.71.0,<0.103.1)", "httpx (>=0.23.0,<0.25.0)", "starlette (>=0.17.1)", "uvicorn[standard] (>=0.16.0)"]
+preview = ["cloudpickle (<3.0)", "google-cloud-logging (<4.0)"]
+private-endpoints = ["requests (>=2.28.1)", "urllib3 (>=1.21.1,<1.27)"]
+ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "pandas (>=1.0.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)"]
+tensorboard = ["tensorflow (>=2.3.0,<3.0.0dev)"]
+testing = ["bigframes", "cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<0.103.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (==0.0.11)", "google-vizier (==0.0.4)", "google-vizier (>=0.0.14)", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "ipython", "kfp", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyfakefs", "pytest-asyncio", "pytest-xdist", "pyyaml (==5.3.1)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<=2.12.0)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost", "xgboost-ray"]
+vizier = ["google-vizier (==0.0.11)", "google-vizier (==0.0.4)", "google-vizier (>=0.0.14)", "google-vizier (>=0.1.6)"]
+xai = ["tensorflow (>=2.3.0,<3.0.0dev)"]
+
+[[package]]
+name = "google-cloud-bigquery"
+version = "3.13.0"
+description = "Google BigQuery API client library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google-cloud-bigquery-3.13.0.tar.gz", hash = "sha256:794ccfc93ccb0e0ad689442f896f9c82de56da0fe18a195531bb37096c2657d6"},
+ {file = "google_cloud_bigquery-3.13.0-py2.py3-none-any.whl", hash = "sha256:eda3dbcff676e17962c54e5224e415b55e4f6833a5c896c6c8902b69e7dba4b4"},
+]
+
+[package.dependencies]
+google-api-core = {version = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev", extras = ["grpc"]}
+google-cloud-core = ">=1.6.0,<3.0.0dev"
+google-resumable-media = ">=0.6.0,<3.0dev"
+grpcio = [
+ {version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
+ {version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""},
+]
+packaging = ">=20.0.0"
+proto-plus = ">=1.15.0,<2.0.0dev"
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
+python-dateutil = ">=2.7.2,<3.0dev"
+requests = ">=2.21.0,<3.0.0dev"
+
+[package.extras]
+all = ["Shapely (>=1.8.4,<3.0.0dev)", "db-dtypes (>=0.3.0,<2.0.0dev)", "geopandas (>=0.9.0,<1.0dev)", "google-cloud-bigquery-storage (>=2.6.0,<3.0.0dev)", "grpcio (>=1.47.0,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "ipykernel (>=6.0.0)", "ipython (>=7.23.1,!=8.1.0)", "ipywidgets (>=7.7.0)", "opentelemetry-api (>=1.1.0)", "opentelemetry-instrumentation (>=0.20b0)", "opentelemetry-sdk (>=1.1.0)", "pandas (>=1.1.0)", "pyarrow (>=3.0.0)", "tqdm (>=4.7.4,<5.0.0dev)"]
+bqstorage = ["google-cloud-bigquery-storage (>=2.6.0,<3.0.0dev)", "grpcio (>=1.47.0,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "pyarrow (>=3.0.0)"]
+geopandas = ["Shapely (>=1.8.4,<3.0.0dev)", "geopandas (>=0.9.0,<1.0dev)"]
+ipython = ["ipykernel (>=6.0.0)", "ipython (>=7.23.1,!=8.1.0)"]
+ipywidgets = ["ipykernel (>=6.0.0)", "ipywidgets (>=7.7.0)"]
+opentelemetry = ["opentelemetry-api (>=1.1.0)", "opentelemetry-instrumentation (>=0.20b0)", "opentelemetry-sdk (>=1.1.0)"]
+pandas = ["db-dtypes (>=0.3.0,<2.0.0dev)", "pandas (>=1.1.0)", "pyarrow (>=3.0.0)"]
+tqdm = ["tqdm (>=4.7.4,<5.0.0dev)"]
+
+[[package]]
+name = "google-cloud-core"
+version = "2.4.1"
+description = "Google Cloud API client core library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google-cloud-core-2.4.1.tar.gz", hash = "sha256:9b7749272a812bde58fff28868d0c5e2f585b82f37e09a1f6ed2d4d10f134073"},
+ {file = "google_cloud_core-2.4.1-py2.py3-none-any.whl", hash = "sha256:a9e6a4422b9ac5c29f79a0ede9485473338e2ce78d91f2370c01e730eab22e61"},
+]
+
+[package.dependencies]
+google-api-core = ">=1.31.6,<2.0.dev0 || >2.3.0,<3.0.0dev"
+google-auth = ">=1.25.0,<3.0dev"
+
+[package.extras]
+grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"]
+
+[[package]]
+name = "google-cloud-documentai"
+version = "2.20.2"
+description = "Google Cloud Documentai API client library"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "google-cloud-documentai-2.20.2.tar.gz", hash = "sha256:097694574489041765a3f26ba346a27a71a17a5f8a4a6a959d3fca84025c144f"},
+ {file = "google_cloud_documentai-2.20.2-py2.py3-none-any.whl", hash = "sha256:d3f401740de4a36d79b0179593f61a9da05602b1686f8a2f0f25b8a4c5b9d6cb"},
+]
+
+[package.dependencies]
+google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
+proto-plus = [
+ {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
+ {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
+]
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
+
+[[package]]
+name = "google-cloud-resource-manager"
+version = "1.11.0"
+description = "Google Cloud Resource Manager API client library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google-cloud-resource-manager-1.11.0.tar.gz", hash = "sha256:a64ba6bb595634ecd2472b8b0322e8f012a76327756659a2dde9f392d7fa1af2"},
+ {file = "google_cloud_resource_manager-1.11.0-py2.py3-none-any.whl", hash = "sha256:bafde909b1d434a620eefcd144b14fcccb72f268afcf158c5bcfcdce5e04a72b"},
+]
+
+[package.dependencies]
+google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
+grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
+proto-plus = ">=1.22.3,<2.0.0dev"
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
+
+[[package]]
+name = "google-cloud-storage"
+version = "2.13.0"
+description = "Google Cloud Storage API client library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google-cloud-storage-2.13.0.tar.gz", hash = "sha256:f62dc4c7b6cd4360d072e3deb28035fbdad491ac3d9b0b1815a12daea10f37c7"},
+ {file = "google_cloud_storage-2.13.0-py2.py3-none-any.whl", hash = "sha256:ab0bf2e1780a1b74cf17fccb13788070b729f50c252f0c94ada2aae0ca95437d"},
+]
+
+[package.dependencies]
+google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev"
+google-auth = ">=2.23.3,<3.0dev"
+google-cloud-core = ">=2.3.0,<3.0dev"
+google-crc32c = ">=1.0,<2.0dev"
+google-resumable-media = ">=2.6.0"
+requests = ">=2.18.0,<3.0.0dev"
+
+[package.extras]
+protobuf = ["protobuf (<5.0.0dev)"]
+
+[[package]]
+name = "google-crc32c"
+version = "1.5.0"
+description = "A python wrapper of the C library 'Google CRC32C'"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google-crc32c-1.5.0.tar.gz", hash = "sha256:89284716bc6a5a415d4eaa11b1726d2d60a0cd12aadf5439828353662ede9dd7"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:596d1f98fc70232fcb6590c439f43b350cb762fb5d61ce7b0e9db4539654cc13"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:be82c3c8cfb15b30f36768797a640e800513793d6ae1724aaaafe5bf86f8f346"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:461665ff58895f508e2866824a47bdee72497b091c730071f2b7575d5762ab65"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2096eddb4e7c7bdae4bd69ad364e55e07b8316653234a56552d9c988bd2d61b"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:116a7c3c616dd14a3de8c64a965828b197e5f2d121fedd2f8c5585c547e87b02"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5829b792bf5822fd0a6f6eb34c5f81dd074f01d570ed7f36aa101d6fc7a0a6e4"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:64e52e2b3970bd891309c113b54cf0e4384762c934d5ae56e283f9a0afcd953e"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:02ebb8bf46c13e36998aeaad1de9b48f4caf545e91d14041270d9dca767b780c"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-win32.whl", hash = "sha256:2e920d506ec85eb4ba50cd4228c2bec05642894d4c73c59b3a2fe20346bd00ee"},
+ {file = "google_crc32c-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:07eb3c611ce363c51a933bf6bd7f8e3878a51d124acfc89452a75120bc436289"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cae0274952c079886567f3f4f685bcaf5708f0a23a5f5216fdab71f81a6c0273"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1034d91442ead5a95b5aaef90dbfaca8633b0247d1e41621d1e9f9db88c36298"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c42c70cd1d362284289c6273adda4c6af8039a8ae12dc451dcd61cdabb8ab57"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8485b340a6a9e76c62a7dce3c98e5f102c9219f4cfbf896a00cf48caf078d438"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77e2fd3057c9d78e225fa0a2160f96b64a824de17840351b26825b0848022906"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f583edb943cf2e09c60441b910d6a20b4d9d626c75a36c8fcac01a6c96c01183"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:a1fd716e7a01f8e717490fbe2e431d2905ab8aa598b9b12f8d10abebb36b04dd"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:72218785ce41b9cfd2fc1d6a017dc1ff7acfc4c17d01053265c41a2c0cc39b8c"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-win32.whl", hash = "sha256:66741ef4ee08ea0b2cc3c86916ab66b6aef03768525627fd6a1b34968b4e3709"},
+ {file = "google_crc32c-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:ba1eb1843304b1e5537e1fca632fa894d6f6deca8d6389636ee5b4797affb968"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:98cb4d057f285bd80d8778ebc4fde6b4d509ac3f331758fb1528b733215443ae"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd8536e902db7e365f49e7d9029283403974ccf29b13fc7028b97e2295b33556"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19e0a019d2c4dcc5e598cd4a4bc7b008546b0358bd322537c74ad47a5386884f"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c65b9817512edc6a4ae7c7e987fea799d2e0ee40c53ec573a692bee24de876"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6ac08d24c1f16bd2bf5eca8eaf8304812f44af5cfe5062006ec676e7e1d50afc"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3359fc442a743e870f4588fcf5dcbc1bf929df1fad8fb9905cd94e5edb02e84c"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1e986b206dae4476f41bcec1faa057851f3889503a70e1bdb2378d406223994a"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:de06adc872bcd8c2a4e0dc51250e9e65ef2ca91be023b9d13ebd67c2ba552e1e"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-win32.whl", hash = "sha256:d3515f198eaa2f0ed49f8819d5732d70698c3fa37384146079b3799b97667a94"},
+ {file = "google_crc32c-1.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:67b741654b851abafb7bc625b6d1cdd520a379074e64b6a128e3b688c3c04740"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c02ec1c5856179f171e032a31d6f8bf84e5a75c45c33b2e20a3de353b266ebd8"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:edfedb64740750e1a3b16152620220f51d58ff1b4abceb339ca92e934775c27a"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84e6e8cd997930fc66d5bb4fde61e2b62ba19d62b7abd7a69920406f9ecca946"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:024894d9d3cfbc5943f8f230e23950cd4906b2fe004c72e29b209420a1e6b05a"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:998679bf62b7fb599d2878aa3ed06b9ce688b8974893e7223c60db155f26bd8d"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:83c681c526a3439b5cf94f7420471705bbf96262f49a6fe546a6db5f687a3d4a"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4c6fdd4fccbec90cc8a01fc00773fcd5fa28db683c116ee3cb35cd5da9ef6c37"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5ae44e10a8e3407dbe138984f21e536583f2bba1be9491239f942c2464ac0894"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:37933ec6e693e51a5b07505bd05de57eee12f3e8c32b07da7e73669398e6630a"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-win32.whl", hash = "sha256:fe70e325aa68fa4b5edf7d1a4b6f691eb04bbccac0ace68e34820d283b5f80d4"},
+ {file = "google_crc32c-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:74dea7751d98034887dbd821b7aae3e1d36eda111d6ca36c206c44478035709c"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c6c777a480337ac14f38564ac88ae82d4cd238bf293f0a22295b66eb89ffced7"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:759ce4851a4bb15ecabae28f4d2e18983c244eddd767f560165563bf9aefbc8d"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f13cae8cc389a440def0c8c52057f37359014ccbc9dc1f0827936bcd367c6100"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e560628513ed34759456a416bf86b54b2476c59144a9138165c9a1575801d0d9"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1674e4307fa3024fc897ca774e9c7562c957af85df55efe2988ed9056dc4e57"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:278d2ed7c16cfc075c91378c4f47924c0625f5fc84b2d50d921b18b7975bd210"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d5280312b9af0976231f9e317c20e4a61cd2f9629b7bfea6a693d1878a264ebd"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8b87e1a59c38f275c0e3676fc2ab6d59eccecfd460be267ac360cc31f7bcde96"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7c074fece789b5034b9b1404a1f8208fc2d4c6ce9decdd16e8220c5a793e6f61"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-win32.whl", hash = "sha256:7f57f14606cd1dd0f0de396e1e53824c371e9544a822648cd76c034d209b559c"},
+ {file = "google_crc32c-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:a2355cba1f4ad8b6988a4ca3feed5bff33f6af2d7f134852cf279c2aebfde541"},
+ {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f314013e7dcd5cf45ab1945d92e713eec788166262ae8deb2cfacd53def27325"},
+ {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b747a674c20a67343cb61d43fdd9207ce5da6a99f629c6e2541aa0e89215bcd"},
+ {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f24ed114432de109aa9fd317278518a5af2d31ac2ea6b952b2f7782b43da091"},
+ {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8667b48e7a7ef66afba2c81e1094ef526388d35b873966d8a9a447974ed9178"},
+ {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:1c7abdac90433b09bad6c43a43af253e688c9cfc1c86d332aed13f9a7c7f65e2"},
+ {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6f998db4e71b645350b9ac28a2167e6632c239963ca9da411523bb439c5c514d"},
+ {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c99616c853bb585301df6de07ca2cadad344fd1ada6d62bb30aec05219c45d2"},
+ {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ad40e31093a4af319dadf503b2467ccdc8f67c72e4bcba97f8c10cb078207b5"},
+ {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd67cf24a553339d5062eff51013780a00d6f97a39ca062781d06b3a73b15462"},
+ {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:398af5e3ba9cf768787eef45c803ff9614cc3e22a5b2f7d7ae116df8b11e3314"},
+ {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b1f8133c9a275df5613a451e73f36c2aea4fe13c5c8997e22cf355ebd7bd0728"},
+ {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ba053c5f50430a3fcfd36f75aff9caeba0440b2d076afdb79a318d6ca245f88"},
+ {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:272d3892a1e1a2dbc39cc5cde96834c236d5327e2122d3aaa19f6614531bb6eb"},
+ {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:635f5d4dd18758a1fbd1049a8e8d2fee4ffed124462d837d1a02a0e009c3ab31"},
+ {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c672d99a345849301784604bfeaeba4db0c7aae50b95be04dd651fd2a7310b93"},
+]
+
+[package.extras]
+testing = ["pytest"]
+
+[[package]]
+name = "google-resumable-media"
+version = "2.6.0"
+description = "Utilities for Google Media Downloads and Resumable Uploads"
+optional = false
+python-versions = ">= 3.7"
+files = [
+ {file = "google-resumable-media-2.6.0.tar.gz", hash = "sha256:972852f6c65f933e15a4a210c2b96930763b47197cdf4aa5f5bea435efb626e7"},
+ {file = "google_resumable_media-2.6.0-py2.py3-none-any.whl", hash = "sha256:fc03d344381970f79eebb632a3c18bb1828593a2dc5572b5f90115ef7d11e81b"},
+]
+
+[package.dependencies]
+google-crc32c = ">=1.0,<2.0dev"
+
+[package.extras]
+aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "google-auth (>=1.22.0,<2.0dev)"]
+requests = ["requests (>=2.18.0,<3.0.0dev)"]
+
+[[package]]
+name = "googleapis-common-protos"
+version = "1.61.0"
+description = "Common protobufs used in Google APIs"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "googleapis-common-protos-1.61.0.tar.gz", hash = "sha256:8a64866a97f6304a7179873a465d6eee97b7a24ec6cfd78e0f575e96b821240b"},
+ {file = "googleapis_common_protos-1.61.0-py2.py3-none-any.whl", hash = "sha256:22f1915393bb3245343f6efe87f6fe868532efc12aa26b391b15132e1279f1c0"},
+]
+
+[package.dependencies]
+grpcio = {version = ">=1.44.0,<2.0.0.dev0", optional = true, markers = "extra == \"grpc\""}
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
+
+[package.extras]
+grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
+
+[[package]]
+name = "gql"
+version = "3.4.1"
+description = "GraphQL client for Python"
+optional = true
+python-versions = "*"
+files = [
+ {file = "gql-3.4.1-py2.py3-none-any.whl", hash = "sha256:315624ca0f4d571ef149d455033ebd35e45c1a13f18a059596aeddcea99135cf"},
+ {file = "gql-3.4.1.tar.gz", hash = "sha256:11dc5d8715a827f2c2899593439a4f36449db4f0eafa5b1ea63948f8a2f8c545"},
+]
+
+[package.dependencies]
+backoff = ">=1.11.1,<3.0"
+graphql-core = ">=3.2,<3.3"
+yarl = ">=1.6,<2.0"
+
+[package.extras]
+aiohttp = ["aiohttp (>=3.7.1,<3.9.0)"]
+all = ["aiohttp (>=3.7.1,<3.9.0)", "botocore (>=1.21,<2)", "requests (>=2.26,<3)", "requests-toolbelt (>=0.9.1,<1)", "urllib3 (>=1.26,<2)", "websockets (>=10,<11)", "websockets (>=9,<10)"]
+botocore = ["botocore (>=1.21,<2)"]
+dev = ["aiofiles", "aiohttp (>=3.7.1,<3.9.0)", "black (==22.3.0)", "botocore (>=1.21,<2)", "check-manifest (>=0.42,<1)", "flake8 (==3.8.1)", "isort (==4.3.21)", "mock (==4.0.2)", "mypy (==0.910)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=0.9.1,<1)", "sphinx (>=3.0.0,<4)", "sphinx-argparse (==0.2.5)", "sphinx-rtd-theme (>=0.4,<1)", "types-aiofiles", "types-mock", "types-requests", "urllib3 (>=1.26,<2)", "vcrpy (==4.0.2)", "websockets (>=10,<11)", "websockets (>=9,<10)"]
+requests = ["requests (>=2.26,<3)", "requests-toolbelt (>=0.9.1,<1)", "urllib3 (>=1.26,<2)"]
+test = ["aiofiles", "aiohttp (>=3.7.1,<3.9.0)", "botocore (>=1.21,<2)", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=0.9.1,<1)", "urllib3 (>=1.26,<2)", "vcrpy (==4.0.2)", "websockets (>=10,<11)", "websockets (>=9,<10)"]
+test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.0.2)"]
+websockets = ["websockets (>=10,<11)", "websockets (>=9,<10)"]
+
+[[package]]
+name = "graphql-core"
+version = "3.2.3"
+description = "GraphQL implementation for Python, a port of GraphQL.js, the JavaScript reference implementation for GraphQL."
+optional = true
+python-versions = ">=3.6,<4"
+files = [
+ {file = "graphql-core-3.2.3.tar.gz", hash = "sha256:06d2aad0ac723e35b1cb47885d3e5c45e956a53bc1b209a9fc5369007fe46676"},
+ {file = "graphql_core-3.2.3-py3-none-any.whl", hash = "sha256:5766780452bd5ec8ba133f8bf287dc92713e3868ddd83aee4faab9fc3e303dc3"},
+]
+
+[[package]]
+name = "greenlet"
+version = "3.0.1"
+description = "Lightweight in-process concurrent programming"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "greenlet-3.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f89e21afe925fcfa655965ca8ea10f24773a1791400989ff32f467badfe4a064"},
+ {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28e89e232c7593d33cac35425b58950789962011cc274aa43ef8865f2e11f46d"},
+ {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8ba29306c5de7717b5761b9ea74f9c72b9e2b834e24aa984da99cbfc70157fd"},
+ {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19bbdf1cce0346ef7341705d71e2ecf6f41a35c311137f29b8a2dc2341374565"},
+ {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:599daf06ea59bfedbec564b1692b0166a0045f32b6f0933b0dd4df59a854caf2"},
+ {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b641161c302efbb860ae6b081f406839a8b7d5573f20a455539823802c655f63"},
+ {file = "greenlet-3.0.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d57e20ba591727da0c230ab2c3f200ac9d6d333860d85348816e1dca4cc4792e"},
+ {file = "greenlet-3.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5805e71e5b570d490938d55552f5a9e10f477c19400c38bf1d5190d760691846"},
+ {file = "greenlet-3.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:52e93b28db27ae7d208748f45d2db8a7b6a380e0d703f099c949d0f0d80b70e9"},
+ {file = "greenlet-3.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f7bfb769f7efa0eefcd039dd19d843a4fbfbac52f1878b1da2ed5793ec9b1a65"},
+ {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91e6c7db42638dc45cf2e13c73be16bf83179f7859b07cfc139518941320be96"},
+ {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1757936efea16e3f03db20efd0cd50a1c86b06734f9f7338a90c4ba85ec2ad5a"},
+ {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19075157a10055759066854a973b3d1325d964d498a805bb68a1f9af4aaef8ec"},
+ {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9d21aaa84557d64209af04ff48e0ad5e28c5cca67ce43444e939579d085da72"},
+ {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2847e5d7beedb8d614186962c3d774d40d3374d580d2cbdab7f184580a39d234"},
+ {file = "greenlet-3.0.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:97e7ac860d64e2dcba5c5944cfc8fa9ea185cd84061c623536154d5a89237884"},
+ {file = "greenlet-3.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b2c02d2ad98116e914d4f3155ffc905fd0c025d901ead3f6ed07385e19122c94"},
+ {file = "greenlet-3.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:22f79120a24aeeae2b4471c711dcf4f8c736a2bb2fabad2a67ac9a55ea72523c"},
+ {file = "greenlet-3.0.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:100f78a29707ca1525ea47388cec8a049405147719f47ebf3895e7509c6446aa"},
+ {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60d5772e8195f4e9ebf74046a9121bbb90090f6550f81d8956a05387ba139353"},
+ {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:daa7197b43c707462f06d2c693ffdbb5991cbb8b80b5b984007de431493a319c"},
+ {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea6b8aa9e08eea388c5f7a276fabb1d4b6b9d6e4ceb12cc477c3d352001768a9"},
+ {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d11ebbd679e927593978aa44c10fc2092bc454b7d13fdc958d3e9d508aba7d0"},
+ {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dbd4c177afb8a8d9ba348d925b0b67246147af806f0b104af4d24f144d461cd5"},
+ {file = "greenlet-3.0.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20107edf7c2c3644c67c12205dc60b1bb11d26b2610b276f97d666110d1b511d"},
+ {file = "greenlet-3.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8bef097455dea90ffe855286926ae02d8faa335ed8e4067326257cb571fc1445"},
+ {file = "greenlet-3.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:b2d3337dcfaa99698aa2377c81c9ca72fcd89c07e7eb62ece3f23a3fe89b2ce4"},
+ {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80ac992f25d10aaebe1ee15df45ca0d7571d0f70b645c08ec68733fb7a020206"},
+ {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:337322096d92808f76ad26061a8f5fccb22b0809bea39212cd6c406f6a7060d2"},
+ {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9934adbd0f6e476f0ecff3c94626529f344f57b38c9a541f87098710b18af0a"},
+ {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc4d815b794fd8868c4d67602692c21bf5293a75e4b607bb92a11e821e2b859a"},
+ {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41bdeeb552d814bcd7fb52172b304898a35818107cc8778b5101423c9017b3de"},
+ {file = "greenlet-3.0.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6e6061bf1e9565c29002e3c601cf68569c450be7fc3f7336671af7ddb4657166"},
+ {file = "greenlet-3.0.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:fa24255ae3c0ab67e613556375a4341af04a084bd58764731972bcbc8baeba36"},
+ {file = "greenlet-3.0.1-cp37-cp37m-win32.whl", hash = "sha256:b489c36d1327868d207002391f662a1d163bdc8daf10ab2e5f6e41b9b96de3b1"},
+ {file = "greenlet-3.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f33f3258aae89da191c6ebaa3bc517c6c4cbc9b9f689e5d8452f7aedbb913fa8"},
+ {file = "greenlet-3.0.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:d2905ce1df400360463c772b55d8e2518d0e488a87cdea13dd2c71dcb2a1fa16"},
+ {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a02d259510b3630f330c86557331a3b0e0c79dac3d166e449a39363beaae174"},
+ {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55d62807f1c5a1682075c62436702aaba941daa316e9161e4b6ccebbbf38bda3"},
+ {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3fcc780ae8edbb1d050d920ab44790201f027d59fdbd21362340a85c79066a74"},
+ {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4eddd98afc726f8aee1948858aed9e6feeb1758889dfd869072d4465973f6bfd"},
+ {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eabe7090db68c981fca689299c2d116400b553f4b713266b130cfc9e2aa9c5a9"},
+ {file = "greenlet-3.0.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f2f6d303f3dee132b322a14cd8765287b8f86cdc10d2cb6a6fae234ea488888e"},
+ {file = "greenlet-3.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d923ff276f1c1f9680d32832f8d6c040fe9306cbfb5d161b0911e9634be9ef0a"},
+ {file = "greenlet-3.0.1-cp38-cp38-win32.whl", hash = "sha256:0b6f9f8ca7093fd4433472fd99b5650f8a26dcd8ba410e14094c1e44cd3ceddd"},
+ {file = "greenlet-3.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:990066bff27c4fcf3b69382b86f4c99b3652bab2a7e685d968cd4d0cfc6f67c6"},
+ {file = "greenlet-3.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ce85c43ae54845272f6f9cd8320d034d7a946e9773c693b27d620edec825e376"},
+ {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89ee2e967bd7ff85d84a2de09df10e021c9b38c7d91dead95b406ed6350c6997"},
+ {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87c8ceb0cf8a5a51b8008b643844b7f4a8264a2c13fcbcd8a8316161725383fe"},
+ {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d6a8c9d4f8692917a3dc7eb25a6fb337bff86909febe2f793ec1928cd97bedfc"},
+ {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fbc5b8f3dfe24784cee8ce0be3da2d8a79e46a276593db6868382d9c50d97b1"},
+ {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85d2b77e7c9382f004b41d9c72c85537fac834fb141b0296942d52bf03fe4a3d"},
+ {file = "greenlet-3.0.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:696d8e7d82398e810f2b3622b24e87906763b6ebfd90e361e88eb85b0e554dc8"},
+ {file = "greenlet-3.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:329c5a2e5a0ee942f2992c5e3ff40be03e75f745f48847f118a3cfece7a28546"},
+ {file = "greenlet-3.0.1-cp39-cp39-win32.whl", hash = "sha256:cf868e08690cb89360eebc73ba4be7fb461cfbc6168dd88e2fbbe6f31812cd57"},
+ {file = "greenlet-3.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:ac4a39d1abae48184d420aa8e5e63efd1b75c8444dd95daa3e03f6c6310e9619"},
+ {file = "greenlet-3.0.1.tar.gz", hash = "sha256:816bd9488a94cba78d93e1abb58000e8266fa9cc2aa9ccdd6eb0696acb24005b"},
+]
+
+[package.extras]
+docs = ["Sphinx"]
+test = ["objgraph", "psutil"]
+
+[[package]]
+name = "grpc-google-iam-v1"
+version = "0.13.0"
+description = "IAM API client library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "grpc-google-iam-v1-0.13.0.tar.gz", hash = "sha256:fad318608b9e093258fbf12529180f400d1c44453698a33509cc6ecf005b294e"},
+ {file = "grpc_google_iam_v1-0.13.0-py2.py3-none-any.whl", hash = "sha256:53902e2af7de8df8c1bd91373d9be55b0743ec267a7428ea638db3775becae89"},
+]
+
+[package.dependencies]
+googleapis-common-protos = {version = ">=1.56.0,<2.0.0dev", extras = ["grpc"]}
+grpcio = ">=1.44.0,<2.0.0dev"
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
+
+[[package]]
+name = "grpcio"
+version = "1.56.0"
+description = "HTTP/2-based RPC framework"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "grpcio-1.56.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:fb34ace11419f1ae321c36ccaa18d81cd3f20728cd191250be42949d6845bb2d"},
+ {file = "grpcio-1.56.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:008767c0aed4899e657b50f2e0beacbabccab51359eba547f860e7c55f2be6ba"},
+ {file = "grpcio-1.56.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:17f47aeb9be0da5337f9ff33ebb8795899021e6c0741ee68bd69774a7804ca86"},
+ {file = "grpcio-1.56.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43c50d810cc26349b093bf2cfe86756ab3e9aba3e7e681d360930c1268e1399a"},
+ {file = "grpcio-1.56.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:187b8f71bad7d41eea15e0c9812aaa2b87adfb343895fffb704fb040ca731863"},
+ {file = "grpcio-1.56.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:881575f240eb5db72ddca4dc5602898c29bc082e0d94599bf20588fb7d1ee6a0"},
+ {file = "grpcio-1.56.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c243b158dd7585021d16c50498c4b2ec0a64a6119967440c5ff2d8c89e72330e"},
+ {file = "grpcio-1.56.0-cp310-cp310-win32.whl", hash = "sha256:8b3b2c7b5feef90bc9a5fa1c7f97637e55ec3e76460c6d16c3013952ee479cd9"},
+ {file = "grpcio-1.56.0-cp310-cp310-win_amd64.whl", hash = "sha256:03a80451530fd3b8b155e0c4480434f6be669daf7ecba56f73ef98f94222ee01"},
+ {file = "grpcio-1.56.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:64bd3abcf9fb4a9fa4ede8d0d34686314a7075f62a1502217b227991d9ca4245"},
+ {file = "grpcio-1.56.0-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:fdc3a895791af4addbb826808d4c9c35917c59bb5c430d729f44224e51c92d61"},
+ {file = "grpcio-1.56.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:4f84a6fd4482e5fe73b297d4874b62a535bc75dc6aec8e9fe0dc88106cd40397"},
+ {file = "grpcio-1.56.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:14e70b4dda3183abea94c72d41d5930c333b21f8561c1904a372d80370592ef3"},
+ {file = "grpcio-1.56.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b5ce42a5ebe3e04796246ba50357f1813c44a6efe17a37f8dc7a5c470377312"},
+ {file = "grpcio-1.56.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8219f17baf069fe8e42bd8ca0b312b875595e43a70cabf397be4fda488e2f27d"},
+ {file = "grpcio-1.56.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:defdd14b518e6e468466f799aaa69db0355bca8d3a5ea75fb912d28ba6f8af31"},
+ {file = "grpcio-1.56.0-cp311-cp311-win32.whl", hash = "sha256:50f4daa698835accbbcc60e61e0bc29636c0156ddcafb3891c987e533a0031ba"},
+ {file = "grpcio-1.56.0-cp311-cp311-win_amd64.whl", hash = "sha256:59c4e606993a47146fbeaf304b9e78c447f5b9ee5641cae013028c4cca784617"},
+ {file = "grpcio-1.56.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:b1f4b6f25a87d80b28dd6d02e87d63fe1577fe6d04a60a17454e3f8077a38279"},
+ {file = "grpcio-1.56.0-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:c2148170e01d464d41011a878088444c13413264418b557f0bdcd1bf1b674a0e"},
+ {file = "grpcio-1.56.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:0409de787ebbf08c9d2bca2bcc7762c1efe72eada164af78b50567a8dfc7253c"},
+ {file = "grpcio-1.56.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66f0369d27f4c105cd21059d635860bb2ea81bd593061c45fb64875103f40e4a"},
+ {file = "grpcio-1.56.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38fdf5bd0a1c754ce6bf9311a3c2c7ebe56e88b8763593316b69e0e9a56af1de"},
+ {file = "grpcio-1.56.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:79d4c5911d12a7aa671e5eb40cbb50a830396525014d2d6f254ea2ba180ce637"},
+ {file = "grpcio-1.56.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:5d2fc471668a7222e213f86ef76933b18cdda6a51ea1322034478df8c6519959"},
+ {file = "grpcio-1.56.0-cp37-cp37m-win_amd64.whl", hash = "sha256:991224fd485e088d3cb5e34366053691a4848a6b7112b8f5625a411305c26691"},
+ {file = "grpcio-1.56.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:c6f36621aabecbaff3e70c4d1d924c76c8e6a7ffec60c331893640a4af0a8037"},
+ {file = "grpcio-1.56.0-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:1eadd6de258901929223f422ffed7f8b310c0323324caf59227f9899ea1b1674"},
+ {file = "grpcio-1.56.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:72836b5a1d4f508ffbcfe35033d027859cc737972f9dddbe33fb75d687421e2e"},
+ {file = "grpcio-1.56.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f92a99ab0c7772fb6859bf2e4f44ad30088d18f7c67b83205297bfb229e0d2cf"},
+ {file = "grpcio-1.56.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa08affbf672d051cd3da62303901aeb7042a2c188c03b2c2a2d346fc5e81c14"},
+ {file = "grpcio-1.56.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2db108b4c8e29c145e95b0226973a66d73ae3e3e7fae00329294af4e27f1c42"},
+ {file = "grpcio-1.56.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8674fdbd28266d8efbcddacf4ec3643f76fe6376f73283fd63a8374c14b0ef7c"},
+ {file = "grpcio-1.56.0-cp38-cp38-win32.whl", hash = "sha256:bd55f743e654fb050c665968d7ec2c33f03578a4bbb163cfce38024775ff54cc"},
+ {file = "grpcio-1.56.0-cp38-cp38-win_amd64.whl", hash = "sha256:c63bc5ac6c7e646c296fed9139097ae0f0e63f36f0864d7ce431cce61fe0118a"},
+ {file = "grpcio-1.56.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:c0bc9dda550785d23f4f025be614b7faa8d0293e10811f0f8536cf50435b7a30"},
+ {file = "grpcio-1.56.0-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:d596408bab632ec7b947761e83ce6b3e7632e26b76d64c239ba66b554b7ee286"},
+ {file = "grpcio-1.56.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:76b6e6e1ee9bda32e6e933efd61c512e9a9f377d7c580977f090d1a9c78cca44"},
+ {file = "grpcio-1.56.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7beb84ebd0a3f732625124b73969d12b7350c5d9d64ddf81ae739bbc63d5b1ed"},
+ {file = "grpcio-1.56.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83ec714bbbe9b9502177c842417fde39f7a267031e01fa3cd83f1ca49688f537"},
+ {file = "grpcio-1.56.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:4feee75565d1b5ab09cb3a5da672b84ca7f6dd80ee07a50f5537207a9af543a4"},
+ {file = "grpcio-1.56.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b4638a796778329cc8e142e4f57c705adb286b3ba64e00b0fa91eeb919611be8"},
+ {file = "grpcio-1.56.0-cp39-cp39-win32.whl", hash = "sha256:437af5a7673bca89c4bc0a993382200592d104dd7bf55eddcd141cef91f40bab"},
+ {file = "grpcio-1.56.0-cp39-cp39-win_amd64.whl", hash = "sha256:4241a1c2c76e748023c834995cd916570e7180ee478969c2d79a60ce007bc837"},
+ {file = "grpcio-1.56.0.tar.gz", hash = "sha256:4c08ee21b3d10315b8dc26f6c13917b20ed574cdbed2d2d80c53d5508fdcc0f2"},
+]
+
+[package.extras]
+protobuf = ["grpcio-tools (>=1.56.0)"]
+
+[[package]]
+name = "grpcio"
+version = "1.59.3"
+description = "HTTP/2-based RPC framework"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "grpcio-1.59.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:aca028a6c7806e5b61e5f9f4232432c52856f7fcb98e330b20b6bc95d657bdcc"},
+ {file = "grpcio-1.59.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:19ad26a7967f7999c8960d2b9fe382dae74c55b0c508c613a6c2ba21cddf2354"},
+ {file = "grpcio-1.59.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:72b71dad2a3d1650e69ad42a5c4edbc59ee017f08c32c95694172bc501def23c"},
+ {file = "grpcio-1.59.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0f0a11d82d0253656cc42e04b6a149521e02e755fe2e4edd21123de610fd1d4"},
+ {file = "grpcio-1.59.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60cddafb70f9a2c81ba251b53b4007e07cca7389e704f86266e22c4bffd8bf1d"},
+ {file = "grpcio-1.59.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6c75a1fa0e677c1d2b6d4196ad395a5c381dfb8385f07ed034ef667cdcdbcc25"},
+ {file = "grpcio-1.59.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e1d8e01438d5964a11167eec1edb5f85ed8e475648f36c834ed5db4ffba24ac8"},
+ {file = "grpcio-1.59.3-cp310-cp310-win32.whl", hash = "sha256:c4b0076f0bf29ee62335b055a9599f52000b7941f577daa001c7ef961a1fbeab"},
+ {file = "grpcio-1.59.3-cp310-cp310-win_amd64.whl", hash = "sha256:b1f00a3e6e0c3dccccffb5579fc76ebfe4eb40405ba308505b41ef92f747746a"},
+ {file = "grpcio-1.59.3-cp311-cp311-linux_armv7l.whl", hash = "sha256:3996aaa21231451161dc29df6a43fcaa8b332042b6150482c119a678d007dd86"},
+ {file = "grpcio-1.59.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:cb4e9cbd9b7388fcb06412da9f188c7803742d06d6f626304eb838d1707ec7e3"},
+ {file = "grpcio-1.59.3-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:8022ca303d6c694a0d7acfb2b472add920217618d3a99eb4b14edc7c6a7e8fcf"},
+ {file = "grpcio-1.59.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b36683fad5664283755a7f4e2e804e243633634e93cd798a46247b8e54e3cb0d"},
+ {file = "grpcio-1.59.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8239b853226e4824e769517e1b5232e7c4dda3815b200534500338960fcc6118"},
+ {file = "grpcio-1.59.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0511af8653fbda489ff11d542a08505d56023e63cafbda60e6e00d4e0bae86ea"},
+ {file = "grpcio-1.59.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e78dc982bda74cef2ddfce1c91d29b96864c4c680c634e279ed204d51e227473"},
+ {file = "grpcio-1.59.3-cp311-cp311-win32.whl", hash = "sha256:6a5c3a96405966c023e139c3bcccb2c7c776a6f256ac6d70f8558c9041bdccc3"},
+ {file = "grpcio-1.59.3-cp311-cp311-win_amd64.whl", hash = "sha256:ed26826ee423b11477297b187371cdf4fa1eca874eb1156422ef3c9a60590dd9"},
+ {file = "grpcio-1.59.3-cp312-cp312-linux_armv7l.whl", hash = "sha256:45dddc5cb5227d30fa43652d8872dc87f086d81ab4b500be99413bad0ae198d7"},
+ {file = "grpcio-1.59.3-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:1736496d74682e53dd0907fd515f2694d8e6a96c9a359b4080b2504bf2b2d91b"},
+ {file = "grpcio-1.59.3-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:ddbd1a16138e52e66229047624de364f88a948a4d92ba20e4e25ad7d22eef025"},
+ {file = "grpcio-1.59.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fcfa56f8d031ffda902c258c84c4b88707f3a4be4827b4e3ab8ec7c24676320d"},
+ {file = "grpcio-1.59.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2eb8f0c7c0c62f7a547ad7a91ba627a5aa32a5ae8d930783f7ee61680d7eb8d"},
+ {file = "grpcio-1.59.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8d993399cc65e3a34f8fd48dd9ad7a376734564b822e0160dd18b3d00c1a33f9"},
+ {file = "grpcio-1.59.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c0bd141f4f41907eb90bda74d969c3cb21c1c62779419782a5b3f5e4b5835718"},
+ {file = "grpcio-1.59.3-cp312-cp312-win32.whl", hash = "sha256:33b8fd65d4e97efa62baec6171ce51f9cf68f3a8ba9f866f4abc9d62b5c97b79"},
+ {file = "grpcio-1.59.3-cp312-cp312-win_amd64.whl", hash = "sha256:0e735ed002f50d4f3cb9ecfe8ac82403f5d842d274c92d99db64cfc998515e07"},
+ {file = "grpcio-1.59.3-cp37-cp37m-linux_armv7l.whl", hash = "sha256:ea40ce4404e7cca0724c91a7404da410f0144148fdd58402a5942971e3469b94"},
+ {file = "grpcio-1.59.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83113bcc393477b6f7342b9f48e8a054330c895205517edc66789ceea0796b53"},
+ {file = "grpcio-1.59.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:73afbac602b8f1212a50088193601f869b5073efa9855b3e51aaaec97848fc8a"},
+ {file = "grpcio-1.59.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:575d61de1950b0b0699917b686b1ca108690702fcc2df127b8c9c9320f93e069"},
+ {file = "grpcio-1.59.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cd76057b5c9a4d68814610ef9226925f94c1231bbe533fdf96f6181f7d2ff9e"},
+ {file = "grpcio-1.59.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:95d6fd804c81efe4879e38bfd84d2b26e339a0a9b797e7615e884ef4686eb47b"},
+ {file = "grpcio-1.59.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0d42048b8a3286ea4134faddf1f9a59cf98192b94aaa10d910a25613c5eb5bfb"},
+ {file = "grpcio-1.59.3-cp37-cp37m-win_amd64.whl", hash = "sha256:4619fea15c64bcdd9d447cdbdde40e3d5f1da3a2e8ae84103d94a9c1df210d7e"},
+ {file = "grpcio-1.59.3-cp38-cp38-linux_armv7l.whl", hash = "sha256:95b5506e70284ac03b2005dd9ffcb6708c9ae660669376f0192a710687a22556"},
+ {file = "grpcio-1.59.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:9e17660947660ccfce56c7869032910c179a5328a77b73b37305cd1ee9301c2e"},
+ {file = "grpcio-1.59.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:00912ce19914d038851be5cd380d94a03f9d195643c28e3ad03d355cc02ce7e8"},
+ {file = "grpcio-1.59.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e58b3cadaa3c90f1efca26ba33e0d408b35b497307027d3d707e4bcd8de862a6"},
+ {file = "grpcio-1.59.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d787ecadea865bdf78f6679f6f5bf4b984f18f659257ba612979df97a298b3c3"},
+ {file = "grpcio-1.59.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0814942ba1bba269db4e760a34388640c601dece525c6a01f3b4ff030cc0db69"},
+ {file = "grpcio-1.59.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fb111aa99d3180c361a35b5ae1e2c63750220c584a1344229abc139d5c891881"},
+ {file = "grpcio-1.59.3-cp38-cp38-win32.whl", hash = "sha256:eb8ba504c726befe40a356ecbe63c6c3c64c9a439b3164f5a718ec53c9874da0"},
+ {file = "grpcio-1.59.3-cp38-cp38-win_amd64.whl", hash = "sha256:cdbc6b32fadab9bebc6f49d3e7ec4c70983c71e965497adab7f87de218e84391"},
+ {file = "grpcio-1.59.3-cp39-cp39-linux_armv7l.whl", hash = "sha256:c82ca1e4be24a98a253d6dbaa216542e4163f33f38163fc77964b0f0d255b552"},
+ {file = "grpcio-1.59.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:36636babfda14f9e9687f28d5b66d349cf88c1301154dc71c6513de2b6c88c59"},
+ {file = "grpcio-1.59.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:5f9b2e591da751ac7fdd316cc25afafb7a626dededa9b414f90faad7f3ccebdb"},
+ {file = "grpcio-1.59.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a93a82876a4926bf451db82ceb725bd87f42292bacc94586045261f501a86994"},
+ {file = "grpcio-1.59.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce31fa0bfdd1f2bb15b657c16105c8652186eab304eb512e6ae3b99b2fdd7d13"},
+ {file = "grpcio-1.59.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:16da0e40573962dab6cba16bec31f25a4f468e6d05b658e589090fe103b03e3d"},
+ {file = "grpcio-1.59.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d1a17372fd425addd5812049fa7374008ffe689585f27f802d0935522cf4b7"},
+ {file = "grpcio-1.59.3-cp39-cp39-win32.whl", hash = "sha256:52cc38a7241b5f7b4a91aaf9000fdd38e26bb00d5e8a71665ce40cfcee716281"},
+ {file = "grpcio-1.59.3-cp39-cp39-win_amd64.whl", hash = "sha256:b491e5bbcad3020a96842040421e508780cade35baba30f402df9d321d1c423e"},
+ {file = "grpcio-1.59.3.tar.gz", hash = "sha256:7800f99568a74a06ebdccd419dd1b6e639b477dcaf6da77ea702f8fb14ce5f80"},
+]
+
+[package.extras]
+protobuf = ["grpcio-tools (>=1.59.3)"]
+
+[[package]]
+name = "grpcio-status"
+version = "1.48.2"
+description = "Status proto mapping for gRPC"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "grpcio-status-1.48.2.tar.gz", hash = "sha256:53695f45da07437b7c344ee4ef60d370fd2850179f5a28bb26d8e2aa1102ec11"},
+ {file = "grpcio_status-1.48.2-py3-none-any.whl", hash = "sha256:2c33bbdbe20188b2953f46f31af669263b6ee2a9b2d38fa0d36ee091532e21bf"},
+]
+
+[package.dependencies]
+googleapis-common-protos = ">=1.5.5"
+grpcio = ">=1.48.2"
+protobuf = ">=3.12.0"
+
+[[package]]
+name = "h11"
+version = "0.14.0"
+description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
+ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
+]
+
+[[package]]
+name = "hologres-vector"
+version = "0.0.6"
+description = ""
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "hologres_vector-0.0.6-py3-none-any.whl", hash = "sha256:c506eaafd9ae8c529955605fae71856e95191a64dde144d0a25b06536e6544a4"},
+ {file = "hologres_vector-0.0.6.tar.gz", hash = "sha256:13251b74bcb9ef2af61cc39c6f155e16452e03891c2f0a07f708f0157baf7b08"},
+]
+
+[package.dependencies]
+psycopg2-binary = "*"
+typing = "*"
+uuid = "*"
+
+[[package]]
+name = "html2text"
+version = "2020.1.16"
+description = "Turn HTML into equivalent Markdown-structured text."
+optional = true
+python-versions = ">=3.5"
+files = [
+ {file = "html2text-2020.1.16-py3-none-any.whl", hash = "sha256:c7c629882da0cf377d66f073329ccf34a12ed2adf0169b9285ae4e63ef54c82b"},
+ {file = "html2text-2020.1.16.tar.gz", hash = "sha256:e296318e16b059ddb97f7a8a1d6a5c1d7af4544049a01e261731d2d5cc277bbb"},
+]
+
+[[package]]
+name = "httpcore"
+version = "0.17.3"
+description = "A minimal low-level HTTP client."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "httpcore-0.17.3-py3-none-any.whl", hash = "sha256:c2789b767ddddfa2a5782e3199b2b7f6894540b17b16ec26b2c4d8e103510b87"},
+ {file = "httpcore-0.17.3.tar.gz", hash = "sha256:a6f30213335e34c1ade7be6ec7c47f19f50c56db36abef1a9dfa3815b1cb3888"},
+]
+
+[package.dependencies]
+anyio = ">=3.0,<5.0"
+certifi = "*"
+h11 = ">=0.13,<0.15"
+sniffio = "==1.*"
+
+[package.extras]
+http2 = ["h2 (>=3,<5)"]
+socks = ["socksio (==1.*)"]
+
+[[package]]
+name = "httpx"
+version = "0.24.1"
+description = "The next generation HTTP client."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "httpx-0.24.1-py3-none-any.whl", hash = "sha256:06781eb9ac53cde990577af654bd990a4949de37a28bdb4a230d434f3a30b9bd"},
+ {file = "httpx-0.24.1.tar.gz", hash = "sha256:5853a43053df830c20f8110c5e69fe44d035d850b2dfe795e196f00fdb774bdd"},
+]
+
+[package.dependencies]
+certifi = "*"
+httpcore = ">=0.15.0,<0.18.0"
+idna = "*"
+sniffio = "*"
+
+[package.extras]
+brotli = ["brotli", "brotlicffi"]
+cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
+http2 = ["h2 (>=3,<5)"]
+socks = ["socksio (==1.*)"]
+
+[[package]]
+name = "httpx-sse"
+version = "0.3.1"
+description = "Consume Server-Sent Event (SSE) messages with HTTPX."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "httpx-sse-0.3.1.tar.gz", hash = "sha256:3bb3289b2867f50cbdb2fee3eeeefecb1e86653122e164faac0023f1ffc88aea"},
+ {file = "httpx_sse-0.3.1-py3-none-any.whl", hash = "sha256:7376dd88732892f9b6b549ac0ad05a8e2341172fe7dcf9f8f9c8050934297316"},
+]
+
+[[package]]
+name = "huggingface-hub"
+version = "0.19.4"
+description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
+optional = false
+python-versions = ">=3.8.0"
+files = [
+ {file = "huggingface_hub-0.19.4-py3-none-any.whl", hash = "sha256:dba013f779da16f14b606492828f3760600a1e1801432d09fe1c33e50b825bb5"},
+ {file = "huggingface_hub-0.19.4.tar.gz", hash = "sha256:176a4fc355a851c17550e7619488f383189727eab209534d7cef2114dae77b22"},
+]
+
+[package.dependencies]
+filelock = "*"
+fsspec = ">=2023.5.0"
+packaging = ">=20.9"
+pyyaml = ">=5.1"
+requests = "*"
+tqdm = ">=4.42.1"
+typing-extensions = ">=3.7.4.3"
+
+[package.extras]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+cli = ["InquirerPy (==0.3.4)"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)", "watchdog"]
+fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
+inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"]
+quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"]
+tensorflow = ["graphviz", "pydot", "tensorflow"]
+testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+torch = ["torch"]
+typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
+
+[[package]]
+name = "humanfriendly"
+version = "10.0"
+description = "Human friendly output for text interfaces using Python"
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+ {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"},
+ {file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"},
+]
+
+[package.dependencies]
+pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""}
+
+[[package]]
+name = "idna"
+version = "3.6"
+description = "Internationalized Domain Names in Applications (IDNA)"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "idna-3.6-py3-none-any.whl", hash = "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f"},
+ {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"},
+]
+
+[[package]]
+name = "importlib-metadata"
+version = "6.11.0"
+description = "Read metadata from Python packages"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"},
+ {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"},
+]
+
+[package.dependencies]
+zipp = ">=0.5"
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
+perf = ["ipython"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
+
+[[package]]
+name = "importlib-resources"
+version = "6.1.1"
+description = "Read resources from Python packages"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"},
+ {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"},
+]
+
+[package.dependencies]
+zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"]
+
+[[package]]
+name = "iniconfig"
+version = "2.0.0"
+description = "brain-dead simple config-ini parsing"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
+ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
+]
+
+[[package]]
+name = "ipykernel"
+version = "6.27.1"
+description = "IPython Kernel for Jupyter"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "ipykernel-6.27.1-py3-none-any.whl", hash = "sha256:dab88b47f112f9f7df62236511023c9bdeef67abc73af7c652e4ce4441601686"},
+ {file = "ipykernel-6.27.1.tar.gz", hash = "sha256:7d5d594b6690654b4d299edba5e872dc17bb7396a8d0609c97cb7b8a1c605de6"},
+]
+
+[package.dependencies]
+appnope = {version = "*", markers = "platform_system == \"Darwin\""}
+comm = ">=0.1.1"
+debugpy = ">=1.6.5"
+ipython = ">=7.23.1"
+jupyter-client = ">=6.1.12"
+jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+matplotlib-inline = ">=0.1"
+nest-asyncio = "*"
+packaging = "*"
+psutil = "*"
+pyzmq = ">=20"
+tornado = ">=6.1"
+traitlets = ">=5.4.0"
+
+[package.extras]
+cov = ["coverage[toml]", "curio", "matplotlib", "pytest-cov", "trio"]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "trio"]
+pyqt5 = ["pyqt5"]
+pyside6 = ["pyside6"]
+test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov", "pytest-timeout"]
+
+[[package]]
+name = "ipython"
+version = "8.12.3"
+description = "IPython: Productive Interactive Computing"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "ipython-8.12.3-py3-none-any.whl", hash = "sha256:b0340d46a933d27c657b211a329d0be23793c36595acf9e6ef4164bc01a1804c"},
+ {file = "ipython-8.12.3.tar.gz", hash = "sha256:3910c4b54543c2ad73d06579aa771041b7d5707b033bd488669b4cf544e3b363"},
+]
+
+[package.dependencies]
+appnope = {version = "*", markers = "sys_platform == \"darwin\""}
+backcall = "*"
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+decorator = "*"
+jedi = ">=0.16"
+matplotlib-inline = "*"
+pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""}
+pickleshare = "*"
+prompt-toolkit = ">=3.0.30,<3.0.37 || >3.0.37,<3.1.0"
+pygments = ">=2.4.0"
+stack-data = "*"
+traitlets = ">=5"
+typing-extensions = {version = "*", markers = "python_version < \"3.10\""}
+
+[package.extras]
+all = ["black", "curio", "docrepr", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.21)", "pandas", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"]
+black = ["black"]
+doc = ["docrepr", "ipykernel", "matplotlib", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"]
+kernel = ["ipykernel"]
+nbconvert = ["nbconvert"]
+nbformat = ["nbformat"]
+notebook = ["ipywidgets", "notebook"]
+parallel = ["ipyparallel"]
+qtconsole = ["qtconsole"]
+test = ["pytest (<7.1)", "pytest-asyncio", "testpath"]
+test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"]
+
+[[package]]
+name = "ipywidgets"
+version = "8.1.1"
+description = "Jupyter interactive widgets"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "ipywidgets-8.1.1-py3-none-any.whl", hash = "sha256:2b88d728656aea3bbfd05d32c747cfd0078f9d7e159cf982433b58ad717eed7f"},
+ {file = "ipywidgets-8.1.1.tar.gz", hash = "sha256:40211efb556adec6fa450ccc2a77d59ca44a060f4f9f136833df59c9f538e6e8"},
+]
+
+[package.dependencies]
+comm = ">=0.1.3"
+ipython = ">=6.1.0"
+jupyterlab-widgets = ">=3.0.9,<3.1.0"
+traitlets = ">=4.3.1"
+widgetsnbextension = ">=4.0.9,<4.1.0"
+
+[package.extras]
+test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"]
+
+[[package]]
+name = "isoduration"
+version = "20.11.0"
+description = "Operations with ISO 8601 durations"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "isoduration-20.11.0-py3-none-any.whl", hash = "sha256:b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042"},
+ {file = "isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9"},
+]
+
+[package.dependencies]
+arrow = ">=0.15.0"
+
+[[package]]
+name = "javelin-sdk"
+version = "0.1.8"
+description = "Python client for Javelin"
+optional = true
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "javelin_sdk-0.1.8-py3-none-any.whl", hash = "sha256:7843e278f99fa04fcc659b31844f6205141b956e24f331a1cac1ae30d9eb3a55"},
+ {file = "javelin_sdk-0.1.8.tar.gz", hash = "sha256:57fa669c68f75296fdce20242023429a79755be22e0d3182dbad62d8f6bb1dd7"},
+]
+
+[package.dependencies]
+httpx = ">=0.24.0,<0.25.0"
+pydantic = ">=1.10.7,<2.0.0"
+
+[[package]]
+name = "jedi"
+version = "0.19.1"
+description = "An autocompletion tool for Python that can be used for text editors."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"},
+ {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"},
+]
+
+[package.dependencies]
+parso = ">=0.8.3,<0.9.0"
+
+[package.extras]
+docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"]
+qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
+testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
+
+[[package]]
+name = "jieba3k"
+version = "0.35.1"
+description = "Chinese Words Segementation Utilities"
+optional = true
+python-versions = "*"
+files = [
+ {file = "jieba3k-0.35.1.zip", hash = "sha256:980a4f2636b778d312518066be90c7697d410dd5a472385f5afced71a2db1c10"},
+]
+
+[[package]]
+name = "jinja2"
+version = "3.1.2"
+description = "A very fast and expressive template engine."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
+ {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
+]
+
+[package.dependencies]
+MarkupSafe = ">=2.0"
+
+[package.extras]
+i18n = ["Babel (>=2.7)"]
+
+[[package]]
+name = "jmespath"
+version = "1.0.1"
+description = "JSON Matching Expressions"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"},
+ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
+]
+
+[[package]]
+name = "joblib"
+version = "1.3.2"
+description = "Lightweight pipelining with Python functions"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"},
+ {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
+]
+
+[[package]]
+name = "jq"
+version = "1.6.0"
+description = "jq is a lightweight and flexible JSON processor."
+optional = true
+python-versions = ">=3.5"
+files = [
+ {file = "jq-1.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5773851cfb9ec6525f362f5bf7f18adab5c1fd1f0161c3599264cd0118c799da"},
+ {file = "jq-1.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a758df4eae767a21ebd8466dfd0066d99c9741d9f7fd4a7e1d5b5227e1924af7"},
+ {file = "jq-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15cf9dd3e7fb40d029f12f60cf418374c0b830a6ea6267dd285b48809069d6af"},
+ {file = "jq-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7e768cf5c25d703d944ef81c787d745da0eb266a97768f3003f91c4c828118d"},
+ {file = "jq-1.6.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:85a697b3cdc65e787f90faa1237caa44c117b6b2853f21263c3f0b16661b192c"},
+ {file = "jq-1.6.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:944e081c328501ddc0a22a8f08196df72afe7910ca11e1a1f21244410dbdd3b3"},
+ {file = "jq-1.6.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:09262d0e0cafb03acc968622e6450bb08abfb14c793bab47afd2732b47c655fd"},
+ {file = "jq-1.6.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:611f460f616f957d57e0da52ac6e1e6294b073c72a89651da5546a31347817bd"},
+ {file = "jq-1.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aba35b5cc07cd75202148e55f47ede3f4d0819b51c80f6d0c82a2ca47db07189"},
+ {file = "jq-1.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5ddb76b03610df19a53583348aed3604f21d0ba6b583ee8d079e8df026cd47"},
+ {file = "jq-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:872f322ff7bfd7daff41b7e8248d414a88722df0e82d1027f3b091a438543e63"},
+ {file = "jq-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca7a2982ff26f4620ac03099542a0230dabd8787af3f03ac93660598e26acbf0"},
+ {file = "jq-1.6.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:316affc6debf15eb05b7fd8e84ebf8993042b10b840e8d2a504659fb3ba07992"},
+ {file = "jq-1.6.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9bc42ade4de77fe4370c0e8e105ef10ad1821ef74d61dcc70982178b9ecfdc72"},
+ {file = "jq-1.6.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:02da59230912b886ed45489f3693ce75877f3e99c9e490c0a2dbcf0db397e0df"},
+ {file = "jq-1.6.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7ea39f89aa469eb12145ddd686248916cd6d186647aa40b319af8444b1f45a2d"},
+ {file = "jq-1.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6e9016f5ba064fabc527adb609ebae1f27cac20c8e0da990abae1cfb12eca706"},
+ {file = "jq-1.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:022be104a548f7fbddf103ce749937956df9d37a4f2f1650396dacad73bce7ee"},
+ {file = "jq-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d5a7f31f779e1aa3d165eaec237d74c7f5728227e81023a576c939ba3da34f8"},
+ {file = "jq-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f1533a2a15c42be3368878b4031b12f30441246878e0b5f6bedfdd7828cdb1f"},
+ {file = "jq-1.6.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8aa67a304e58aa85c550ec011a68754ae49abe227b37d63a351feef4eea4c7a7"},
+ {file = "jq-1.6.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0893d1590cfa6facaf787cc6c28ac51e47d0d06a303613f84d4943ac0ca98e32"},
+ {file = "jq-1.6.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:63db80b4803905a4f4f6c87a17aa1816c530f6262bc795773ebe60f8ab259092"},
+ {file = "jq-1.6.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e2c1f429e644cb962e846a6157b5352c3c556fbd0b22bba9fc2fea0710333369"},
+ {file = "jq-1.6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:bcf574f28809ec63b8df6456fdd4a981751b7466851e80621993b4e9d3e3c8ee"},
+ {file = "jq-1.6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:49dbe0f003b411ca52b5d0afaf09cad8e430a1011181c86f2ef720a0956f31c1"},
+ {file = "jq-1.6.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f5a9c4185269a5faf395aa7ca086c7b02c9c8b448d542be3b899041d06e0970"},
+ {file = "jq-1.6.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8265f3badcd125f234e55dfc02a078c5decdc6faafcd453fde04d4c0d2699886"},
+ {file = "jq-1.6.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c6c39b53d000d2f7f9f6338061942b83c9034d04f3bc99acae0867d23c9e7127"},
+ {file = "jq-1.6.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:9897931ea7b9a46f8165ee69737ece4a2e6dbc8e10ececb81f459d51d71401df"},
+ {file = "jq-1.6.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:6312237159e88e92775ea497e0c739590528062d4074544aacf12a08d252f966"},
+ {file = "jq-1.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:aa786a60bdd1a3571f092a4021dd9abf6c46798530fa99f19ecf4f0fceaa7eaf"},
+ {file = "jq-1.6.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22495573d8221320d3433e1aeded40132bd8e1726845629558bd73aaa66eef7b"},
+ {file = "jq-1.6.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:711eabc5d33ef3ec581e0744d9cff52f43896d84847a2692c287a0140a29c915"},
+ {file = "jq-1.6.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57e75c1563d083b0424690b3c3ef2bb519e670770931fe633101ede16615d6ee"},
+ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c795f175b1a13bd716a0c180d062cc8e305271f47bbdb9eb0f0f62f7e4f5def4"},
+ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
+ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
+ {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
+ {file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
+ {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
+ {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
+ {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
+ {file = "jq-1.6.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1b3b95d5fd20e51f18a42647fdb52e5d8aaf150b7a666dd659cf282a2221ee3f"},
+ {file = "jq-1.6.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a8d98f72111043e75610cad7fa9ec5aec0b1ee2f7332dc7fd0f6603ea8144f8"},
+ {file = "jq-1.6.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:487483f10ae8f70e6acf7723f31b329736de4b421ce56b2f43b46d5cbd7337b0"},
+ {file = "jq-1.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:18a700f55b7ef83a1382edf0a48cb176b22bacd155e097375ef2345ff8621d97"},
+ {file = "jq-1.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68aec8534ac3c4705e524b4ef54f66b8bdc867df9e0af2c3895e82c6774b5374"},
+ {file = "jq-1.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7a164748dbd03bb06d23bab7ead7ba7e5c4fcfebea7b082bdcd21d14136931e"},
+ {file = "jq-1.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa22d24740276a8ce82411e4960ed2b5fab476230f913f9d9cf726f766a22208"},
+ {file = "jq-1.6.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c1a6fae1b74b3e0478e281eb6addedad7b32421221ac685e21c1d49af5e997f"},
+ {file = "jq-1.6.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ce628546c22792b8870b9815086f65873ebb78d7bf617b5a16dd839adba36538"},
+ {file = "jq-1.6.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7bb685f337cf5d4f4fe210c46220e31a7baec02a0ca0df3ace3dd4780328fc30"},
+ {file = "jq-1.6.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bdbbc509a35ee6082d79c1f25eb97c08f1c59043d21e0772cd24baa909505899"},
+ {file = "jq-1.6.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1b332dfdf0d81fb7faf3d12aabf997565d7544bec9812e0ac5ee55e60ef4df8c"},
+ {file = "jq-1.6.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:3a4f6ef8c0bd19beae56074c50026665d66345d1908f050e5c442ceac2efe398"},
+ {file = "jq-1.6.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5184c2fcca40f8f2ab1b14662721accf68b4b5e772e2f5336fec24aa58fe235a"},
+ {file = "jq-1.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689429fe1e07a2d6041daba2c21ced3a24895b2745326deb0c90ccab9386e116"},
+ {file = "jq-1.6.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8405d1c996c83711570f16aac32e3bf2c116d6fa4254a820276b87aed544d7e8"},
+ {file = "jq-1.6.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:138d56c7efc8bb162c1cfc3806bd6b4d779115943af36c9e3b8ca644dde856c2"},
+ {file = "jq-1.6.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd28f8395687e45bba56dc771284ebb6492b02037f74f450176c102f3f4e86a3"},
+ {file = "jq-1.6.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2c783288bf10e67aad321b58735e663f4975d7ddfbfb0a5bca8428eee283bde"},
+ {file = "jq-1.6.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:206391ac5b2eb556720b94f0f131558cbf8d82d8cc7e0404e733eeef48bcd823"},
+ {file = "jq-1.6.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:35090fea1283402abc3a13b43261468162199d8b5dcdaba2d1029e557ed23070"},
+ {file = "jq-1.6.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:201c6384603aec87a744ad7b393cc4f1c58ece23d6e0a6c216a47bfcc405d231"},
+ {file = "jq-1.6.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3d8b075351c29653f29a1fec5d31bc88aa198a0843c0a9550b9be74d8fab33b"},
+ {file = "jq-1.6.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:132e41f6e988c42b91c04b1b60dd8fa185a5c0681de5438ea1e6c64f5329768c"},
+ {file = "jq-1.6.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1cb4751808b1d0dbddd37319e0c574fb0c3a29910d52ba35890b1343a1f1e59"},
+ {file = "jq-1.6.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:bd158911ed5f5c644f557ad94d6424c411560632a885eae47d105f290f0109cb"},
+ {file = "jq-1.6.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:64bc09ae6a9d9b82b78e15d142f90b816228bd3ee48833ddca3ff8c08e163fa7"},
+ {file = "jq-1.6.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4eed167322662f4b7e65235723c54aa6879f6175b6f9b68bc24887549637ffb"},
+ {file = "jq-1.6.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64bb4b305e2fabe5b5161b599bf934aceb0e0e7d3dd8f79246737ea91a2bc9ae"},
+ {file = "jq-1.6.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:165bfbe29bf73878d073edf75f384b7da8a9657ba0ab9fb1e5fe6be65ab7debb"},
+ {file = "jq-1.6.0.tar.gz", hash = "sha256:c7711f0c913a826a00990736efa6ffc285f8ef433414516bb14b7df971d6c1ea"},
+]
+
+[[package]]
+name = "json5"
+version = "0.9.14"
+description = "A Python implementation of the JSON5 data format."
+optional = false
+python-versions = "*"
+files = [
+ {file = "json5-0.9.14-py2.py3-none-any.whl", hash = "sha256:740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f"},
+ {file = "json5-0.9.14.tar.gz", hash = "sha256:9ed66c3a6ca3510a976a9ef9b8c0787de24802724ab1860bc0153c7fdd589b02"},
+]
+
+[package.extras]
+dev = ["hypothesis"]
+
+[[package]]
+name = "jsonable"
+version = "0.3.1"
+description = "An abstract class that supports jsonserialization/deserialization."
+optional = true
+python-versions = "*"
+files = [
+ {file = "jsonable-0.3.1-py2.py3-none-any.whl", hash = "sha256:f7754dd27b4734e42e7f8a61c2336bc98082f715e31e29a061a95843b102dc3a"},
+ {file = "jsonable-0.3.1.tar.gz", hash = "sha256:137b676e8e5819fa58518678c3d1f5463cab7e8466f69b3641cbc438042eaee4"},
+]
+
+[[package]]
+name = "jsonpatch"
+version = "1.33"
+description = "Apply JSON-Patches (RFC 6902)"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
+files = [
+ {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
+ {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"},
+]
+
+[package.dependencies]
+jsonpointer = ">=1.9"
+
+[[package]]
+name = "jsonpointer"
+version = "2.4"
+description = "Identify specific nodes in a JSON document (RFC 6901)"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
+files = [
+ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
+ {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
+]
+
+[[package]]
+name = "jsonschema"
+version = "4.20.0"
+description = "An implementation of JSON Schema validation for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jsonschema-4.20.0-py3-none-any.whl", hash = "sha256:ed6231f0429ecf966f5bc8dfef245998220549cbbcf140f913b7464c52c3b6b3"},
+ {file = "jsonschema-4.20.0.tar.gz", hash = "sha256:4f614fd46d8d61258610998997743ec5492a648b33cf478c1ddc23ed4598a5fa"},
+]
+
+[package.dependencies]
+attrs = ">=22.2.0"
+fqdn = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
+idna = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
+importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
+isoduration = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
+jsonpointer = {version = ">1.13", optional = true, markers = "extra == \"format-nongpl\""}
+jsonschema-specifications = ">=2023.03.6"
+pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""}
+referencing = ">=0.28.4"
+rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
+rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""}
+rpds-py = ">=0.7.1"
+uri-template = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
+webcolors = {version = ">=1.11", optional = true, markers = "extra == \"format-nongpl\""}
+
+[package.extras]
+format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
+format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"]
+
+[[package]]
+name = "jsonschema-specifications"
+version = "2023.11.2"
+description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jsonschema_specifications-2023.11.2-py3-none-any.whl", hash = "sha256:e74ba7c0a65e8cb49dc26837d6cfe576557084a8b423ed16a420984228104f93"},
+ {file = "jsonschema_specifications-2023.11.2.tar.gz", hash = "sha256:9472fc4fea474cd74bea4a2b190daeccb5a9e4db2ea80efcf7a1b582fc9a81b8"},
+]
+
+[package.dependencies]
+importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
+referencing = ">=0.31.0"
+
+[[package]]
+name = "jupyter"
+version = "1.0.0"
+description = "Jupyter metapackage. Install all the Jupyter components in one go."
+optional = false
+python-versions = "*"
+files = [
+ {file = "jupyter-1.0.0-py2.py3-none-any.whl", hash = "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78"},
+ {file = "jupyter-1.0.0.tar.gz", hash = "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"},
+ {file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"},
+]
+
+[package.dependencies]
+ipykernel = "*"
+ipywidgets = "*"
+jupyter-console = "*"
+nbconvert = "*"
+notebook = "*"
+qtconsole = "*"
+
+[[package]]
+name = "jupyter-client"
+version = "8.6.0"
+description = "Jupyter protocol implementation and client libraries"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyter_client-8.6.0-py3-none-any.whl", hash = "sha256:909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99"},
+ {file = "jupyter_client-8.6.0.tar.gz", hash = "sha256:0642244bb83b4764ae60d07e010e15f0e2d275ec4e918a8f7b80fbbef3ca60c7"},
+]
+
+[package.dependencies]
+importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""}
+jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+python-dateutil = ">=2.8.2"
+pyzmq = ">=23.0"
+tornado = ">=6.2"
+traitlets = ">=5.3"
+
+[package.extras]
+docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"]
+test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"]
+
+[[package]]
+name = "jupyter-console"
+version = "6.6.3"
+description = "Jupyter terminal console"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485"},
+ {file = "jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539"},
+]
+
+[package.dependencies]
+ipykernel = ">=6.14"
+ipython = "*"
+jupyter-client = ">=7.0.0"
+jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+prompt-toolkit = ">=3.0.30"
+pygments = "*"
+pyzmq = ">=17"
+traitlets = ">=5.4"
+
+[package.extras]
+test = ["flaky", "pexpect", "pytest"]
+
+[[package]]
+name = "jupyter-core"
+version = "5.5.0"
+description = "Jupyter core package. A base package on which Jupyter projects rely."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyter_core-5.5.0-py3-none-any.whl", hash = "sha256:e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805"},
+ {file = "jupyter_core-5.5.0.tar.gz", hash = "sha256:880b86053bf298a8724994f95e99b99130659022a4f7f45f563084b6223861d3"},
+]
+
+[package.dependencies]
+platformdirs = ">=2.5"
+pywin32 = {version = ">=300", markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\""}
+traitlets = ">=5.3"
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"]
+test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
+
+[[package]]
+name = "jupyter-events"
+version = "0.9.0"
+description = "Jupyter Event System library"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyter_events-0.9.0-py3-none-any.whl", hash = "sha256:d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf"},
+ {file = "jupyter_events-0.9.0.tar.gz", hash = "sha256:81ad2e4bc710881ec274d31c6c50669d71bbaa5dd9d01e600b56faa85700d399"},
+]
+
+[package.dependencies]
+jsonschema = {version = ">=4.18.0", extras = ["format-nongpl"]}
+python-json-logger = ">=2.0.4"
+pyyaml = ">=5.3"
+referencing = "*"
+rfc3339-validator = "*"
+rfc3986-validator = ">=0.1.1"
+traitlets = ">=5.3"
+
+[package.extras]
+cli = ["click", "rich"]
+docs = ["jupyterlite-sphinx", "myst-parser", "pydata-sphinx-theme", "sphinxcontrib-spelling"]
+test = ["click", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.19.0)", "pytest-console-scripts", "rich"]
+
+[[package]]
+name = "jupyter-lsp"
+version = "2.2.1"
+description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyter-lsp-2.2.1.tar.gz", hash = "sha256:b17fab6d70fe83c8896b0cff59237640038247c196056b43684a0902b6a9e0fb"},
+ {file = "jupyter_lsp-2.2.1-py3-none-any.whl", hash = "sha256:17a689910c5e4ae5e7d334b02f31d08ffbe98108f6f658fb05e4304b4345368b"},
+]
+
+[package.dependencies]
+importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""}
+jupyter-server = ">=1.1.2"
+
+[[package]]
+name = "jupyter-server"
+version = "2.11.2"
+description = "The backendβi.e. core services, APIs, and REST endpointsβto Jupyter web applications."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyter_server-2.11.2-py3-none-any.whl", hash = "sha256:0c548151b54bcb516ca466ec628f7f021545be137d01b5467877e87f6fff4374"},
+ {file = "jupyter_server-2.11.2.tar.gz", hash = "sha256:0c99f9367b0f24141e527544522430176613f9249849be80504c6d2b955004bb"},
+]
+
+[package.dependencies]
+anyio = ">=3.1.0"
+argon2-cffi = "*"
+jinja2 = "*"
+jupyter-client = ">=7.4.4"
+jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+jupyter-events = ">=0.9.0"
+jupyter-server-terminals = "*"
+nbconvert = ">=6.4.4"
+nbformat = ">=5.3.0"
+overrides = "*"
+packaging = "*"
+prometheus-client = "*"
+pywinpty = {version = "*", markers = "os_name == \"nt\""}
+pyzmq = ">=24"
+send2trash = ">=1.8.2"
+terminado = ">=0.8.3"
+tornado = ">=6.2.0"
+traitlets = ">=5.6.0"
+websocket-client = "*"
+
+[package.extras]
+docs = ["ipykernel", "jinja2", "jupyter-client", "jupyter-server", "myst-parser", "nbformat", "prometheus-client", "pydata-sphinx-theme", "send2trash", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-openapi (>=0.8.0)", "sphinxcontrib-spelling", "sphinxemoji", "tornado", "typing-extensions"]
+test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", "pytest-jupyter[server] (>=0.4)", "pytest-timeout", "requests"]
+
+[[package]]
+name = "jupyter-server-terminals"
+version = "0.4.4"
+description = "A Jupyter Server Extension Providing Terminals."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyter_server_terminals-0.4.4-py3-none-any.whl", hash = "sha256:75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36"},
+ {file = "jupyter_server_terminals-0.4.4.tar.gz", hash = "sha256:57ab779797c25a7ba68e97bcfb5d7740f2b5e8a83b5e8102b10438041a7eac5d"},
+]
+
+[package.dependencies]
+pywinpty = {version = ">=2.0.3", markers = "os_name == \"nt\""}
+terminado = ">=0.8.3"
+
+[package.extras]
+docs = ["jinja2", "jupyter-server", "mistune (<3.0)", "myst-parser", "nbformat", "packaging", "pydata-sphinx-theme", "sphinxcontrib-github-alt", "sphinxcontrib-openapi", "sphinxcontrib-spelling", "sphinxemoji", "tornado"]
+test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", "pytest-jupyter[server] (>=0.5.3)", "pytest-timeout"]
+
+[[package]]
+name = "jupyterlab"
+version = "4.0.9"
+description = "JupyterLab computational environment"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyterlab-4.0.9-py3-none-any.whl", hash = "sha256:9f6f8e36d543fdbcc3df961a1d6a3f524b4a4001be0327a398f68fa4e534107c"},
+ {file = "jupyterlab-4.0.9.tar.gz", hash = "sha256:9ebada41d52651f623c0c9f069ddb8a21d6848e4c887d8e5ddc0613166ed5c0b"},
+]
+
+[package.dependencies]
+async-lru = ">=1.0.0"
+importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""}
+importlib-resources = {version = ">=1.4", markers = "python_version < \"3.9\""}
+ipykernel = "*"
+jinja2 = ">=3.0.3"
+jupyter-core = "*"
+jupyter-lsp = ">=2.0.0"
+jupyter-server = ">=2.4.0,<3"
+jupyterlab-server = ">=2.19.0,<3"
+notebook-shim = ">=0.2"
+packaging = "*"
+tomli = {version = "*", markers = "python_version < \"3.11\""}
+tornado = ">=6.2.0"
+traitlets = "*"
+
+[package.extras]
+dev = ["black[jupyter] (==23.10.1)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.1.4)"]
+docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8,<7.2.0)", "sphinx-copybutton"]
+docs-screenshots = ["altair (==5.0.1)", "ipython (==8.14.0)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post0)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.2)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"]
+test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter (>=0.5.3)", "pytest-timeout", "pytest-tornasync", "requests", "requests-cache", "virtualenv"]
+
+[[package]]
+name = "jupyterlab-pygments"
+version = "0.3.0"
+description = "Pygments theme using JupyterLab CSS variables"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780"},
+ {file = "jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d"},
+]
+
+[[package]]
+name = "jupyterlab-server"
+version = "2.25.2"
+description = "A set of server components for JupyterLab and JupyterLab like applications."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyterlab_server-2.25.2-py3-none-any.whl", hash = "sha256:5b1798c9cc6a44f65c757de9f97fc06fc3d42535afbf47d2ace5e964ab447aaf"},
+ {file = "jupyterlab_server-2.25.2.tar.gz", hash = "sha256:bd0ec7a99ebcedc8bcff939ef86e52c378e44c2707e053fcd81d046ce979ee63"},
+]
+
+[package.dependencies]
+babel = ">=2.10"
+importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""}
+jinja2 = ">=3.0.3"
+json5 = ">=0.9.0"
+jsonschema = ">=4.18.0"
+jupyter-server = ">=1.21,<3"
+packaging = ">=21.3"
+requests = ">=2.31"
+
+[package.extras]
+docs = ["autodoc-traits", "jinja2 (<3.2.0)", "mistune (<4)", "myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-copybutton", "sphinxcontrib-openapi (>0.8)"]
+openapi = ["openapi-core (>=0.18.0,<0.19.0)", "ruamel-yaml"]
+test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.8.0)", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"]
+
+[[package]]
+name = "jupyterlab-widgets"
+version = "3.0.9"
+description = "Jupyter interactive widgets for JupyterLab"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "jupyterlab_widgets-3.0.9-py3-none-any.whl", hash = "sha256:3cf5bdf5b897bf3bccf1c11873aa4afd776d7430200f765e0686bd352487b58d"},
+ {file = "jupyterlab_widgets-3.0.9.tar.gz", hash = "sha256:6005a4e974c7beee84060fdfba341a3218495046de8ae3ec64888e5fe19fdb4c"},
+]
+
+[[package]]
+name = "langchain-core"
+version = "0.0.13-rc.2"
+description = "Building applications with LLMs through composability"
+optional = false
+python-versions = ">=3.8.1,<4.0"
+files = []
+develop = true
+
+[package.dependencies]
+anyio = ">=3,<5"
+jsonpatch = "^1.33"
+langsmith = "~0.0.63"
+packaging = "^23.2"
+pydantic = ">=1,<3"
+PyYAML = ">=5.3"
+requests = "^2"
+tenacity = "^8.1.0"
+
+[package.extras]
+extended-testing = ["jinja2 (>=3,<4)"]
+
+[package.source]
+type = "directory"
+url = "../core"
+
+[[package]]
+name = "langsmith"
+version = "0.0.69"
+description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
+optional = false
+python-versions = ">=3.8.1,<4.0"
+files = [
+ {file = "langsmith-0.0.69-py3-none-any.whl", hash = "sha256:49a2546bb83eedb0552673cf81a068bb08078d6d48471f4f1018e1d5c6aa46b1"},
+ {file = "langsmith-0.0.69.tar.gz", hash = "sha256:8fb5297f274db0576ec650d9bab0319acfbb6622d62bc5bb9fe31c6235dc0358"},
+]
+
+[package.dependencies]
+pydantic = ">=1,<3"
+requests = ">=2,<3"
+
+[[package]]
+name = "lark"
+version = "1.1.8"
+description = "a modern parsing library"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "lark-1.1.8-py3-none-any.whl", hash = "sha256:7d2c221a66a8165f3f81aacb958d26033d40d972fdb70213ab0a2e0627e29c86"},
+ {file = "lark-1.1.8.tar.gz", hash = "sha256:7ef424db57f59c1ffd6f0d4c2b705119927f566b68c0fe1942dddcc0e44391a5"},
+]
+
+[package.extras]
+atomic-cache = ["atomicwrites"]
+interegular = ["interegular (>=0.3.1,<0.4.0)"]
+nearley = ["js2py"]
+regex = ["regex"]
+
+[[package]]
+name = "lxml"
+version = "4.9.3"
+description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API."
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*"
+files = [
+ {file = "lxml-4.9.3-cp27-cp27m-macosx_11_0_x86_64.whl", hash = "sha256:b0a545b46b526d418eb91754565ba5b63b1c0b12f9bd2f808c852d9b4b2f9b5c"},
+ {file = "lxml-4.9.3-cp27-cp27m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:075b731ddd9e7f68ad24c635374211376aa05a281673ede86cbe1d1b3455279d"},
+ {file = "lxml-4.9.3-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1e224d5755dba2f4a9498e150c43792392ac9b5380aa1b845f98a1618c94eeef"},
+ {file = "lxml-4.9.3-cp27-cp27m-win32.whl", hash = "sha256:2c74524e179f2ad6d2a4f7caf70e2d96639c0954c943ad601a9e146c76408ed7"},
+ {file = "lxml-4.9.3-cp27-cp27m-win_amd64.whl", hash = "sha256:4f1026bc732b6a7f96369f7bfe1a4f2290fb34dce00d8644bc3036fb351a4ca1"},
+ {file = "lxml-4.9.3-cp27-cp27mu-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c0781a98ff5e6586926293e59480b64ddd46282953203c76ae15dbbbf302e8bb"},
+ {file = "lxml-4.9.3-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cef2502e7e8a96fe5ad686d60b49e1ab03e438bd9123987994528febd569868e"},
+ {file = "lxml-4.9.3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b86164d2cff4d3aaa1f04a14685cbc072efd0b4f99ca5708b2ad1b9b5988a991"},
+ {file = "lxml-4.9.3-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:42871176e7896d5d45138f6d28751053c711ed4d48d8e30b498da155af39aebd"},
+ {file = "lxml-4.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:ae8b9c6deb1e634ba4f1930eb67ef6e6bf6a44b6eb5ad605642b2d6d5ed9ce3c"},
+ {file = "lxml-4.9.3-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:411007c0d88188d9f621b11d252cce90c4a2d1a49db6c068e3c16422f306eab8"},
+ {file = "lxml-4.9.3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:cd47b4a0d41d2afa3e58e5bf1f62069255aa2fd6ff5ee41604418ca925911d76"},
+ {file = "lxml-4.9.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e2cb47860da1f7e9a5256254b74ae331687b9672dfa780eed355c4c9c3dbd23"},
+ {file = "lxml-4.9.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1247694b26342a7bf47c02e513d32225ededd18045264d40758abeb3c838a51f"},
+ {file = "lxml-4.9.3-cp310-cp310-win32.whl", hash = "sha256:cdb650fc86227eba20de1a29d4b2c1bfe139dc75a0669270033cb2ea3d391b85"},
+ {file = "lxml-4.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:97047f0d25cd4bcae81f9ec9dc290ca3e15927c192df17331b53bebe0e3ff96d"},
+ {file = "lxml-4.9.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:1f447ea5429b54f9582d4b955f5f1985f278ce5cf169f72eea8afd9502973dd5"},
+ {file = "lxml-4.9.3-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:57d6ba0ca2b0c462f339640d22882acc711de224d769edf29962b09f77129cbf"},
+ {file = "lxml-4.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:9767e79108424fb6c3edf8f81e6730666a50feb01a328f4a016464a5893f835a"},
+ {file = "lxml-4.9.3-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:71c52db65e4b56b8ddc5bb89fb2e66c558ed9d1a74a45ceb7dcb20c191c3df2f"},
+ {file = "lxml-4.9.3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:d73d8ecf8ecf10a3bd007f2192725a34bd62898e8da27eb9d32a58084f93962b"},
+ {file = "lxml-4.9.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0a3d3487f07c1d7f150894c238299934a2a074ef590b583103a45002035be120"},
+ {file = "lxml-4.9.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e28c51fa0ce5674be9f560c6761c1b441631901993f76700b1b30ca6c8378d6"},
+ {file = "lxml-4.9.3-cp311-cp311-win32.whl", hash = "sha256:0bfd0767c5c1de2551a120673b72e5d4b628737cb05414f03c3277bf9bed3305"},
+ {file = "lxml-4.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:25f32acefac14ef7bd53e4218fe93b804ef6f6b92ffdb4322bb6d49d94cad2bc"},
+ {file = "lxml-4.9.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:d3ff32724f98fbbbfa9f49d82852b159e9784d6094983d9a8b7f2ddaebb063d4"},
+ {file = "lxml-4.9.3-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:48d6ed886b343d11493129e019da91d4039826794a3e3027321c56d9e71505be"},
+ {file = "lxml-4.9.3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9a92d3faef50658dd2c5470af249985782bf754c4e18e15afb67d3ab06233f13"},
+ {file = "lxml-4.9.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b4e4bc18382088514ebde9328da057775055940a1f2e18f6ad2d78aa0f3ec5b9"},
+ {file = "lxml-4.9.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fc9b106a1bf918db68619fdcd6d5ad4f972fdd19c01d19bdb6bf63f3589a9ec5"},
+ {file = "lxml-4.9.3-cp312-cp312-win_amd64.whl", hash = "sha256:d37017287a7adb6ab77e1c5bee9bcf9660f90ff445042b790402a654d2ad81d8"},
+ {file = "lxml-4.9.3-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:56dc1f1ebccc656d1b3ed288f11e27172a01503fc016bcabdcbc0978b19352b7"},
+ {file = "lxml-4.9.3-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:578695735c5a3f51569810dfebd05dd6f888147a34f0f98d4bb27e92b76e05c2"},
+ {file = "lxml-4.9.3-cp35-cp35m-win32.whl", hash = "sha256:704f61ba8c1283c71b16135caf697557f5ecf3e74d9e453233e4771d68a1f42d"},
+ {file = "lxml-4.9.3-cp35-cp35m-win_amd64.whl", hash = "sha256:c41bfca0bd3532d53d16fd34d20806d5c2b1ace22a2f2e4c0008570bf2c58833"},
+ {file = "lxml-4.9.3-cp36-cp36m-macosx_11_0_x86_64.whl", hash = "sha256:64f479d719dc9f4c813ad9bb6b28f8390360660b73b2e4beb4cb0ae7104f1c12"},
+ {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:dd708cf4ee4408cf46a48b108fb9427bfa00b9b85812a9262b5c668af2533ea5"},
+ {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c31c7462abdf8f2ac0577d9f05279727e698f97ecbb02f17939ea99ae8daa98"},
+ {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:e3cd95e10c2610c360154afdc2f1480aea394f4a4f1ea0a5eacce49640c9b190"},
+ {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:4930be26af26ac545c3dffb662521d4e6268352866956672231887d18f0eaab2"},
+ {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4aec80cde9197340bc353d2768e2a75f5f60bacda2bab72ab1dc499589b3878c"},
+ {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:14e019fd83b831b2e61baed40cab76222139926b1fb5ed0e79225bc0cae14584"},
+ {file = "lxml-4.9.3-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:0c0850c8b02c298d3c7006b23e98249515ac57430e16a166873fc47a5d549287"},
+ {file = "lxml-4.9.3-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:aca086dc5f9ef98c512bac8efea4483eb84abbf926eaeedf7b91479feb092458"},
+ {file = "lxml-4.9.3-cp36-cp36m-win32.whl", hash = "sha256:50baa9c1c47efcaef189f31e3d00d697c6d4afda5c3cde0302d063492ff9b477"},
+ {file = "lxml-4.9.3-cp36-cp36m-win_amd64.whl", hash = "sha256:bef4e656f7d98aaa3486d2627e7d2df1157d7e88e7efd43a65aa5dd4714916cf"},
+ {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:46f409a2d60f634fe550f7133ed30ad5321ae2e6630f13657fb9479506b00601"},
+ {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:4c28a9144688aef80d6ea666c809b4b0e50010a2aca784c97f5e6bf143d9f129"},
+ {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:141f1d1a9b663c679dc524af3ea1773e618907e96075262726c7612c02b149a4"},
+ {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:53ace1c1fd5a74ef662f844a0413446c0629d151055340e9893da958a374f70d"},
+ {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:17a753023436a18e27dd7769e798ce302963c236bc4114ceee5b25c18c52c693"},
+ {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7d298a1bd60c067ea75d9f684f5f3992c9d6766fadbc0bcedd39750bf344c2f4"},
+ {file = "lxml-4.9.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:081d32421db5df44c41b7f08a334a090a545c54ba977e47fd7cc2deece78809a"},
+ {file = "lxml-4.9.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:23eed6d7b1a3336ad92d8e39d4bfe09073c31bfe502f20ca5116b2a334f8ec02"},
+ {file = "lxml-4.9.3-cp37-cp37m-win32.whl", hash = "sha256:1509dd12b773c02acd154582088820893109f6ca27ef7291b003d0e81666109f"},
+ {file = "lxml-4.9.3-cp37-cp37m-win_amd64.whl", hash = "sha256:120fa9349a24c7043854c53cae8cec227e1f79195a7493e09e0c12e29f918e52"},
+ {file = "lxml-4.9.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:4d2d1edbca80b510443f51afd8496be95529db04a509bc8faee49c7b0fb6d2cc"},
+ {file = "lxml-4.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:8d7e43bd40f65f7d97ad8ef5c9b1778943d02f04febef12def25f7583d19baac"},
+ {file = "lxml-4.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:71d66ee82e7417828af6ecd7db817913cb0cf9d4e61aa0ac1fde0583d84358db"},
+ {file = "lxml-4.9.3-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:6fc3c450eaa0b56f815c7b62f2b7fba7266c4779adcf1cece9e6deb1de7305ce"},
+ {file = "lxml-4.9.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:65299ea57d82fb91c7f019300d24050c4ddeb7c5a190e076b5f48a2b43d19c42"},
+ {file = "lxml-4.9.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:eadfbbbfb41b44034a4c757fd5d70baccd43296fb894dba0295606a7cf3124aa"},
+ {file = "lxml-4.9.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3e9bdd30efde2b9ccfa9cb5768ba04fe71b018a25ea093379c857c9dad262c40"},
+ {file = "lxml-4.9.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fcdd00edfd0a3001e0181eab3e63bd5c74ad3e67152c84f93f13769a40e073a7"},
+ {file = "lxml-4.9.3-cp38-cp38-win32.whl", hash = "sha256:57aba1bbdf450b726d58b2aea5fe47c7875f5afb2c4a23784ed78f19a0462574"},
+ {file = "lxml-4.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:92af161ecbdb2883c4593d5ed4815ea71b31fafd7fd05789b23100d081ecac96"},
+ {file = "lxml-4.9.3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:9bb6ad405121241e99a86efff22d3ef469024ce22875a7ae045896ad23ba2340"},
+ {file = "lxml-4.9.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:8ed74706b26ad100433da4b9d807eae371efaa266ffc3e9191ea436087a9d6a7"},
+ {file = "lxml-4.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fbf521479bcac1e25a663df882c46a641a9bff6b56dc8b0fafaebd2f66fb231b"},
+ {file = "lxml-4.9.3-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:303bf1edce6ced16bf67a18a1cf8339d0db79577eec5d9a6d4a80f0fb10aa2da"},
+ {file = "lxml-4.9.3-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:5515edd2a6d1a5a70bfcdee23b42ec33425e405c5b351478ab7dc9347228f96e"},
+ {file = "lxml-4.9.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:690dafd0b187ed38583a648076865d8c229661ed20e48f2335d68e2cf7dc829d"},
+ {file = "lxml-4.9.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b6420a005548ad52154c8ceab4a1290ff78d757f9e5cbc68f8c77089acd3c432"},
+ {file = "lxml-4.9.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bb3bb49c7a6ad9d981d734ef7c7193bc349ac338776a0360cc671eaee89bcf69"},
+ {file = "lxml-4.9.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d27be7405547d1f958b60837dc4c1007da90b8b23f54ba1f8b728c78fdb19d50"},
+ {file = "lxml-4.9.3-cp39-cp39-win32.whl", hash = "sha256:8df133a2ea5e74eef5e8fc6f19b9e085f758768a16e9877a60aec455ed2609b2"},
+ {file = "lxml-4.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:4dd9a263e845a72eacb60d12401e37c616438ea2e5442885f65082c276dfb2b2"},
+ {file = "lxml-4.9.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6689a3d7fd13dc687e9102a27e98ef33730ac4fe37795d5036d18b4d527abd35"},
+ {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:f6bdac493b949141b733c5345b6ba8f87a226029cbabc7e9e121a413e49441e0"},
+ {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:05186a0f1346ae12553d66df1cfce6f251589fea3ad3da4f3ef4e34b2d58c6a3"},
+ {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c2006f5c8d28dee289f7020f721354362fa304acbaaf9745751ac4006650254b"},
+ {file = "lxml-4.9.3-pp38-pypy38_pp73-macosx_11_0_x86_64.whl", hash = "sha256:5c245b783db29c4e4fbbbfc9c5a78be496c9fea25517f90606aa1f6b2b3d5f7b"},
+ {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:4fb960a632a49f2f089d522f70496640fdf1218f1243889da3822e0a9f5f3ba7"},
+ {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:50670615eaf97227d5dc60de2dc99fb134a7130d310d783314e7724bf163f75d"},
+ {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:9719fe17307a9e814580af1f5c6e05ca593b12fb7e44fe62450a5384dbf61b4b"},
+ {file = "lxml-4.9.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:3331bece23c9ee066e0fb3f96c61322b9e0f54d775fccefff4c38ca488de283a"},
+ {file = "lxml-4.9.3-pp39-pypy39_pp73-macosx_11_0_x86_64.whl", hash = "sha256:ed667f49b11360951e201453fc3967344d0d0263aa415e1619e85ae7fd17b4e0"},
+ {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:8b77946fd508cbf0fccd8e400a7f71d4ac0e1595812e66025bac475a8e811694"},
+ {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:e4da8ca0c0c0aea88fd46be8e44bd49716772358d648cce45fe387f7b92374a7"},
+ {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:fe4bda6bd4340caa6e5cf95e73f8fea5c4bfc55763dd42f1b50a94c1b4a2fbd4"},
+ {file = "lxml-4.9.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:f3df3db1d336b9356dd3112eae5f5c2b8b377f3bc826848567f10bfddfee77e9"},
+ {file = "lxml-4.9.3.tar.gz", hash = "sha256:48628bd53a426c9eb9bc066a923acaa0878d1e86129fd5359aee99285f4eed9c"},
+]
+
+[package.extras]
+cssselect = ["cssselect (>=0.7)"]
+html5 = ["html5lib"]
+htmlsoup = ["BeautifulSoup4"]
+source = ["Cython (>=0.29.35)"]
+
+[[package]]
+name = "markdown-it-py"
+version = "3.0.0"
+description = "Python port of markdown-it. Markdown parsing, done right!"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
+ {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
+]
+
+[package.dependencies]
+mdurl = ">=0.1,<1.0"
+
+[package.extras]
+benchmarking = ["psutil", "pytest", "pytest-benchmark"]
+code-style = ["pre-commit (>=3.0,<4.0)"]
+compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
+linkify = ["linkify-it-py (>=1,<3)"]
+plugins = ["mdit-py-plugins"]
+profiling = ["gprof2dot"]
+rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
+testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
+
+[[package]]
+name = "markdownify"
+version = "0.11.6"
+description = "Convert HTML to markdown."
+optional = true
+python-versions = "*"
+files = [
+ {file = "markdownify-0.11.6-py3-none-any.whl", hash = "sha256:ba35fe289d5e9073bcd7d2cad629278fe25f1a93741fcdc0bfb4f009076d8324"},
+ {file = "markdownify-0.11.6.tar.gz", hash = "sha256:009b240e0c9f4c8eaf1d085625dcd4011e12f0f8cec55dedf9ea6f7655e49bfe"},
+]
+
+[package.dependencies]
+beautifulsoup4 = ">=4.9,<5"
+six = ">=1.15,<2"
+
+[[package]]
+name = "markupsafe"
+version = "2.1.3"
+description = "Safely add untrusted strings to HTML/XML markup."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"},
+ {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"},
+]
+
+[[package]]
+name = "marshmallow"
+version = "3.20.1"
+description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "marshmallow-3.20.1-py3-none-any.whl", hash = "sha256:684939db93e80ad3561392f47be0230743131560a41c5110684c16e21ade0a5c"},
+ {file = "marshmallow-3.20.1.tar.gz", hash = "sha256:5d2371bbe42000f2b3fb5eaa065224df7d8f8597bc19a1bbfa5bfe7fba8da889"},
+]
+
+[package.dependencies]
+packaging = ">=17.0"
+
+[package.extras]
+dev = ["flake8 (==6.0.0)", "flake8-bugbear (==23.7.10)", "mypy (==1.4.1)", "pre-commit (>=2.4,<4.0)", "pytest", "pytz", "simplejson", "tox"]
+docs = ["alabaster (==0.7.13)", "autodocsumm (==0.2.11)", "sphinx (==7.0.1)", "sphinx-issues (==3.0.1)", "sphinx-version-warning (==1.1.2)"]
+lint = ["flake8 (==6.0.0)", "flake8-bugbear (==23.7.10)", "mypy (==1.4.1)", "pre-commit (>=2.4,<4.0)"]
+tests = ["pytest", "pytz", "simplejson"]
+
+[[package]]
+name = "matplotlib-inline"
+version = "0.1.6"
+description = "Inline Matplotlib backend for Jupyter"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"},
+ {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"},
+]
+
+[package.dependencies]
+traitlets = "*"
+
+[[package]]
+name = "mdurl"
+version = "0.1.2"
+description = "Markdown URL utilities"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
+ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
+]
+
+[[package]]
+name = "mistune"
+version = "3.0.2"
+description = "A sane and fast Markdown parser with useful plugins and renderers"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"},
+ {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"},
+]
+
+[[package]]
+name = "mlflow-skinny"
+version = "2.8.1"
+description = "MLflow: A Platform for ML Development and Productionization"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "mlflow-skinny-2.8.1.tar.gz", hash = "sha256:8f46462e2df5ffd93a7f7d92ad1d3d7335adbe5e8e999543a3879963ae576d33"},
+ {file = "mlflow_skinny-2.8.1-py3-none-any.whl", hash = "sha256:8e2a1a5b8f1e2a3437c1fab972115a4df25934cd07cd83b8eb70202af8ad814a"},
+]
+
+[package.dependencies]
+click = ">=7.0,<9"
+cloudpickle = "<3"
+databricks-cli = ">=0.8.7,<1"
+entrypoints = "<1"
+gitpython = ">=2.1.0,<4"
+importlib-metadata = ">=3.7.0,<4.7.0 || >4.7.0,<7"
+packaging = "<24"
+protobuf = ">=3.12.0,<5"
+pytz = "<2024"
+pyyaml = ">=5.1,<7"
+requests = ">=2.17.3,<3"
+sqlparse = ">=0.4.0,<1"
+
+[package.extras]
+aliyun-oss = ["aliyunstoreplugin"]
+databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"]
+extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"]
+gateway = ["aiohttp (<4)", "boto3 (>=1.28.56,<2)", "fastapi (<1)", "pydantic (>=1.0,<3)", "uvicorn[standard] (<1)", "watchfiles (<1)"]
+sqlserver = ["mlflow-dbstore"]
+xethub = ["mlflow-xethub"]
+
+[[package]]
+name = "motor"
+version = "3.3.2"
+description = "Non-blocking MongoDB driver for Tornado or asyncio"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "motor-3.3.2-py3-none-any.whl", hash = "sha256:6fe7e6f0c4f430b9e030b9d22549b732f7c2226af3ab71ecc309e4a1b7d19953"},
+ {file = "motor-3.3.2.tar.gz", hash = "sha256:d2fc38de15f1c8058f389c1a44a4d4105c0405c48c061cd492a654496f7bc26a"},
+]
+
+[package.dependencies]
+pymongo = ">=4.5,<5"
+
+[package.extras]
+aws = ["pymongo[aws] (>=4.5,<5)"]
+encryption = ["pymongo[encryption] (>=4.5,<5)"]
+gssapi = ["pymongo[gssapi] (>=4.5,<5)"]
+ocsp = ["pymongo[ocsp] (>=4.5,<5)"]
+snappy = ["pymongo[snappy] (>=4.5,<5)"]
+srv = ["pymongo[srv] (>=4.5,<5)"]
+test = ["aiohttp (<3.8.6)", "mockupdb", "motor[encryption]", "pytest (>=7)", "tornado (>=5)"]
+zstd = ["pymongo[zstd] (>=4.5,<5)"]
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+description = "Python library for arbitrary-precision floating-point arithmetic"
+optional = true
+python-versions = "*"
+files = [
+ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
+ {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
+]
+
+[package.extras]
+develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"]
+docs = ["sphinx"]
+gmpy = ["gmpy2 (>=2.1.0a4)"]
+tests = ["pytest (>=4.6)"]
+
+[[package]]
+name = "msal"
+version = "1.25.0"
+description = "The Microsoft Authentication Library (MSAL) for Python library"
+optional = true
+python-versions = ">=2.7"
+files = [
+ {file = "msal-1.25.0-py2.py3-none-any.whl", hash = "sha256:386df621becb506bc315a713ec3d4d5b5d6163116955c7dde23622f156b81af6"},
+ {file = "msal-1.25.0.tar.gz", hash = "sha256:f44329fdb59f4f044c779164a34474b8a44ad9e4940afbc4c3a3a2bbe90324d9"},
+]
+
+[package.dependencies]
+cryptography = ">=0.6,<44"
+PyJWT = {version = ">=1.0.0,<3", extras = ["crypto"]}
+requests = ">=2.0.0,<3"
+
+[package.extras]
+broker = ["pymsalruntime (>=0.13.2,<0.14)"]
+
+[[package]]
+name = "multidict"
+version = "6.0.4"
+description = "multidict implementation"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"},
+ {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"},
+ {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"},
+ {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"},
+ {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"},
+ {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"},
+ {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"},
+ {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"},
+ {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"},
+ {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"},
+ {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"},
+ {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"},
+ {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"},
+ {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"},
+ {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"},
+ {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"},
+ {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"},
+ {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"},
+ {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"},
+ {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"},
+ {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"},
+ {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"},
+ {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"},
+ {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"},
+ {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"},
+ {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"},
+ {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"},
+ {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"},
+ {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"},
+ {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"},
+ {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"},
+ {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"},
+ {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"},
+ {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"},
+ {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"},
+ {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"},
+ {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"},
+ {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"},
+ {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"},
+ {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"},
+ {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"},
+ {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"},
+ {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"},
+ {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"},
+ {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"},
+ {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"},
+ {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"},
+ {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"},
+ {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"},
+ {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"},
+ {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"},
+ {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"},
+ {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"},
+ {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"},
+ {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"},
+ {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"},
+ {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"},
+ {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"},
+ {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"},
+ {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"},
+ {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"},
+ {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"},
+ {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"},
+ {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"},
+ {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"},
+ {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"},
+ {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"},
+ {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"},
+ {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"},
+ {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"},
+ {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"},
+ {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"},
+ {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"},
+ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
+]
+
+[[package]]
+name = "multiprocess"
+version = "0.70.15"
+description = "better multiprocessing and multithreading in Python"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"},
+ {file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"},
+ {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"},
+ {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"},
+ {file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"},
+ {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"},
+ {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"},
+ {file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"},
+ {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"},
+ {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"},
+ {file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"},
+ {file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"},
+ {file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"},
+ {file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"},
+ {file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"},
+ {file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"},
+]
+
+[package.dependencies]
+dill = ">=0.3.7"
+
+[[package]]
+name = "mwcli"
+version = "0.0.3"
+description = "Utilities for processing MediaWiki on the command line."
+optional = true
+python-versions = "*"
+files = [
+ {file = "mwcli-0.0.3-py2.py3-none-any.whl", hash = "sha256:24a7e53730e6fa7e55626e4f2a61a0b016d5e0a9798306c1d8c71bcead0ab239"},
+ {file = "mwcli-0.0.3.tar.gz", hash = "sha256:00331bd0ff16b5721c9c6274d91e25fd355f45ec0773c8a0e3926eac058719a0"},
+]
+
+[package.dependencies]
+docopt = "*"
+mwxml = "*"
+para = "*"
+
+[[package]]
+name = "mwparserfromhell"
+version = "0.6.5"
+description = "MWParserFromHell is a parser for MediaWiki wikicode."
+optional = true
+python-versions = ">= 3.7"
+files = [
+ {file = "mwparserfromhell-0.6.5-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b3a941ea35fc4fb49fc8d9087490ee8d94e09fb8e08b3bca83fc99cb4577bb81"},
+ {file = "mwparserfromhell-0.6.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a3b27580eebda2685ab5e54381df0845f13acb8ca7d50f754378184756e13bf"},
+ {file = "mwparserfromhell-0.6.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6169f314d6c28f0f373b5b2b346c51058248c8897493ed7c490db7caa65ea729"},
+ {file = "mwparserfromhell-0.6.5-cp310-cp310-win32.whl", hash = "sha256:b60e575e1e5c17a2e316b12a143de04665c4b1189a61a3a534967d33b57394cd"},
+ {file = "mwparserfromhell-0.6.5-cp310-cp310-win_amd64.whl", hash = "sha256:30747186171f6c58858c04eb617dd82dff2ae06d6f9e1b94714698daa32bc664"},
+ {file = "mwparserfromhell-0.6.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:837e6adf0963ddf5317f789541ea109108515ccd2405cd1437ff8224294c3fa7"},
+ {file = "mwparserfromhell-0.6.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3ba57207582de52345e69187218bd35cf3675497fd383bc70e46c0c728d50f"},
+ {file = "mwparserfromhell-0.6.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7984215a21b0778b90724643df24e8dbb89aecb95af2ba56a42a1956fcbeb571"},
+ {file = "mwparserfromhell-0.6.5-cp311-cp311-win32.whl", hash = "sha256:0c055324ad12c80f1ee2175c1d1b29b997aab57f6010174e704de15fdcb1757b"},
+ {file = "mwparserfromhell-0.6.5-cp311-cp311-win_amd64.whl", hash = "sha256:f252f09c4bf5432bd91a6aa79c707753ff084454cb24f8b513187531d5f6295f"},
+ {file = "mwparserfromhell-0.6.5-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:61d1e01cc027fe3d94c7d3620cb6ea9648305795a66bb93747d418a15c0d1860"},
+ {file = "mwparserfromhell-0.6.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4932d27cd3f00b451579c97c31e45d1e236b643bb93eeddde8d4aca50d87e3e6"},
+ {file = "mwparserfromhell-0.6.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51e56993b7a351a44bdb9af7abbb72f3383fcb46f69e556f6116397598f6f3bb"},
+ {file = "mwparserfromhell-0.6.5-cp37-cp37m-win32.whl", hash = "sha256:eb1afb65e5b8a0e3eba35644347cd5304c6e7803571db042850dc0697bbe49a3"},
+ {file = "mwparserfromhell-0.6.5-cp37-cp37m-win_amd64.whl", hash = "sha256:05b8262dc13c83e023ea6d17e5e5bcef225c2c172621c71cad947958afbaf4e4"},
+ {file = "mwparserfromhell-0.6.5-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:7408b3ce5f0b328e86be3809e906fc378767ef5396565b7411963452ad3bbf12"},
+ {file = "mwparserfromhell-0.6.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1660c5558cd9a32b3e72c0e3aabdd6729a013d8e1b5695d4bdb478f691d9657e"},
+ {file = "mwparserfromhell-0.6.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b205ce558596c35eacec863b3c88e9081872aa56b471ffd4f54162480d75f8d1"},
+ {file = "mwparserfromhell-0.6.5-cp38-cp38-win32.whl", hash = "sha256:b09a62cac76ae0cb0daef309a93ecc23d3fbcd8e68a646517c6ac8479c4cc5fe"},
+ {file = "mwparserfromhell-0.6.5-cp38-cp38-win_amd64.whl", hash = "sha256:2ecc86c6b29354adb472553bf982b6bd05fd21ac41c44d454d2aac06ca456163"},
+ {file = "mwparserfromhell-0.6.5-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:76234df2d138542ae839bebe53d4e4f59b286d0287101f54d1b84d9d193d5848"},
+ {file = "mwparserfromhell-0.6.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5be5f476bf4077bfc6fefcb3ccb21900f63b36c09ef0bb63667e21f09be2198"},
+ {file = "mwparserfromhell-0.6.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b4d5e87d2b405eb493f86a3f0e513d4e2c30edde6b8d3b4f7d2a53ffac5d81a"},
+ {file = "mwparserfromhell-0.6.5-cp39-cp39-win32.whl", hash = "sha256:7f6e5505014d0e97e29bc01304e8f6a8d782dec55c53492cc7ca03d2a6d1e445"},
+ {file = "mwparserfromhell-0.6.5-cp39-cp39-win_amd64.whl", hash = "sha256:153243177c4c242880e9c4547880e834f01d04625ad0bc175693255dfb22dae5"},
+ {file = "mwparserfromhell-0.6.5.tar.gz", hash = "sha256:2bad0bff614576399e4470d6400ba29c52d595682a4b8de642afbb5bebf4a346"},
+]
+
+[[package]]
+name = "mwtypes"
+version = "0.3.2"
+description = "A set of types for processing MediaWiki data."
+optional = true
+python-versions = "*"
+files = [
+ {file = "mwtypes-0.3.2-py2.py3-none-any.whl", hash = "sha256:d6f3cae90eea4c88bc260101c8a082fb0ab22cca88e7474657b28cd9538794f3"},
+ {file = "mwtypes-0.3.2.tar.gz", hash = "sha256:dc1176c5965629c123e859b319ae6151d4e385531e9a781604c0d4ca3434e399"},
+]
+
+[package.dependencies]
+jsonable = ">=0.3.0"
+
+[[package]]
+name = "mwxml"
+version = "0.3.3"
+description = "A set of utilities for processing MediaWiki XML dump data."
+optional = true
+python-versions = "*"
+files = [
+ {file = "mwxml-0.3.3-py2.py3-none-any.whl", hash = "sha256:9695848b8b6987b6f6addc2a8accba5b2bcbc543702598194e182b508ab568a9"},
+ {file = "mwxml-0.3.3.tar.gz", hash = "sha256:0848df0cf2e293718f554311acf4715bd679f639f4e52cbe47d8206589db1d31"},
+]
+
+[package.dependencies]
+jsonschema = ">=2.5.1"
+mwcli = ">=0.0.2"
+mwtypes = ">=0.3.0"
+para = ">=0.0.1"
+
+[[package]]
+name = "mypy"
+version = "0.991"
+description = "Optional static typing for Python"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "mypy-0.991-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7d17e0a9707d0772f4a7b878f04b4fd11f6f5bcb9b3813975a9b13c9332153ab"},
+ {file = "mypy-0.991-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0714258640194d75677e86c786e80ccf294972cc76885d3ebbb560f11db0003d"},
+ {file = "mypy-0.991-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c8f3be99e8a8bd403caa8c03be619544bc2c77a7093685dcf308c6b109426c6"},
+ {file = "mypy-0.991-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9ec663ed6c8f15f4ae9d3c04c989b744436c16d26580eaa760ae9dd5d662eb"},
+ {file = "mypy-0.991-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4307270436fd7694b41f913eb09210faff27ea4979ecbcd849e57d2da2f65305"},
+ {file = "mypy-0.991-cp310-cp310-win_amd64.whl", hash = "sha256:901c2c269c616e6cb0998b33d4adbb4a6af0ac4ce5cd078afd7bc95830e62c1c"},
+ {file = "mypy-0.991-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d13674f3fb73805ba0c45eb6c0c3053d218aa1f7abead6e446d474529aafc372"},
+ {file = "mypy-0.991-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1c8cd4fb70e8584ca1ed5805cbc7c017a3d1a29fb450621089ffed3e99d1857f"},
+ {file = "mypy-0.991-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:209ee89fbb0deed518605edddd234af80506aec932ad28d73c08f1400ef80a33"},
+ {file = "mypy-0.991-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37bd02ebf9d10e05b00d71302d2c2e6ca333e6c2a8584a98c00e038db8121f05"},
+ {file = "mypy-0.991-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:26efb2fcc6b67e4d5a55561f39176821d2adf88f2745ddc72751b7890f3194ad"},
+ {file = "mypy-0.991-cp311-cp311-win_amd64.whl", hash = "sha256:3a700330b567114b673cf8ee7388e949f843b356a73b5ab22dd7cff4742a5297"},
+ {file = "mypy-0.991-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1f7d1a520373e2272b10796c3ff721ea1a0712288cafaa95931e66aa15798813"},
+ {file = "mypy-0.991-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:641411733b127c3e0dab94c45af15fea99e4468f99ac88b39efb1ad677da5711"},
+ {file = "mypy-0.991-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3d80e36b7d7a9259b740be6d8d906221789b0d836201af4234093cae89ced0cd"},
+ {file = "mypy-0.991-cp37-cp37m-win_amd64.whl", hash = "sha256:e62ebaad93be3ad1a828a11e90f0e76f15449371ffeecca4a0a0b9adc99abcef"},
+ {file = "mypy-0.991-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b86ce2c1866a748c0f6faca5232059f881cda6dda2a893b9a8373353cfe3715a"},
+ {file = "mypy-0.991-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ac6e503823143464538efda0e8e356d871557ef60ccd38f8824a4257acc18d93"},
+ {file = "mypy-0.991-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cca5adf694af539aeaa6ac633a7afe9bbd760df9d31be55ab780b77ab5ae8bf"},
+ {file = "mypy-0.991-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12c56bf73cdab116df96e4ff39610b92a348cc99a1307e1da3c3768bbb5b135"},
+ {file = "mypy-0.991-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:652b651d42f155033a1967739788c436491b577b6a44e4c39fb340d0ee7f0d70"},
+ {file = "mypy-0.991-cp38-cp38-win_amd64.whl", hash = "sha256:4175593dc25d9da12f7de8de873a33f9b2b8bdb4e827a7cae952e5b1a342e243"},
+ {file = "mypy-0.991-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:98e781cd35c0acf33eb0295e8b9c55cdbef64fcb35f6d3aa2186f289bed6e80d"},
+ {file = "mypy-0.991-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6d7464bac72a85cb3491c7e92b5b62f3dcccb8af26826257760a552a5e244aa5"},
+ {file = "mypy-0.991-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c9166b3f81a10cdf9b49f2d594b21b31adadb3d5e9db9b834866c3258b695be3"},
+ {file = "mypy-0.991-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8472f736a5bfb159a5e36740847808f6f5b659960115ff29c7cecec1741c648"},
+ {file = "mypy-0.991-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5e80e758243b97b618cdf22004beb09e8a2de1af481382e4d84bc52152d1c476"},
+ {file = "mypy-0.991-cp39-cp39-win_amd64.whl", hash = "sha256:74e259b5c19f70d35fcc1ad3d56499065c601dfe94ff67ae48b85596b9ec1461"},
+ {file = "mypy-0.991-py3-none-any.whl", hash = "sha256:de32edc9b0a7e67c2775e574cb061a537660e51210fbf6006b0b36ea695ae9bb"},
+ {file = "mypy-0.991.tar.gz", hash = "sha256:3c0165ba8f354a6d9881809ef29f1a9318a236a6d81c690094c5df32107bde06"},
+]
+
+[package.dependencies]
+mypy-extensions = ">=0.4.3"
+tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
+typing-extensions = ">=3.10"
+
+[package.extras]
+dmypy = ["psutil (>=4.0)"]
+install-types = ["pip"]
+python2 = ["typed-ast (>=1.4.0,<2)"]
+reports = ["lxml"]
+
+[[package]]
+name = "mypy-extensions"
+version = "1.0.0"
+description = "Type system extensions for programs checked with the mypy type checker."
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"},
+ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
+]
+
+[[package]]
+name = "mypy-protobuf"
+version = "3.3.0"
+description = "Generate mypy stub files from protobuf specs"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "mypy-protobuf-3.3.0.tar.gz", hash = "sha256:24f3b0aecb06656e983f58e07c732a90577b9d7af3e1066fc2b663bbf0370248"},
+ {file = "mypy_protobuf-3.3.0-py3-none-any.whl", hash = "sha256:15604f6943b16c05db646903261e3b3e775cf7f7990b7c37b03d043a907b650d"},
+]
+
+[package.dependencies]
+protobuf = ">=3.19.4"
+types-protobuf = ">=3.19.12"
+
+[[package]]
+name = "nbclient"
+version = "0.9.0"
+description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
+optional = false
+python-versions = ">=3.8.0"
+files = [
+ {file = "nbclient-0.9.0-py3-none-any.whl", hash = "sha256:a3a1ddfb34d4a9d17fc744d655962714a866639acd30130e9be84191cd97cd15"},
+ {file = "nbclient-0.9.0.tar.gz", hash = "sha256:4b28c207877cf33ef3a9838cdc7a54c5ceff981194a82eac59d558f05487295e"},
+]
+
+[package.dependencies]
+jupyter-client = ">=6.1.12"
+jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+nbformat = ">=5.1"
+traitlets = ">=5.4"
+
+[package.extras]
+dev = ["pre-commit"]
+docs = ["autodoc-traits", "mock", "moto", "myst-parser", "nbclient[test]", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling"]
+test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"]
+
+[[package]]
+name = "nbconvert"
+version = "7.12.0"
+description = "Converting Jupyter Notebooks"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "nbconvert-7.12.0-py3-none-any.whl", hash = "sha256:5b6c848194d270cc55fb691169202620d7b52a12fec259508d142ecbe4219310"},
+ {file = "nbconvert-7.12.0.tar.gz", hash = "sha256:b1564bd89f69a74cd6398b0362da94db07aafb991b7857216a766204a71612c0"},
+]
+
+[package.dependencies]
+beautifulsoup4 = "*"
+bleach = "!=5.0.0"
+defusedxml = "*"
+importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""}
+jinja2 = ">=3.0"
+jupyter-core = ">=4.7"
+jupyterlab-pygments = "*"
+markupsafe = ">=2.0"
+mistune = ">=2.0.3,<4"
+nbclient = ">=0.5.0"
+nbformat = ">=5.7"
+packaging = "*"
+pandocfilters = ">=1.4.1"
+pygments = ">=2.4.1"
+tinycss2 = "*"
+traitlets = ">=5.1"
+
+[package.extras]
+all = ["nbconvert[docs,qtpdf,serve,test,webpdf]"]
+docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sphinx-theme", "sphinx (==5.0.2)", "sphinxcontrib-spelling"]
+qtpdf = ["nbconvert[qtpng]"]
+qtpng = ["pyqtwebengine (>=5.15)"]
+serve = ["tornado (>=6.1)"]
+test = ["flaky", "ipykernel", "ipywidgets (>=7)", "pytest"]
+webpdf = ["playwright"]
+
+[[package]]
+name = "nbformat"
+version = "5.9.2"
+description = "The Jupyter Notebook format"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "nbformat-5.9.2-py3-none-any.whl", hash = "sha256:1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9"},
+ {file = "nbformat-5.9.2.tar.gz", hash = "sha256:5f98b5ba1997dff175e77e0c17d5c10a96eaed2cbd1de3533d1fc35d5e111192"},
+]
+
+[package.dependencies]
+fastjsonschema = "*"
+jsonschema = ">=2.6"
+jupyter-core = "*"
+traitlets = ">=5.1"
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"]
+test = ["pep440", "pre-commit", "pytest", "testpath"]
+
+[[package]]
+name = "nest-asyncio"
+version = "1.5.8"
+description = "Patch asyncio to allow nested event loops"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "nest_asyncio-1.5.8-py3-none-any.whl", hash = "sha256:accda7a339a70599cb08f9dd09a67e0c2ef8d8d6f4c07f96ab203f2ae254e48d"},
+ {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"},
+]
+
+[[package]]
+name = "newspaper3k"
+version = "0.2.8"
+description = "Simplified python article discovery & extraction."
+optional = true
+python-versions = "*"
+files = [
+ {file = "newspaper3k-0.2.8-py3-none-any.whl", hash = "sha256:44a864222633d3081113d1030615991c3dbba87239f6bbf59d91240f71a22e3e"},
+ {file = "newspaper3k-0.2.8.tar.gz", hash = "sha256:9f1bd3e1fb48f400c715abf875cc7b0a67b7ddcd87f50c9aeeb8fcbbbd9004fb"},
+]
+
+[package.dependencies]
+beautifulsoup4 = ">=4.4.1"
+cssselect = ">=0.9.2"
+feedfinder2 = ">=0.0.4"
+feedparser = ">=5.2.1"
+jieba3k = ">=0.35.1"
+lxml = ">=3.6.0"
+nltk = ">=3.2.1"
+Pillow = ">=3.3.0"
+python-dateutil = ">=2.5.3"
+PyYAML = ">=3.11"
+requests = ">=2.10.0"
+tinysegmenter = "0.3"
+tldextract = ">=2.0.1"
+
+[[package]]
+name = "nltk"
+version = "3.8.1"
+description = "Natural Language Toolkit"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"},
+ {file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"},
+]
+
+[package.dependencies]
+click = "*"
+joblib = "*"
+regex = ">=2021.8.3"
+tqdm = "*"
+
+[package.extras]
+all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"]
+corenlp = ["requests"]
+machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"]
+plot = ["matplotlib"]
+tgrep = ["pyparsing"]
+twitter = ["twython"]
+
+[[package]]
+name = "notebook"
+version = "7.0.6"
+description = "Jupyter Notebook - A web-based notebook environment for interactive computing"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "notebook-7.0.6-py3-none-any.whl", hash = "sha256:0fe8f67102fea3744fedf652e4c15339390902ca70c5a31c4f547fa23da697cc"},
+ {file = "notebook-7.0.6.tar.gz", hash = "sha256:ec6113b06529019f7f287819af06c97a2baf7a95ac21a8f6e32192898e9f9a58"},
+]
+
+[package.dependencies]
+jupyter-server = ">=2.4.0,<3"
+jupyterlab = ">=4.0.2,<5"
+jupyterlab-server = ">=2.22.1,<3"
+notebook-shim = ">=0.2,<0.3"
+tornado = ">=6.2.0"
+
+[package.extras]
+dev = ["hatch", "pre-commit"]
+docs = ["myst-parser", "nbsphinx", "pydata-sphinx-theme", "sphinx (>=1.3.6)", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"]
+test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.22.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"]
+
+[[package]]
+name = "notebook-shim"
+version = "0.2.3"
+description = "A shim layer for notebook traits and config"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "notebook_shim-0.2.3-py3-none-any.whl", hash = "sha256:a83496a43341c1674b093bfcebf0fe8e74cbe7eda5fd2bbc56f8e39e1486c0c7"},
+ {file = "notebook_shim-0.2.3.tar.gz", hash = "sha256:f69388ac283ae008cd506dda10d0288b09a017d822d5e8c7129a152cbd3ce7e9"},
+]
+
+[package.dependencies]
+jupyter-server = ">=1.8,<3"
+
+[package.extras]
+test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync"]
+
+[[package]]
+name = "numexpr"
+version = "2.8.6"
+description = "Fast numerical expression evaluator for NumPy"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "numexpr-2.8.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:80acbfefb68bd92e708e09f0a02b29e04d388b9ae72f9fcd57988aca172a7833"},
+ {file = "numexpr-2.8.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6e884687da8af5955dc9beb6a12d469675c90b8fb38b6c93668c989cfc2cd982"},
+ {file = "numexpr-2.8.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ef7e8aaa84fce3aba2e65f243d14a9f8cc92aafd5d90d67283815febfe43eeb"},
+ {file = "numexpr-2.8.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dee04d72307c09599f786b9231acffb10df7d7a74b2ce3681d74a574880d13ce"},
+ {file = "numexpr-2.8.6-cp310-cp310-win32.whl", hash = "sha256:211804ec25a9f6d188eadf4198dd1a92b2f61d7d20993c6c7706139bc4199c5b"},
+ {file = "numexpr-2.8.6-cp310-cp310-win_amd64.whl", hash = "sha256:18b1804923cfa3be7bbb45187d01c0540c8f6df4928c22a0f786e15568e9ebc5"},
+ {file = "numexpr-2.8.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:95b9da613761e4fc79748535b2a1f58cada22500e22713ae7d9571fa88d1c2e2"},
+ {file = "numexpr-2.8.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:47b45da5aa25600081a649f5e8b2aa640e35db3703f4631f34bb1f2f86d1b5b4"},
+ {file = "numexpr-2.8.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84979bf14143351c2db8d9dd7fef8aca027c66ad9df9cb5e75c93bf5f7b5a338"},
+ {file = "numexpr-2.8.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d36528a33aa9c23743b3ea686e57526a4f71e7128a1be66210e1511b09c4e4e9"},
+ {file = "numexpr-2.8.6-cp311-cp311-win32.whl", hash = "sha256:681812e2e71ff1ba9145fac42d03f51ddf6ba911259aa83041323f68e7458002"},
+ {file = "numexpr-2.8.6-cp311-cp311-win_amd64.whl", hash = "sha256:27782177a0081bd0aab229be5d37674e7f0ab4264ef576697323dd047432a4cd"},
+ {file = "numexpr-2.8.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ef6e8896457a60a539cb6ba27da78315a9bb31edb246829b25b5b0304bfcee91"},
+ {file = "numexpr-2.8.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e640bc0eaf1b59f3dde52bc02bbfda98e62f9950202b0584deba28baf9f36bbb"},
+ {file = "numexpr-2.8.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d126938c2c3784673c9c58d94e00b1570aa65517d9c33662234d442fc9fb5795"},
+ {file = "numexpr-2.8.6-cp37-cp37m-win32.whl", hash = "sha256:e93d64cd20940b726477c3cb64926e683d31b778a1e18f9079a5088fd0d8e7c8"},
+ {file = "numexpr-2.8.6-cp37-cp37m-win_amd64.whl", hash = "sha256:31cf610c952eec57081171f0b4427f9bed2395ec70ec432bbf45d260c5c0cdeb"},
+ {file = "numexpr-2.8.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b5f96c89aa0b1f13685ec32fa3d71028db0b5981bfd99a0bbc271035949136b3"},
+ {file = "numexpr-2.8.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c8f37f7a6af3bdd61f2efd1cafcc083a9525ab0aaf5dc641e7ec8fc0ae2d3aa1"},
+ {file = "numexpr-2.8.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38b8b90967026bbc36c7aa6e8ca3b8906e1990914fd21f446e2a043f4ee3bc06"},
+ {file = "numexpr-2.8.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1967c16f61c27df1cdc43ba3c0ba30346157048dd420b4259832276144d0f64e"},
+ {file = "numexpr-2.8.6-cp38-cp38-win32.whl", hash = "sha256:15469dc722b5ceb92324ec8635411355ebc702303db901ae8cc87f47c5e3a124"},
+ {file = "numexpr-2.8.6-cp38-cp38-win_amd64.whl", hash = "sha256:95c09e814b0d6549de98b5ded7cdf7d954d934bb6b505432ff82e83a6d330bda"},
+ {file = "numexpr-2.8.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:aa0f661f5f4872fd7350cc9895f5d2594794b2a7e7f1961649a351724c64acc9"},
+ {file = "numexpr-2.8.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8e3e6f1588d6c03877cb3b3dcc3096482da9d330013b886b29cb9586af5af3eb"},
+ {file = "numexpr-2.8.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8564186aad5a2c88d597ebc79b8171b52fd33e9b085013e1ff2208f7e4b387e3"},
+ {file = "numexpr-2.8.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6a88d71c166e86b98d34701285d23e3e89d548d9f5ae3f4b60919ac7151949f"},
+ {file = "numexpr-2.8.6-cp39-cp39-win32.whl", hash = "sha256:c48221b6a85494a7be5a022899764e58259af585dff031cecab337277278cc93"},
+ {file = "numexpr-2.8.6-cp39-cp39-win_amd64.whl", hash = "sha256:6d7003497d82ef19458dce380b36a99343b96a3bd5773465c2d898bf8f5a38f9"},
+ {file = "numexpr-2.8.6.tar.gz", hash = "sha256:6336f8dba3f456e41a4ffc3c97eb63d89c73589ff6e1707141224b930263260d"},
+]
+
+[package.dependencies]
+numpy = ">=1.13.3"
+
+[[package]]
+name = "numpy"
+version = "1.24.4"
+description = "Fundamental package for array computing in Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"},
+ {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"},
+ {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"},
+ {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"},
+ {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"},
+ {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"},
+ {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"},
+ {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"},
+ {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"},
+ {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"},
+ {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"},
+ {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"},
+ {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"},
+ {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"},
+ {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"},
+ {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"},
+ {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"},
+ {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"},
+ {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"},
+ {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"},
+ {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"},
+ {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"},
+ {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"},
+ {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"},
+ {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"},
+ {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"},
+ {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"},
+ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"},
+]
+
+[[package]]
+name = "oauthlib"
+version = "3.2.2"
+description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"},
+ {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"},
+]
+
+[package.extras]
+rsa = ["cryptography (>=3.0.0)"]
+signals = ["blinker (>=1.4.0)"]
+signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
+
+[[package]]
+name = "onnxruntime"
+version = "1.16.3"
+description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
+optional = true
+python-versions = "*"
+files = [
+ {file = "onnxruntime-1.16.3-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:3bc41f323ac77acfed190be8ffdc47a6a75e4beeb3473fbf55eeb075ccca8df2"},
+ {file = "onnxruntime-1.16.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:212741b519ee61a4822c79c47147d63a8b0ffde25cd33988d3d7be9fbd51005d"},
+ {file = "onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f91f5497fe3df4ceee2f9e66c6148d9bfeb320cd6a71df361c66c5b8bac985a"},
+ {file = "onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b1fc269cabd27f129fb9058917d6fdc89b188c49ed8700f300b945c81f889"},
+ {file = "onnxruntime-1.16.3-cp310-cp310-win32.whl", hash = "sha256:f36b56a593b49a3c430be008c2aea6658d91a3030115729609ec1d5ffbaab1b6"},
+ {file = "onnxruntime-1.16.3-cp310-cp310-win_amd64.whl", hash = "sha256:3c467eaa3d2429c026b10c3d17b78b7f311f718ef9d2a0d6938e5c3c2611b0cf"},
+ {file = "onnxruntime-1.16.3-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:a225bb683991001d111f75323d355b3590e75e16b5e0f07a0401e741a0143ea1"},
+ {file = "onnxruntime-1.16.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9aded21fe3d898edd86be8aa2eb995aa375e800ad3dfe4be9f618a20b8ee3630"},
+ {file = "onnxruntime-1.16.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00cccc37a5195c8fca5011b9690b349db435986bd508eb44c9fce432da9228a4"},
+ {file = "onnxruntime-1.16.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e253e572021563226a86f1c024f8f70cdae28f2fb1cc8c3a9221e8b1ce37db5"},
+ {file = "onnxruntime-1.16.3-cp311-cp311-win32.whl", hash = "sha256:a82a8f0b4c978d08f9f5c7a6019ae51151bced9fd91e5aaa0c20a9e4ac7a60b6"},
+ {file = "onnxruntime-1.16.3-cp311-cp311-win_amd64.whl", hash = "sha256:78d81d9af457a1dc90db9a7da0d09f3ccb1288ea1236c6ab19f0ca61f3eee2d3"},
+ {file = "onnxruntime-1.16.3-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:04ebcd29c20473596a1412e471524b2fb88d55e6301c40b98dd2407b5911595f"},
+ {file = "onnxruntime-1.16.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9996bab0f202a6435ab867bc55598f15210d0b72794d5de83712b53d564084ae"},
+ {file = "onnxruntime-1.16.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b8f5083f903408238883821dd8c775f8120cb4a604166dbdabe97f4715256d5"},
+ {file = "onnxruntime-1.16.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c2dcf1b70f8434abb1116fe0975c00e740722aaf321997195ea3618cc00558e"},
+ {file = "onnxruntime-1.16.3-cp38-cp38-win32.whl", hash = "sha256:d4a0151e1accd04da6711f6fd89024509602f82c65a754498e960b032359b02d"},
+ {file = "onnxruntime-1.16.3-cp38-cp38-win_amd64.whl", hash = "sha256:e8aa5bba78afbd4d8a2654b14ec7462ff3ce4a6aad312a3c2d2c2b65009f2541"},
+ {file = "onnxruntime-1.16.3-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:6829dc2a79d48c911fedaf4c0f01e03c86297d32718a3fdee7a282766dfd282a"},
+ {file = "onnxruntime-1.16.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:76f876c53bfa912c6c242fc38213a6f13f47612d4360bc9d599bd23753e53161"},
+ {file = "onnxruntime-1.16.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4137e5d443e2dccebe5e156a47f1d6d66f8077b03587c35f11ee0c7eda98b533"},
+ {file = "onnxruntime-1.16.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c56695c1a343c7c008b647fff3df44da63741fbe7b6003ef576758640719be7b"},
+ {file = "onnxruntime-1.16.3-cp39-cp39-win32.whl", hash = "sha256:985a029798744ce4743fcf8442240fed35c8e4d4d30ec7d0c2cdf1388cd44408"},
+ {file = "onnxruntime-1.16.3-cp39-cp39-win_amd64.whl", hash = "sha256:28ff758b17ce3ca6bcad3d936ec53bd7f5482e7630a13f6dcae518eba8f71d85"},
+]
+
+[package.dependencies]
+coloredlogs = "*"
+flatbuffers = "*"
+numpy = ">=1.21.6"
+packaging = "*"
+protobuf = "*"
+sympy = "*"
+
+[[package]]
+name = "openai"
+version = "1.3.7"
+description = "The official Python library for the openai API"
+optional = false
+python-versions = ">=3.7.1"
+files = [
+ {file = "openai-1.3.7-py3-none-any.whl", hash = "sha256:e5c51367a910297e4d1cd33d2298fb87d7edf681edbe012873925ac16f95bee0"},
+ {file = "openai-1.3.7.tar.gz", hash = "sha256:18074a0f51f9b49d1ae268c7abc36f7f33212a0c0d08ce11b7053ab2d17798de"},
+]
+
+[package.dependencies]
+anyio = ">=3.5.0,<4"
+distro = ">=1.7.0,<2"
+httpx = ">=0.23.0,<1"
+pydantic = ">=1.9.0,<3"
+sniffio = "*"
+tqdm = ">4"
+typing-extensions = ">=4.5,<5"
+
+[package.extras]
+datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
+
+[[package]]
+name = "openapi-pydantic"
+version = "0.3.2"
+description = "Pydantic OpenAPI schema implementation"
+optional = true
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "openapi_pydantic-0.3.2-py3-none-any.whl", hash = "sha256:24488566a0a61bee3b55de6d3665329adaf2aadfe8f292ac0bddfe22155fadac"},
+ {file = "openapi_pydantic-0.3.2.tar.gz", hash = "sha256:685aa631395c469ecfd04f01a2ffedd541f94d372943868a501b412e9de6ba8b"},
+]
+
+[package.dependencies]
+pydantic = ">=1.8"
+
+[[package]]
+name = "opencv-python"
+version = "4.8.1.78"
+description = "Wrapper package for OpenCV python bindings."
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "opencv-python-4.8.1.78.tar.gz", hash = "sha256:cc7adbbcd1112877a39274106cb2752e04984bc01a031162952e97450d6117f6"},
+ {file = "opencv_python-4.8.1.78-cp37-abi3-macosx_10_16_x86_64.whl", hash = "sha256:91d5f6f5209dc2635d496f6b8ca6573ecdad051a09e6b5de4c399b8e673c60da"},
+ {file = "opencv_python-4.8.1.78-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:bc31f47e05447da8b3089faa0a07ffe80e114c91ce0b171e6424f9badbd1c5cd"},
+ {file = "opencv_python-4.8.1.78-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9814beca408d3a0eca1bae7e3e5be68b07c17ecceb392b94170881216e09b319"},
+ {file = "opencv_python-4.8.1.78-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c406bdb41eb21ea51b4e90dfbc989c002786c3f601c236a99c59a54670a394"},
+ {file = "opencv_python-4.8.1.78-cp37-abi3-win32.whl", hash = "sha256:a7aac3900fbacf55b551e7b53626c3dad4c71ce85643645c43e91fcb19045e47"},
+ {file = "opencv_python-4.8.1.78-cp37-abi3-win_amd64.whl", hash = "sha256:b983197f97cfa6fcb74e1da1802c7497a6f94ed561aba6980f1f33123f904956"},
+]
+
+[package.dependencies]
+numpy = [
+ {version = ">=1.23.5", markers = "python_version >= \"3.11\""},
+ {version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\" and python_version >= \"3.8\""},
+ {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
+ {version = ">=1.17.3", markers = "(platform_system != \"Darwin\" and platform_system != \"Linux\") and python_version >= \"3.8\" and python_version < \"3.9\" or platform_system != \"Darwin\" and python_version >= \"3.8\" and python_version < \"3.9\" and platform_machine != \"aarch64\" or platform_machine != \"arm64\" and python_version >= \"3.8\" and python_version < \"3.9\" and platform_system != \"Linux\" or (platform_machine != \"arm64\" and platform_machine != \"aarch64\") and python_version >= \"3.8\" and python_version < \"3.9\""},
+ {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
+ {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
+]
+
+[[package]]
+name = "orjson"
+version = "3.9.10"
+description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "orjson-3.9.10-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:c18a4da2f50050a03d1da5317388ef84a16013302a5281d6f64e4a3f406aabc4"},
+ {file = "orjson-3.9.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5148bab4d71f58948c7c39d12b14a9005b6ab35a0bdf317a8ade9a9e4d9d0bd5"},
+ {file = "orjson-3.9.10-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4cf7837c3b11a2dfb589f8530b3cff2bd0307ace4c301e8997e95c7468c1378e"},
+ {file = "orjson-3.9.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c62b6fa2961a1dcc51ebe88771be5319a93fd89bd247c9ddf732bc250507bc2b"},
+ {file = "orjson-3.9.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:deeb3922a7a804755bbe6b5be9b312e746137a03600f488290318936c1a2d4dc"},
+ {file = "orjson-3.9.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1234dc92d011d3554d929b6cf058ac4a24d188d97be5e04355f1b9223e98bbe9"},
+ {file = "orjson-3.9.10-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:06ad5543217e0e46fd7ab7ea45d506c76f878b87b1b4e369006bdb01acc05a83"},
+ {file = "orjson-3.9.10-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4fd72fab7bddce46c6826994ce1e7de145ae1e9e106ebb8eb9ce1393ca01444d"},
+ {file = "orjson-3.9.10-cp310-none-win32.whl", hash = "sha256:b5b7d4a44cc0e6ff98da5d56cde794385bdd212a86563ac321ca64d7f80c80d1"},
+ {file = "orjson-3.9.10-cp310-none-win_amd64.whl", hash = "sha256:61804231099214e2f84998316f3238c4c2c4aaec302df12b21a64d72e2a135c7"},
+ {file = "orjson-3.9.10-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cff7570d492bcf4b64cc862a6e2fb77edd5e5748ad715f487628f102815165e9"},
+ {file = "orjson-3.9.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed8bc367f725dfc5cabeed1ae079d00369900231fbb5a5280cf0736c30e2adf7"},
+ {file = "orjson-3.9.10-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c812312847867b6335cfb264772f2a7e85b3b502d3a6b0586aa35e1858528ab1"},
+ {file = "orjson-3.9.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9edd2856611e5050004f4722922b7b1cd6268da34102667bd49d2a2b18bafb81"},
+ {file = "orjson-3.9.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:674eb520f02422546c40401f4efaf8207b5e29e420c17051cddf6c02783ff5ca"},
+ {file = "orjson-3.9.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d0dc4310da8b5f6415949bd5ef937e60aeb0eb6b16f95041b5e43e6200821fb"},
+ {file = "orjson-3.9.10-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e99c625b8c95d7741fe057585176b1b8783d46ed4b8932cf98ee145c4facf499"},
+ {file = "orjson-3.9.10-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ec6f18f96b47299c11203edfbdc34e1b69085070d9a3d1f302810cc23ad36bf3"},
+ {file = "orjson-3.9.10-cp311-none-win32.whl", hash = "sha256:ce0a29c28dfb8eccd0f16219360530bc3cfdf6bf70ca384dacd36e6c650ef8e8"},
+ {file = "orjson-3.9.10-cp311-none-win_amd64.whl", hash = "sha256:cf80b550092cc480a0cbd0750e8189247ff45457e5a023305f7ef1bcec811616"},
+ {file = "orjson-3.9.10-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:602a8001bdf60e1a7d544be29c82560a7b49319a0b31d62586548835bbe2c862"},
+ {file = "orjson-3.9.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f295efcd47b6124b01255d1491f9e46f17ef40d3d7eabf7364099e463fb45f0f"},
+ {file = "orjson-3.9.10-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:92af0d00091e744587221e79f68d617b432425a7e59328ca4c496f774a356071"},
+ {file = "orjson-3.9.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5a02360e73e7208a872bf65a7554c9f15df5fe063dc047f79738998b0506a14"},
+ {file = "orjson-3.9.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:858379cbb08d84fe7583231077d9a36a1a20eb72f8c9076a45df8b083724ad1d"},
+ {file = "orjson-3.9.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666c6fdcaac1f13eb982b649e1c311c08d7097cbda24f32612dae43648d8db8d"},
+ {file = "orjson-3.9.10-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3fb205ab52a2e30354640780ce4587157a9563a68c9beaf52153e1cea9aa0921"},
+ {file = "orjson-3.9.10-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:7ec960b1b942ee3c69323b8721df2a3ce28ff40e7ca47873ae35bfafeb4555ca"},
+ {file = "orjson-3.9.10-cp312-none-win_amd64.whl", hash = "sha256:3e892621434392199efb54e69edfff9f699f6cc36dd9553c5bf796058b14b20d"},
+ {file = "orjson-3.9.10-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8b9ba0ccd5a7f4219e67fbbe25e6b4a46ceef783c42af7dbc1da548eb28b6531"},
+ {file = "orjson-3.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e2ecd1d349e62e3960695214f40939bbfdcaeaaa62ccc638f8e651cf0970e5f"},
+ {file = "orjson-3.9.10-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7f433be3b3f4c66016d5a20e5b4444ef833a1f802ced13a2d852c637f69729c1"},
+ {file = "orjson-3.9.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4689270c35d4bb3102e103ac43c3f0b76b169760aff8bcf2d401a3e0e58cdb7f"},
+ {file = "orjson-3.9.10-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4bd176f528a8151a6efc5359b853ba3cc0e82d4cd1fab9c1300c5d957dc8f48c"},
+ {file = "orjson-3.9.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a2ce5ea4f71681623f04e2b7dadede3c7435dfb5e5e2d1d0ec25b35530e277b"},
+ {file = "orjson-3.9.10-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:49f8ad582da6e8d2cf663c4ba5bf9f83cc052570a3a767487fec6af839b0e777"},
+ {file = "orjson-3.9.10-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2a11b4b1a8415f105d989876a19b173f6cdc89ca13855ccc67c18efbd7cbd1f8"},
+ {file = "orjson-3.9.10-cp38-none-win32.whl", hash = "sha256:a353bf1f565ed27ba71a419b2cd3db9d6151da426b61b289b6ba1422a702e643"},
+ {file = "orjson-3.9.10-cp38-none-win_amd64.whl", hash = "sha256:e28a50b5be854e18d54f75ef1bb13e1abf4bc650ab9d635e4258c58e71eb6ad5"},
+ {file = "orjson-3.9.10-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ee5926746232f627a3be1cc175b2cfad24d0170d520361f4ce3fa2fd83f09e1d"},
+ {file = "orjson-3.9.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a73160e823151f33cdc05fe2cea557c5ef12fdf276ce29bb4f1c571c8368a60"},
+ {file = "orjson-3.9.10-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c338ed69ad0b8f8f8920c13f529889fe0771abbb46550013e3c3d01e5174deef"},
+ {file = "orjson-3.9.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5869e8e130e99687d9e4be835116c4ebd83ca92e52e55810962446d841aba8de"},
+ {file = "orjson-3.9.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d2c1e559d96a7f94a4f581e2a32d6d610df5840881a8cba8f25e446f4d792df3"},
+ {file = "orjson-3.9.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a3a3a72c9811b56adf8bcc829b010163bb2fc308877e50e9910c9357e78521"},
+ {file = "orjson-3.9.10-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7f8fb7f5ecf4f6355683ac6881fd64b5bb2b8a60e3ccde6ff799e48791d8f864"},
+ {file = "orjson-3.9.10-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c943b35ecdf7123b2d81d225397efddf0bce2e81db2f3ae633ead38e85cd5ade"},
+ {file = "orjson-3.9.10-cp39-none-win32.whl", hash = "sha256:fb0b361d73f6b8eeceba47cd37070b5e6c9de5beaeaa63a1cb35c7e1a73ef088"},
+ {file = "orjson-3.9.10-cp39-none-win_amd64.whl", hash = "sha256:b90f340cb6397ec7a854157fac03f0c82b744abdd1c0941a024c3c29d1340aff"},
+ {file = "orjson-3.9.10.tar.gz", hash = "sha256:9ebbdbd6a046c304b1845e96fbcc5559cd296b4dfd3ad2509e33c4d9ce07d6a1"},
+]
+
+[[package]]
+name = "overrides"
+version = "7.4.0"
+description = "A decorator to automatically detect mismatch when overriding a method."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "overrides-7.4.0-py3-none-any.whl", hash = "sha256:3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d"},
+ {file = "overrides-7.4.0.tar.gz", hash = "sha256:9502a3cca51f4fac40b5feca985b6703a5c1f6ad815588a7ca9e285b9dca6757"},
+]
+
+[[package]]
+name = "packaging"
+version = "23.2"
+description = "Core utilities for Python packages"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
+ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
+]
+
+[[package]]
+name = "pandas"
+version = "2.0.3"
+description = "Powerful data structures for data analysis, time series, and statistics"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"},
+ {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"},
+ {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"},
+ {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"},
+ {file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"},
+ {file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"},
+ {file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"},
+ {file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"},
+ {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"},
+ {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"},
+ {file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"},
+ {file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"},
+ {file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"},
+ {file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"},
+ {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"},
+ {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"},
+ {file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"},
+ {file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"},
+ {file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"},
+ {file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"},
+ {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"},
+ {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"},
+ {file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"},
+ {file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"},
+ {file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"},
+]
+
+[package.dependencies]
+numpy = [
+ {version = ">=1.23.2", markers = "python_version >= \"3.11\""},
+ {version = ">=1.20.3", markers = "python_version < \"3.10\""},
+ {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
+]
+python-dateutil = ">=2.8.2"
+pytz = ">=2020.1"
+tzdata = ">=2022.1"
+
+[package.extras]
+all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"]
+aws = ["s3fs (>=2021.08.0)"]
+clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"]
+compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"]
+computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"]
+excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"]
+feather = ["pyarrow (>=7.0.0)"]
+fss = ["fsspec (>=2021.07.0)"]
+gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"]
+hdf5 = ["tables (>=3.6.1)"]
+html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"]
+mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"]
+output-formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"]
+parquet = ["pyarrow (>=7.0.0)"]
+performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"]
+plot = ["matplotlib (>=3.6.1)"]
+postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"]
+spss = ["pyreadstat (>=1.1.2)"]
+sql-other = ["SQLAlchemy (>=1.4.16)"]
+test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
+xml = ["lxml (>=4.6.3)"]
+
+[[package]]
+name = "pandocfilters"
+version = "1.5.0"
+description = "Utilities for writing pandoc filters in python"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"},
+ {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"},
+]
+
+[[package]]
+name = "para"
+version = "0.0.8"
+description = "a set utilities that ake advantage of python's 'multiprocessing' module to distribute CPU-intensive tasks"
+optional = true
+python-versions = "*"
+files = [
+ {file = "para-0.0.8-py3-none-any.whl", hash = "sha256:c63b030658cafd84f8fabfc000142324d51c7440e50ef5012fd1a54972ca25f4"},
+ {file = "para-0.0.8.tar.gz", hash = "sha256:46c3232ae9d8ea9d886cfd08cdd112892202bed8645f40b6255597ba4cfef217"},
+]
+
+[[package]]
+name = "parso"
+version = "0.8.3"
+description = "A Python Parser"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"},
+ {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"},
+]
+
+[package.extras]
+qa = ["flake8 (==3.8.3)", "mypy (==0.782)"]
+testing = ["docopt", "pytest (<6.0.0)"]
+
+[[package]]
+name = "pdfminer-six"
+version = "20221105"
+description = "PDF parser and analyzer"
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "pdfminer.six-20221105-py3-none-any.whl", hash = "sha256:1eaddd712d5b2732f8ac8486824533514f8ba12a0787b3d5fe1e686cd826532d"},
+ {file = "pdfminer.six-20221105.tar.gz", hash = "sha256:8448ab7b939d18b64820478ecac5394f482d7a79f5f7eaa7703c6c959c175e1d"},
+]
+
+[package.dependencies]
+charset-normalizer = ">=2.0.0"
+cryptography = ">=36.0.0"
+
+[package.extras]
+dev = ["black", "mypy (==0.931)", "nox", "pytest"]
+docs = ["sphinx", "sphinx-argparse"]
+image = ["Pillow"]
+
+[[package]]
+name = "pexpect"
+version = "4.9.0"
+description = "Pexpect allows easy control of interactive console applications."
+optional = false
+python-versions = "*"
+files = [
+ {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"},
+ {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"},
+]
+
+[package.dependencies]
+ptyprocess = ">=0.5"
+
+[[package]]
+name = "pgvector"
+version = "0.1.8"
+description = "pgvector support for Python"
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "pgvector-0.1.8-py2.py3-none-any.whl", hash = "sha256:99dce3a6580ef73863edb9b8441937671f4e1a09383826e6b0838176cd441a96"},
+]
+
+[package.dependencies]
+numpy = "*"
+
+[[package]]
+name = "pickleshare"
+version = "0.7.5"
+description = "Tiny 'shelve'-like database with concurrency support"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"},
+ {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"},
+]
+
+[[package]]
+name = "pillow"
+version = "10.1.0"
+description = "Python Imaging Library (Fork)"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "Pillow-10.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1ab05f3db77e98f93964697c8efc49c7954b08dd61cff526b7f2531a22410106"},
+ {file = "Pillow-10.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6932a7652464746fcb484f7fc3618e6503d2066d853f68a4bd97193a3996e273"},
+ {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f63b5a68daedc54c7c3464508d8c12075e56dcfbd42f8c1bf40169061ae666"},
+ {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0949b55eb607898e28eaccb525ab104b2d86542a85c74baf3a6dc24002edec2"},
+ {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ae88931f93214777c7a3aa0a8f92a683f83ecde27f65a45f95f22d289a69e593"},
+ {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b0eb01ca85b2361b09480784a7931fc648ed8b7836f01fb9241141b968feb1db"},
+ {file = "Pillow-10.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d27b5997bdd2eb9fb199982bb7eb6164db0426904020dc38c10203187ae2ff2f"},
+ {file = "Pillow-10.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7df5608bc38bd37ef585ae9c38c9cd46d7c81498f086915b0f97255ea60c2818"},
+ {file = "Pillow-10.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:41f67248d92a5e0a2076d3517d8d4b1e41a97e2df10eb8f93106c89107f38b57"},
+ {file = "Pillow-10.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1fb29c07478e6c06a46b867e43b0bcdb241b44cc52be9bc25ce5944eed4648e7"},
+ {file = "Pillow-10.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2cdc65a46e74514ce742c2013cd4a2d12e8553e3a2563c64879f7c7e4d28bce7"},
+ {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50d08cd0a2ecd2a8657bd3d82c71efd5a58edb04d9308185d66c3a5a5bed9610"},
+ {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:062a1610e3bc258bff2328ec43f34244fcec972ee0717200cb1425214fe5b839"},
+ {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:61f1a9d247317fa08a308daaa8ee7b3f760ab1809ca2da14ecc88ae4257d6172"},
+ {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a646e48de237d860c36e0db37ecaecaa3619e6f3e9d5319e527ccbc8151df061"},
+ {file = "Pillow-10.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:47e5bf85b80abc03be7455c95b6d6e4896a62f6541c1f2ce77a7d2bb832af262"},
+ {file = "Pillow-10.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a92386125e9ee90381c3369f57a2a50fa9e6aa8b1cf1d9c4b200d41a7dd8e992"},
+ {file = "Pillow-10.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:0f7c276c05a9767e877a0b4c5050c8bee6a6d960d7f0c11ebda6b99746068c2a"},
+ {file = "Pillow-10.1.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:a89b8312d51715b510a4fe9fc13686283f376cfd5abca8cd1c65e4c76e21081b"},
+ {file = "Pillow-10.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:00f438bb841382b15d7deb9a05cc946ee0f2c352653c7aa659e75e592f6fa17d"},
+ {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d929a19f5469b3f4df33a3df2983db070ebb2088a1e145e18facbc28cae5b27"},
+ {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a92109192b360634a4489c0c756364c0c3a2992906752165ecb50544c251312"},
+ {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0248f86b3ea061e67817c47ecbe82c23f9dd5d5226200eb9090b3873d3ca32de"},
+ {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9882a7451c680c12f232a422730f986a1fcd808da0fd428f08b671237237d651"},
+ {file = "Pillow-10.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1c3ac5423c8c1da5928aa12c6e258921956757d976405e9467c5f39d1d577a4b"},
+ {file = "Pillow-10.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:806abdd8249ba3953c33742506fe414880bad78ac25cc9a9b1c6ae97bedd573f"},
+ {file = "Pillow-10.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:eaed6977fa73408b7b8a24e8b14e59e1668cfc0f4c40193ea7ced8e210adf996"},
+ {file = "Pillow-10.1.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:fe1e26e1ffc38be097f0ba1d0d07fcade2bcfd1d023cda5b29935ae8052bd793"},
+ {file = "Pillow-10.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7e3daa202beb61821c06d2517428e8e7c1aab08943e92ec9e5755c2fc9ba5e"},
+ {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24fadc71218ad2b8ffe437b54876c9382b4a29e030a05a9879f615091f42ffc2"},
+ {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa1d323703cfdac2036af05191b969b910d8f115cf53093125e4058f62012c9a"},
+ {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:912e3812a1dbbc834da2b32299b124b5ddcb664ed354916fd1ed6f193f0e2d01"},
+ {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7dbaa3c7de82ef37e7708521be41db5565004258ca76945ad74a8e998c30af8d"},
+ {file = "Pillow-10.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9d7bc666bd8c5a4225e7ac71f2f9d12466ec555e89092728ea0f5c0c2422ea80"},
+ {file = "Pillow-10.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:baada14941c83079bf84c037e2d8b7506ce201e92e3d2fa0d1303507a8538212"},
+ {file = "Pillow-10.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:2ef6721c97894a7aa77723740a09547197533146fba8355e86d6d9a4a1056b14"},
+ {file = "Pillow-10.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0a026c188be3b443916179f5d04548092e253beb0c3e2ee0a4e2cdad72f66099"},
+ {file = "Pillow-10.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:04f6f6149f266a100374ca3cc368b67fb27c4af9f1cc8cb6306d849dcdf12616"},
+ {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb40c011447712d2e19cc261c82655f75f32cb724788df315ed992a4d65696bb"},
+ {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a8413794b4ad9719346cd9306118450b7b00d9a15846451549314a58ac42219"},
+ {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c9aeea7b63edb7884b031a35305629a7593272b54f429a9869a4f63a1bf04c34"},
+ {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b4005fee46ed9be0b8fb42be0c20e79411533d1fd58edabebc0dd24626882cfd"},
+ {file = "Pillow-10.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4d0152565c6aa6ebbfb1e5d8624140a440f2b99bf7afaafbdbf6430426497f28"},
+ {file = "Pillow-10.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d921bc90b1defa55c9917ca6b6b71430e4286fc9e44c55ead78ca1a9f9eba5f2"},
+ {file = "Pillow-10.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cfe96560c6ce2f4c07d6647af2d0f3c54cc33289894ebd88cfbb3bcd5391e256"},
+ {file = "Pillow-10.1.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:937bdc5a7f5343d1c97dc98149a0be7eb9704e937fe3dc7140e229ae4fc572a7"},
+ {file = "Pillow-10.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c25762197144e211efb5f4e8ad656f36c8d214d390585d1d21281f46d556ba"},
+ {file = "Pillow-10.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:afc8eef765d948543a4775f00b7b8c079b3321d6b675dde0d02afa2ee23000b4"},
+ {file = "Pillow-10.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:883f216eac8712b83a63f41b76ddfb7b2afab1b74abbb413c5df6680f071a6b9"},
+ {file = "Pillow-10.1.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:b920e4d028f6442bea9a75b7491c063f0b9a3972520731ed26c83e254302eb1e"},
+ {file = "Pillow-10.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c41d960babf951e01a49c9746f92c5a7e0d939d1652d7ba30f6b3090f27e412"},
+ {file = "Pillow-10.1.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1fafabe50a6977ac70dfe829b2d5735fd54e190ab55259ec8aea4aaea412fa0b"},
+ {file = "Pillow-10.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3b834f4b16173e5b92ab6566f0473bfb09f939ba14b23b8da1f54fa63e4b623f"},
+ {file = "Pillow-10.1.0.tar.gz", hash = "sha256:e6bf8de6c36ed96c86ea3b6e1d5273c53f46ef518a062464cd7ef5dd2cf92e38"},
+]
+
+[package.extras]
+docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"]
+tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"]
+
+[[package]]
+name = "pkgutil-resolve-name"
+version = "1.3.10"
+description = "Resolve a name to an object."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"},
+ {file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"},
+]
+
+[[package]]
+name = "platformdirs"
+version = "4.1.0"
+description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "platformdirs-4.1.0-py3-none-any.whl", hash = "sha256:11c8f37bcca40db96d8144522d925583bdb7a31f7b0e37e3ed4318400a8e2380"},
+ {file = "platformdirs-4.1.0.tar.gz", hash = "sha256:906d548203468492d432bcb294d4bc2fff751bf84971fbb2c10918cc206ee420"},
+]
+
+[package.extras]
+docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"]
+test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"]
+
+[[package]]
+name = "pluggy"
+version = "1.3.0"
+description = "plugin and hook calling mechanisms for python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"},
+ {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"},
+]
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["pytest", "pytest-benchmark"]
+
+[[package]]
+name = "praw"
+version = "7.7.1"
+description = "PRAW, an acronym for \"Python Reddit API Wrapper\", is a Python package that allows for simple access to Reddit's API."
+optional = true
+python-versions = "~=3.7"
+files = [
+ {file = "praw-7.7.1-py3-none-any.whl", hash = "sha256:9ec5dc943db00c175bc6a53f4e089ce625f3fdfb27305564b616747b767d38ef"},
+ {file = "praw-7.7.1.tar.gz", hash = "sha256:f1d7eef414cafe28080dda12ed09253a095a69933d5c8132eca11d4dc8a070bf"},
+]
+
+[package.dependencies]
+prawcore = ">=2.1,<3"
+update-checker = ">=0.18"
+websocket-client = ">=0.54.0"
+
+[package.extras]
+ci = ["coveralls"]
+dev = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "furo", "packaging", "pre-commit", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "sphinx", "urllib3 (==1.26.*)"]
+lint = ["furo", "pre-commit", "sphinx"]
+readthedocs = ["furo", "sphinx"]
+test = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "urllib3 (==1.26.*)"]
+
+[[package]]
+name = "prawcore"
+version = "2.4.0"
+description = "\"Low-level communication layer for PRAW 4+."
+optional = true
+python-versions = "~=3.8"
+files = [
+ {file = "prawcore-2.4.0-py3-none-any.whl", hash = "sha256:29af5da58d85704b439ad3c820873ad541f4535e00bb98c66f0fbcc8c603065a"},
+ {file = "prawcore-2.4.0.tar.gz", hash = "sha256:b7b2b5a1d04406e086ab4e79988dc794df16059862f329f4c6a43ed09986c335"},
+]
+
+[package.dependencies]
+requests = ">=2.6.0,<3.0"
+
+[package.extras]
+ci = ["coveralls"]
+dev = ["packaging", "prawcore[lint]", "prawcore[test]"]
+lint = ["pre-commit", "ruff (>=0.0.291)"]
+test = ["betamax (>=0.8,<0.9)", "pytest (>=2.7.3)", "urllib3 (==1.26.*)"]
+
+[[package]]
+name = "prometheus-client"
+version = "0.19.0"
+description = "Python client for the Prometheus monitoring system."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"},
+ {file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"},
+]
+
+[package.extras]
+twisted = ["twisted"]
+
+[[package]]
+name = "prompt-toolkit"
+version = "3.0.41"
+description = "Library for building powerful interactive command lines in Python"
+optional = false
+python-versions = ">=3.7.0"
+files = [
+ {file = "prompt_toolkit-3.0.41-py3-none-any.whl", hash = "sha256:f36fe301fafb7470e86aaf90f036eef600a3210be4decf461a5b1ca8403d3cb2"},
+ {file = "prompt_toolkit-3.0.41.tar.gz", hash = "sha256:941367d97fc815548822aa26c2a269fdc4eb21e9ec05fc5d447cf09bad5d75f0"},
+]
+
+[package.dependencies]
+wcwidth = "*"
+
+[[package]]
+name = "proto-plus"
+version = "1.22.3"
+description = "Beautiful, Pythonic protocol buffers."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "proto-plus-1.22.3.tar.gz", hash = "sha256:fdcd09713cbd42480740d2fe29c990f7fbd885a67efc328aa8be6ee3e9f76a6b"},
+ {file = "proto_plus-1.22.3-py3-none-any.whl", hash = "sha256:a49cd903bc0b6ab41f76bf65510439d56ca76f868adf0274e738bfdd096894df"},
+]
+
+[package.dependencies]
+protobuf = ">=3.19.0,<5.0.0dev"
+
+[package.extras]
+testing = ["google-api-core[grpc] (>=1.31.5)"]
+
+[[package]]
+name = "protobuf"
+version = "3.19.6"
+description = "Protocol Buffers"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1"},
+ {file = "protobuf-3.19.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11478547958c2dfea921920617eb457bc26867b0d1aa065ab05f35080c5d9eb6"},
+ {file = "protobuf-3.19.6-cp310-cp310-win32.whl", hash = "sha256:559670e006e3173308c9254d63facb2c03865818f22204037ab76f7a0ff70b5f"},
+ {file = "protobuf-3.19.6-cp310-cp310-win_amd64.whl", hash = "sha256:347b393d4dd06fb93a77620781e11c058b3b0a5289262f094379ada2920a3730"},
+ {file = "protobuf-3.19.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a8ce5ae0de28b51dff886fb922012dad885e66176663950cb2344c0439ecb473"},
+ {file = "protobuf-3.19.6-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90b0d02163c4e67279ddb6dc25e063db0130fc299aefabb5d481053509fae5c8"},
+ {file = "protobuf-3.19.6-cp36-cp36m-win32.whl", hash = "sha256:30f5370d50295b246eaa0296533403961f7e64b03ea12265d6dfce3a391d8992"},
+ {file = "protobuf-3.19.6-cp36-cp36m-win_amd64.whl", hash = "sha256:0c0714b025ec057b5a7600cb66ce7c693815f897cfda6d6efb58201c472e3437"},
+ {file = "protobuf-3.19.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5057c64052a1f1dd7d4450e9aac25af6bf36cfbfb3a1cd89d16393a036c49157"},
+ {file = "protobuf-3.19.6-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:bb6776bd18f01ffe9920e78e03a8676530a5d6c5911934c6a1ac6eb78973ecb6"},
+ {file = "protobuf-3.19.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84a04134866861b11556a82dd91ea6daf1f4925746b992f277b84013a7cc1229"},
+ {file = "protobuf-3.19.6-cp37-cp37m-win32.whl", hash = "sha256:4bc98de3cdccfb5cd769620d5785b92c662b6bfad03a202b83799b6ed3fa1fa7"},
+ {file = "protobuf-3.19.6-cp37-cp37m-win_amd64.whl", hash = "sha256:aa3b82ca1f24ab5326dcf4ea00fcbda703e986b22f3d27541654f749564d778b"},
+ {file = "protobuf-3.19.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2b2d2913bcda0e0ec9a784d194bc490f5dc3d9d71d322d070b11a0ade32ff6ba"},
+ {file = "protobuf-3.19.6-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:d0b635cefebd7a8a0f92020562dead912f81f401af7e71f16bf9506ff3bdbb38"},
+ {file = "protobuf-3.19.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a552af4dc34793803f4e735aabe97ffc45962dfd3a237bdde242bff5a3de684"},
+ {file = "protobuf-3.19.6-cp38-cp38-win32.whl", hash = "sha256:0469bc66160180165e4e29de7f445e57a34ab68f49357392c5b2f54c656ab25e"},
+ {file = "protobuf-3.19.6-cp38-cp38-win_amd64.whl", hash = "sha256:91d5f1e139ff92c37e0ff07f391101df77e55ebb97f46bbc1535298d72019462"},
+ {file = "protobuf-3.19.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c0ccd3f940fe7f3b35a261b1dd1b4fc850c8fde9f74207015431f174be5976b3"},
+ {file = "protobuf-3.19.6-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:30a15015d86b9c3b8d6bf78d5b8c7749f2512c29f168ca259c9d7727604d0e39"},
+ {file = "protobuf-3.19.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:878b4cd080a21ddda6ac6d1e163403ec6eea2e206cf225982ae04567d39be7b0"},
+ {file = "protobuf-3.19.6-cp39-cp39-win32.whl", hash = "sha256:5a0d7539a1b1fb7e76bf5faa0b44b30f812758e989e59c40f77a7dab320e79b9"},
+ {file = "protobuf-3.19.6-cp39-cp39-win_amd64.whl", hash = "sha256:bbf5cea5048272e1c60d235c7bd12ce1b14b8a16e76917f371c718bd3005f045"},
+ {file = "protobuf-3.19.6-py2.py3-none-any.whl", hash = "sha256:14082457dc02be946f60b15aad35e9f5c69e738f80ebbc0900a19bc83734a5a4"},
+ {file = "protobuf-3.19.6.tar.gz", hash = "sha256:5f5540d57a43042389e87661c6eaa50f47c19c6176e8cf1c4f287aeefeccb5c4"},
+]
+
+[[package]]
+name = "psutil"
+version = "5.9.6"
+description = "Cross-platform lib for process and system monitoring in Python."
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
+files = [
+ {file = "psutil-5.9.6-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d"},
+ {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c"},
+ {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28"},
+ {file = "psutil-5.9.6-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017"},
+ {file = "psutil-5.9.6-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c"},
+ {file = "psutil-5.9.6-cp27-none-win32.whl", hash = "sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9"},
+ {file = "psutil-5.9.6-cp27-none-win_amd64.whl", hash = "sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac"},
+ {file = "psutil-5.9.6-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a"},
+ {file = "psutil-5.9.6-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c"},
+ {file = "psutil-5.9.6-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4"},
+ {file = "psutil-5.9.6-cp36-cp36m-win32.whl", hash = "sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602"},
+ {file = "psutil-5.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa"},
+ {file = "psutil-5.9.6-cp37-abi3-win32.whl", hash = "sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c"},
+ {file = "psutil-5.9.6-cp37-abi3-win_amd64.whl", hash = "sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a"},
+ {file = "psutil-5.9.6-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57"},
+ {file = "psutil-5.9.6.tar.gz", hash = "sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a"},
+]
+
+[package.extras]
+test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
+
+[[package]]
+name = "psychicapi"
+version = "0.8.4"
+description = "Psychic.dev is an open-source data integration platform for LLMs. This is the Python client for Psychic"
+optional = true
+python-versions = "*"
+files = [
+ {file = "psychicapi-0.8.4-py3-none-any.whl", hash = "sha256:bf0a0ea858a79c8d443565d0d1ae8d7f8c63095bf4fd2bd7723241e46b59bbd4"},
+ {file = "psychicapi-0.8.4.tar.gz", hash = "sha256:18dc3f2e4ab4dbbf6002c39f4ce680fbd7b86253d92403a5e6530ddf07064224"},
+]
+
+[package.dependencies]
+requests = "*"
+
+[[package]]
+name = "psycopg2"
+version = "2.9.9"
+description = "psycopg2 - Python-PostgreSQL Database Adapter"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "psycopg2-2.9.9-cp310-cp310-win32.whl", hash = "sha256:38a8dcc6856f569068b47de286b472b7c473ac7977243593a288ebce0dc89516"},
+ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
+ {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
+ {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
+ {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
+ {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
+ {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
+ {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
+ {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
+ {file = "psycopg2-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:bac58c024c9922c23550af2a581998624d6e02350f4ae9c5f0bc642c633a2d5e"},
+ {file = "psycopg2-2.9.9-cp39-cp39-win32.whl", hash = "sha256:c92811b2d4c9b6ea0285942b2e7cac98a59e166d59c588fe5cfe1eda58e72d59"},
+ {file = "psycopg2-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:de80739447af31525feddeb8effd640782cf5998e1a4e9192ebdf829717e3913"},
+ {file = "psycopg2-2.9.9.tar.gz", hash = "sha256:d1454bde93fb1e224166811694d600e746430c006fbb031ea06ecc2ea41bf156"},
+]
+
+[[package]]
+name = "psycopg2-binary"
+version = "2.9.9"
+description = "psycopg2 - Python-PostgreSQL Database Adapter"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "psycopg2-binary-2.9.9.tar.gz", hash = "sha256:7f01846810177d829c7692f1f5ada8096762d9172af1b1a28d4ab5b77c923c1c"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c2470da5418b76232f02a2fcd2229537bb2d5a7096674ce61859c3229f2eb202"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c6af2a6d4b7ee9615cbb162b0738f6e1fd1f5c3eda7e5da17861eacf4c717ea7"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75723c3c0fbbf34350b46a3199eb50638ab22a0228f93fb472ef4d9becc2382b"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83791a65b51ad6ee6cf0845634859d69a038ea9b03d7b26e703f94c7e93dbcf9"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0ef4854e82c09e84cc63084a9e4ccd6d9b154f1dbdd283efb92ecd0b5e2b8c84"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed1184ab8f113e8d660ce49a56390ca181f2981066acc27cf637d5c1e10ce46e"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d2997c458c690ec2bc6b0b7ecbafd02b029b7b4283078d3b32a852a7ce3ddd98"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b58b4710c7f4161b5e9dcbe73bb7c62d65670a87df7bcce9e1faaad43e715245"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0c009475ee389757e6e34611d75f6e4f05f0cf5ebb76c6037508318e1a1e0d7e"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8dbf6d1bc73f1d04ec1734bae3b4fb0ee3cb2a493d35ede9badbeb901fb40f6f"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-win32.whl", hash = "sha256:3f78fd71c4f43a13d342be74ebbc0666fe1f555b8837eb113cb7416856c79682"},
+ {file = "psycopg2_binary-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:876801744b0dee379e4e3c38b76fc89f88834bb15bf92ee07d94acd06ec890a0"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee825e70b1a209475622f7f7b776785bd68f34af6e7a46e2e42f27b659b5bc26"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1ea665f8ce695bcc37a90ee52de7a7980be5161375d42a0b6c6abedbf0d81f0f"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:143072318f793f53819048fdfe30c321890af0c3ec7cb1dfc9cc87aa88241de2"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c332c8d69fb64979ebf76613c66b985414927a40f8defa16cf1bc028b7b0a7b0"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7fc5a5acafb7d6ccca13bfa8c90f8c51f13d8fb87d95656d3950f0158d3ce53"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977646e05232579d2e7b9c59e21dbe5261f403a88417f6a6512e70d3f8a046be"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b6356793b84728d9d50ead16ab43c187673831e9d4019013f1402c41b1db9b27"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bc7bb56d04601d443f24094e9e31ae6deec9ccb23581f75343feebaf30423359"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:77853062a2c45be16fd6b8d6de2a99278ee1d985a7bd8b103e97e41c034006d2"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:78151aa3ec21dccd5cdef6c74c3e73386dcdfaf19bced944169697d7ac7482fc"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
+ {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e6f98446430fdf41bd36d4faa6cb409f5140c1c2cf58ce0bbdaf16af7d3f119"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c77e3d1862452565875eb31bdb45ac62502feabbd53429fdc39a1cc341d681ba"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
+ {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8359bf4791968c5a78c56103702000105501adb557f3cf772b2c207284273984"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:275ff571376626195ab95a746e6a04c7df8ea34638b99fc11160de91f2fef503"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f9b5571d33660d5009a8b3c25dc1db560206e2d2f89d3df1cb32d72c0d117d52"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:420f9bbf47a02616e8554e825208cb947969451978dceb77f95ad09c37791dae"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:4154ad09dac630a0f13f37b583eae260c6aa885d67dfbccb5b02c33f31a6d420"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a148c5d507bb9b4f2030a2025c545fccb0e1ef317393eaba42e7eabd28eb6041"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:68fc1f1ba168724771e38bee37d940d2865cb0f562380a1fb1ffb428b75cb692"},
+ {file = "psycopg2_binary-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:281309265596e388ef483250db3640e5f414168c5a67e9c665cafce9492eda2f"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:60989127da422b74a04345096c10d416c2b41bd7bf2a380eb541059e4e999980"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:246b123cc54bb5361588acc54218c8c9fb73068bf227a4a531d8ed56fa3ca7d6"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34eccd14566f8fe14b2b95bb13b11572f7c7d5c36da61caf414d23b91fcc5d94"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18d0ef97766055fec15b5de2c06dd8e7654705ce3e5e5eed3b6651a1d2a9a152"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d3f82c171b4ccd83bbaf35aa05e44e690113bd4f3b7b6cc54d2219b132f3ae55"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ead20f7913a9c1e894aebe47cccf9dc834e1618b7aa96155d2091a626e59c972"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ca49a8119c6cbd77375ae303b0cfd8c11f011abbbd64601167ecca18a87e7cdd"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:323ba25b92454adb36fa425dc5cf6f8f19f78948cbad2e7bc6cdf7b0d7982e59"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:1236ed0952fbd919c100bc839eaa4a39ebc397ed1c08a97fc45fee2a595aa1b3"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:729177eaf0aefca0994ce4cffe96ad3c75e377c7b6f4efa59ebf003b6d398716"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-win32.whl", hash = "sha256:804d99b24ad523a1fe18cc707bf741670332f7c7412e9d49cb5eab67e886b9b5"},
+ {file = "psycopg2_binary-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:a6cdcc3ede532f4a4b96000b6362099591ab4a3e913d70bcbac2b56c872446f7"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:72dffbd8b4194858d0941062a9766f8297e8868e1dd07a7b36212aaa90f49472"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:30dcc86377618a4c8f3b72418df92e77be4254d8f89f14b8e8f57d6d43603c0f"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31a34c508c003a4347d389a9e6fcc2307cc2150eb516462a7a17512130de109e"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15208be1c50b99203fe88d15695f22a5bed95ab3f84354c494bcb1d08557df67"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1873aade94b74715be2246321c8650cabf5a0d098a95bab81145ffffa4c13876"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a58c98a7e9c021f357348867f537017057c2ed7f77337fd914d0bedb35dace7"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4686818798f9194d03c9129a4d9a702d9e113a89cb03bffe08c6cf799e053291"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ebdc36bea43063116f0486869652cb2ed7032dbc59fbcb4445c4862b5c1ecf7f"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:ca08decd2697fdea0aea364b370b1249d47336aec935f87b8bbfd7da5b2ee9c1"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ac05fb791acf5e1a3e39402641827780fe44d27e72567a000412c648a85ba860"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-win32.whl", hash = "sha256:9dba73be7305b399924709b91682299794887cbbd88e38226ed9f6712eabee90"},
+ {file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"},
+]
+
+[[package]]
+name = "ptyprocess"
+version = "0.7.0"
+description = "Run a subprocess in a pseudo terminal"
+optional = false
+python-versions = "*"
+files = [
+ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
+ {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"},
+]
+
+[[package]]
+name = "pure-eval"
+version = "0.2.2"
+description = "Safely evaluate AST nodes without side effects"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"},
+ {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"},
+]
+
+[package.extras]
+tests = ["pytest"]
+
+[[package]]
+name = "py-trello"
+version = "0.19.0"
+description = "Python wrapper around the Trello API"
+optional = true
+python-versions = "*"
+files = [
+ {file = "py-trello-0.19.0.tar.gz", hash = "sha256:f4a8c05db61fad0ef5fa35d62c29806c75d9d2b797358d9cf77275e2cbf23020"},
+]
+
+[package.dependencies]
+python-dateutil = "*"
+pytz = "*"
+requests = "*"
+requests-oauthlib = ">=0.4.1"
+
+[[package]]
+name = "py4j"
+version = "0.10.9.7"
+description = "Enables Python programs to dynamically access arbitrary Java objects"
+optional = true
+python-versions = "*"
+files = [
+ {file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"},
+ {file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"},
+]
+
+[[package]]
+name = "pyaes"
+version = "1.6.1"
+description = "Pure-Python Implementation of the AES block-cipher and common modes of operation"
+optional = true
+python-versions = "*"
+files = [
+ {file = "pyaes-1.6.1.tar.gz", hash = "sha256:02c1b1405c38d3c370b085fb952dd8bea3fadcee6411ad99f312cc129c536d8f"},
+]
+
+[[package]]
+name = "pyarrow"
+version = "14.0.1"
+description = "Python library for Apache Arrow"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pyarrow-14.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:96d64e5ba7dceb519a955e5eeb5c9adcfd63f73a56aea4722e2cc81364fc567a"},
+ {file = "pyarrow-14.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a8ae88c0038d1bc362a682320112ee6774f006134cd5afc291591ee4bc06505"},
+ {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f6f053cb66dc24091f5511e5920e45c83107f954a21032feadc7b9e3a8e7851"},
+ {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:906b0dc25f2be12e95975722f1e60e162437023f490dbd80d0deb7375baf3171"},
+ {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:78d4a77a46a7de9388b653af1c4ce539350726cd9af62e0831e4f2bd0c95a2f4"},
+ {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06ca79080ef89d6529bb8e5074d4b4f6086143b2520494fcb7cf8a99079cde93"},
+ {file = "pyarrow-14.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:32542164d905002c42dff896efdac79b3bdd7291b1b74aa292fac8450d0e4dcd"},
+ {file = "pyarrow-14.0.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:c7331b4ed3401b7ee56f22c980608cf273f0380f77d0f73dd3c185f78f5a6220"},
+ {file = "pyarrow-14.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:922e8b49b88da8633d6cac0e1b5a690311b6758d6f5d7c2be71acb0f1e14cd61"},
+ {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58c889851ca33f992ea916b48b8540735055201b177cb0dcf0596a495a667b00"},
+ {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30d8494870d9916bb53b2a4384948491444741cb9a38253c590e21f836b01222"},
+ {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:be28e1a07f20391bb0b15ea03dcac3aade29fc773c5eb4bee2838e9b2cdde0cb"},
+ {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:981670b4ce0110d8dcb3246410a4aabf5714db5d8ea63b15686bce1c914b1f83"},
+ {file = "pyarrow-14.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:4756a2b373a28f6166c42711240643fb8bd6322467e9aacabd26b488fa41ec23"},
+ {file = "pyarrow-14.0.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:cf87e2cec65dd5cf1aa4aba918d523ef56ef95597b545bbaad01e6433851aa10"},
+ {file = "pyarrow-14.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:470ae0194fbfdfbf4a6b65b4f9e0f6e1fa0ea5b90c1ee6b65b38aecee53508c8"},
+ {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6263cffd0c3721c1e348062997babdf0151301f7353010c9c9a8ed47448f82ab"},
+ {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a8089d7e77d1455d529dbd7cff08898bbb2666ee48bc4085203af1d826a33cc"},
+ {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fada8396bc739d958d0b81d291cfd201126ed5e7913cb73de6bc606befc30226"},
+ {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2a145dab9ed7849fc1101bf03bcdc69913547f10513fdf70fc3ab6c0a50c7eee"},
+ {file = "pyarrow-14.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:05fe7994745b634c5fb16ce5717e39a1ac1fac3e2b0795232841660aa76647cd"},
+ {file = "pyarrow-14.0.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:a8eeef015ae69d104c4c3117a6011e7e3ecd1abec79dc87fd2fac6e442f666ee"},
+ {file = "pyarrow-14.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3c76807540989fe8fcd02285dd15e4f2a3da0b09d27781abec3adc265ddbeba1"},
+ {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:450e4605e3c20e558485f9161a79280a61c55efe585d51513c014de9ae8d393f"},
+ {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:323cbe60210173ffd7db78bfd50b80bdd792c4c9daca8843ef3cd70b186649db"},
+ {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0140c7e2b740e08c5a459439d87acd26b747fc408bde0a8806096ee0baaa0c15"},
+ {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:e592e482edd9f1ab32f18cd6a716c45b2c0f2403dc2af782f4e9674952e6dd27"},
+ {file = "pyarrow-14.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:d264ad13605b61959f2ae7c1d25b1a5b8505b112715c961418c8396433f213ad"},
+ {file = "pyarrow-14.0.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:01e44de9749cddc486169cb632f3c99962318e9dacac7778315a110f4bf8a450"},
+ {file = "pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0351fecf0e26e152542bc164c22ea2a8e8c682726fce160ce4d459ea802d69c"},
+ {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33c1f6110c386464fd2e5e4ea3624466055bbe681ff185fd6c9daa98f30a3f9a"},
+ {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11e045dfa09855b6d3e7705a37c42e2dc2c71d608fab34d3c23df2e02df9aec3"},
+ {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:097828b55321897db0e1dbfc606e3ff8101ae5725673498cbfa7754ee0da80e4"},
+ {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1daab52050a1c48506c029e6fa0944a7b2436334d7e44221c16f6f1b2cc9c510"},
+ {file = "pyarrow-14.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:3f6d5faf4f1b0d5a7f97be987cf9e9f8cd39902611e818fe134588ee99bf0283"},
+ {file = "pyarrow-14.0.1.tar.gz", hash = "sha256:b8b3f4fe8d4ec15e1ef9b599b94683c5216adaed78d5cb4c606180546d1e2ee1"},
+]
+
+[package.dependencies]
+numpy = ">=1.16.6"
+
+[[package]]
+name = "pyarrow-hotfix"
+version = "0.6"
+description = ""
+optional = true
+python-versions = ">=3.5"
+files = [
+ {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"},
+ {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"},
+]
+
+[[package]]
+name = "pyasn1"
+version = "0.5.1"
+description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
+files = [
+ {file = "pyasn1-0.5.1-py2.py3-none-any.whl", hash = "sha256:4439847c58d40b1d0a573d07e3856e95333f1976294494c325775aeca506eb58"},
+ {file = "pyasn1-0.5.1.tar.gz", hash = "sha256:6d391a96e59b23130a5cfa74d6fd7f388dbbe26cc8f1edf39fdddf08d9d6676c"},
+]
+
+[[package]]
+name = "pyasn1-modules"
+version = "0.3.0"
+description = "A collection of ASN.1-based protocols modules"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
+files = [
+ {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"},
+ {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.4.6,<0.6.0"
+
+[[package]]
+name = "pycares"
+version = "4.4.0"
+description = "Python interface for c-ares"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pycares-4.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:24da119850841d16996713d9c3374ca28a21deee056d609fbbed29065d17e1f6"},
+ {file = "pycares-4.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8f64cb58729689d4d0e78f0bfb4c25ce2f851d0274c0273ac751795c04b8798a"},
+ {file = "pycares-4.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d33e2a1120887e89075f7f814ec144f66a6ce06a54f5722ccefc62fbeda83cff"},
+ {file = "pycares-4.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c680fef1b502ee680f8f0b95a41af4ec2c234e50e16c0af5bbda31999d3584bd"},
+ {file = "pycares-4.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fff16b09042ba077f7b8aa5868d1d22456f0002574d0ba43462b10a009331677"},
+ {file = "pycares-4.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:229a1675eb33bc9afb1fc463e73ee334950ccc485bc83a43f6ae5839fb4d5fa3"},
+ {file = "pycares-4.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3aebc73e5ad70464f998f77f2da2063aa617cbd8d3e8174dd7c5b4518f967153"},
+ {file = "pycares-4.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6ef64649eba56448f65e26546d85c860709844d2fc22ef14d324fe0b27f761a9"},
+ {file = "pycares-4.4.0-cp310-cp310-win32.whl", hash = "sha256:4afc2644423f4eef97857a9fd61be9758ce5e336b4b0bd3d591238bb4b8b03e0"},
+ {file = "pycares-4.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:5ed4e04af4012f875b78219d34434a6d08a67175150ac1b79eb70ab585d4ba8c"},
+ {file = "pycares-4.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bce8db2fc6f3174bd39b81405210b9b88d7b607d33e56a970c34a0c190da0490"},
+ {file = "pycares-4.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9a0303428d013ccf5c51de59c83f9127aba6200adb7fd4be57eddb432a1edd2a"},
+ {file = "pycares-4.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afb91792f1556f97be7f7acb57dc7756d89c5a87bd8b90363a77dbf9ea653817"},
+ {file = "pycares-4.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b61579cecf1f4d616e5ea31a6e423a16680ab0d3a24a2ffe7bb1d4ee162477ff"},
+ {file = "pycares-4.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7af06968cbf6851566e806bf3e72825b0e6671832a2cbe840be1d2d65350710"},
+ {file = "pycares-4.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ceb12974367b0a68a05d52f4162b29f575d241bd53de155efe632bf2c943c7f6"},
+ {file = "pycares-4.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:2eeec144bcf6a7b6f2d74d6e70cbba7886a84dd373c886f06cb137a07de4954c"},
+ {file = "pycares-4.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e3a6f7cfdfd11eb5493d6d632e582408c8f3b429f295f8799c584c108b28db6f"},
+ {file = "pycares-4.4.0-cp311-cp311-win32.whl", hash = "sha256:34736a2ffaa9c08ca9c707011a2d7b69074bbf82d645d8138bba771479b2362f"},
+ {file = "pycares-4.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:eb66c30eb11e877976b7ead13632082a8621df648c408b8e15cdb91a452dd502"},
+ {file = "pycares-4.4.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fd644505a8cfd7f6584d33a9066d4e3d47700f050ef1490230c962de5dfb28c6"},
+ {file = "pycares-4.4.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52084961262232ec04bd75f5043aed7e5d8d9695e542ff691dfef0110209f2d4"},
+ {file = "pycares-4.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0c5368206057884cde18602580083aeaad9b860e2eac14fd253543158ce1e93"},
+ {file = "pycares-4.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:112a4979c695b1c86f6782163d7dec58d57a3b9510536dcf4826550f9053dd9a"},
+ {file = "pycares-4.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d186dafccdaa3409194c0f94db93c1a5d191145a275f19da6591f9499b8e7b8"},
+ {file = "pycares-4.4.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:64965dc19c578a683ea73487a215a8897276224e004d50eeb21f0bc7a0b63c88"},
+ {file = "pycares-4.4.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:ed2a38e34bec6f2586435f6ff0bc5fe11d14bebd7ed492cf739a424e81681540"},
+ {file = "pycares-4.4.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:94d6962db81541eb0396d2f0dfcbb18cdb8c8b251d165efc2d974ae652c547d4"},
+ {file = "pycares-4.4.0-cp312-cp312-win32.whl", hash = "sha256:1168a48a834813aa80f412be2df4abaf630528a58d15c704857448b20b1675c0"},
+ {file = "pycares-4.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:db24c4e7fea4a052c6e869cbf387dd85d53b9736cfe1ef5d8d568d1ca925e977"},
+ {file = "pycares-4.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:21a5a0468861ec7df7befa69050f952da13db5427ae41ffe4713bc96291d1d95"},
+ {file = "pycares-4.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:22c00bf659a9fa44d7b405cf1cd69b68b9d37537899898d8cbe5dffa4016b273"},
+ {file = "pycares-4.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23aa3993a352491a47fcf17867f61472f32f874df4adcbb486294bd9fbe8abee"},
+ {file = "pycares-4.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:813d661cbe2e37d87da2d16b7110a6860e93ddb11735c6919c8a3545c7b9c8d8"},
+ {file = "pycares-4.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:77cf5a2fd5583c670de41a7f4a7b46e5cbabe7180d8029f728571f4d2e864084"},
+ {file = "pycares-4.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3eaa6681c0a3e3f3868c77aca14b7760fed35fdfda2fe587e15c701950e7bc69"},
+ {file = "pycares-4.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ad58e284a658a8a6a84af2e0b62f2f961f303cedfe551854d7bd40c3cbb61912"},
+ {file = "pycares-4.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bfb89ca9e3d0a9b5332deeb666b2ede9d3469107742158f4aeda5ce032d003f4"},
+ {file = "pycares-4.4.0-cp38-cp38-win32.whl", hash = "sha256:f36bdc1562142e3695555d2f4ac0cb69af165eddcefa98efc1c79495b533481f"},
+ {file = "pycares-4.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:902461a92b6a80fd5041a2ec5235680c7cc35e43615639ec2a40e63fca2dfb51"},
+ {file = "pycares-4.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7bddc6adba8f699728f7fc1c9ce8cef359817ad78e2ed52b9502cb5f8dc7f741"},
+ {file = "pycares-4.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cb49d5805cd347c404f928c5ae7c35e86ba0c58ffa701dbe905365e77ce7d641"},
+ {file = "pycares-4.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56cf3349fa3a2e67ed387a7974c11d233734636fe19facfcda261b411af14d80"},
+ {file = "pycares-4.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8bf2eaa83a5987e48fa63302f0fe7ce3275cfda87b34d40fef9ce703fb3ac002"},
+ {file = "pycares-4.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82bba2ab77eb5addbf9758d514d9bdef3c1bfe7d1649a47bd9a0d55a23ef478b"},
+ {file = "pycares-4.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c6a8bde63106f162fca736e842a916853cad3c8d9d137e11c9ffa37efa818b02"},
+ {file = "pycares-4.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f5f646eec041db6ffdbcaf3e0756fb92018f7af3266138c756bb09d2b5baadec"},
+ {file = "pycares-4.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9dc04c54c6ea615210c1b9e803d0e2d2255f87a3d5d119b6482c8f0dfa15b26b"},
+ {file = "pycares-4.4.0-cp39-cp39-win32.whl", hash = "sha256:97892cced5794d721fb4ff8765764aa4ea48fe8b2c3820677505b96b83d4ef47"},
+ {file = "pycares-4.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:917f08f0b5d9324e9a34211e68d27447c552b50ab967044776bbab7e42a553a2"},
+ {file = "pycares-4.4.0.tar.gz", hash = "sha256:f47579d508f2f56eddd16ce72045782ad3b1b3b678098699e2b6a1b30733e1c2"},
+]
+
+[package.dependencies]
+cffi = ">=1.5.0"
+
+[package.extras]
+idna = ["idna (>=2.1)"]
+
+[[package]]
+name = "pyclipper"
+version = "1.3.0.post5"
+description = "Cython wrapper for the C++ translation of the Angus Johnson's Clipper library (ver. 6.4.2)"
+optional = true
+python-versions = "*"
+files = [
+ {file = "pyclipper-1.3.0.post5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3c45f99b8180dd4df4c86642657ca92b7d5289a5e3724521822e0f9461961fe2"},
+ {file = "pyclipper-1.3.0.post5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:567ffd419a0bdc3727fa4562cfa1f18484691817a2bc0bc675750aa28ed98bd4"},
+ {file = "pyclipper-1.3.0.post5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:59c8c75661a6d87e98b1655851578a2917d3c8859912c9a4f1956b9830940fd9"},
+ {file = "pyclipper-1.3.0.post5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a496efa146d2d88b59350021739e4685e439dc569b6654e9e6d5e42e9a0b1666"},
+ {file = "pyclipper-1.3.0.post5-cp310-cp310-win32.whl", hash = "sha256:02a98d09af9b60bcf8e9480d153c0839e20b92689f5602f87242a4933842fecd"},
+ {file = "pyclipper-1.3.0.post5-cp310-cp310-win_amd64.whl", hash = "sha256:847f1e2fc3994bb498fe675f55c98129b95dc26a5c92304ba4cf0ab40721ea3d"},
+ {file = "pyclipper-1.3.0.post5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b7a983ae019932bfa0a1971a2dc8c856704add5f3d567bed8fac02dbc0e7f0bf"},
+ {file = "pyclipper-1.3.0.post5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8760075c395b924f894aa16ee06e8c040c6f9b63e0903e49de3cc8d82d9e637"},
+ {file = "pyclipper-1.3.0.post5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4ea61ca5899d3346c614951342c506f119601ed0a1f4889a9cc236558afec6b"},
+ {file = "pyclipper-1.3.0.post5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46499b361ae067662b22578401d83d57716f3cc0071d592feb07d504b439fea7"},
+ {file = "pyclipper-1.3.0.post5-cp311-cp311-win32.whl", hash = "sha256:d5c77e39ab05a6cf277c819639968b21e6959e996ea1a074afc24236541708ff"},
+ {file = "pyclipper-1.3.0.post5-cp311-cp311-win_amd64.whl", hash = "sha256:0f78a1c18ff4f9276f78d9353d6ed4309c3886a9d0172437e48328aef499165e"},
+ {file = "pyclipper-1.3.0.post5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5237282f906049c307e6c90333c7d56f6b8712bf087ef97b141830c40b09ca0a"},
+ {file = "pyclipper-1.3.0.post5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aca8635573646b65c054399433fb3493637f1445db942de8a52fca9ef493ba3d"},
+ {file = "pyclipper-1.3.0.post5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1158a2b13d59bdfab33d1d928f7b72c8c7fb8a76e7d2283839cb45d7c0ff2140"},
+ {file = "pyclipper-1.3.0.post5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a041f1a7982b17cf92fd3be349ec41ff1901792149c166bf283f469567b52d6"},
+ {file = "pyclipper-1.3.0.post5-cp312-cp312-win32.whl", hash = "sha256:bf3a2ccd6e4e078250b0a31a12c519b0be6d1bc160acfceee62407dbd68558f6"},
+ {file = "pyclipper-1.3.0.post5-cp312-cp312-win_amd64.whl", hash = "sha256:2ce6e0a6ab32182c26537965cf521822cd11a28a7ffcef48635a94c6ca8559ef"},
+ {file = "pyclipper-1.3.0.post5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:010ee13d40d924341cc41b6d9901d763175040c68753939f140bc0cc714f18bb"},
+ {file = "pyclipper-1.3.0.post5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee1c4797b1dc982ae9d60333269536ea03ddc0baa1c3383a6d5b741dbbb12675"},
+ {file = "pyclipper-1.3.0.post5-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:ba692cf11873886085a0445dcfc362b24ca35bcb997ad9e9b5685854a290d8ff"},
+ {file = "pyclipper-1.3.0.post5-cp36-cp36m-win32.whl", hash = "sha256:f0b84fcf5230aca2de06ddb7920459daa858853835f8774739ca30dd516e7d37"},
+ {file = "pyclipper-1.3.0.post5-cp36-cp36m-win_amd64.whl", hash = "sha256:741910bfd7b0bd40f027869f4bf86bdd9678ae7f74e8dabcf62d170269f6191d"},
+ {file = "pyclipper-1.3.0.post5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5f3484b4dffa64f0e3a43b63165a5c0f507c5850e70b9cc2eaa82474d7746393"},
+ {file = "pyclipper-1.3.0.post5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87efec9795744cef786f2f8cab17d6dc07f57dfce5e3b7f3be96eb79a4ce5794"},
+ {file = "pyclipper-1.3.0.post5-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:5f445a2d03690faa23a1b90e32dfb4352a60b23437323de87388c6c611d3d1e3"},
+ {file = "pyclipper-1.3.0.post5-cp37-cp37m-win32.whl", hash = "sha256:eb9d1cb2999bc1ea8ad1c3a031ba33b0a89a5ace25d33df7529d3ff18c16604c"},
+ {file = "pyclipper-1.3.0.post5-cp37-cp37m-win_amd64.whl", hash = "sha256:ead0f3ecd1961005f61d50c896e33442138b4e7c9e0c035784d3525068dd2b10"},
+ {file = "pyclipper-1.3.0.post5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:39ccd920b192a4f8096589a2a1f8faaf6aaaadb7a163b5ce913d03faac2449bb"},
+ {file = "pyclipper-1.3.0.post5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e346e7adba43e40f5f5f293b6b6a45de5a6a3bdc74e437dedd948c5d74de9405"},
+ {file = "pyclipper-1.3.0.post5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb2fb22927c3ac3191e555efd335c6efa819aa1ff4d0901979673ab5a18eb740"},
+ {file = "pyclipper-1.3.0.post5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a678999d728023f1f3988a14a2e6d89d6f1ed4d0786d5992c1bffb4c1ab30318"},
+ {file = "pyclipper-1.3.0.post5-cp38-cp38-win32.whl", hash = "sha256:36d456fdf32a6410a87bd7af8ebc4c01f19b4e3b839104b3072558cad0d8bf4c"},
+ {file = "pyclipper-1.3.0.post5-cp38-cp38-win_amd64.whl", hash = "sha256:c9c1fdf4ecae6b55033ede3f4e931156ffc969334300f44f8bf1b356ec0a3d63"},
+ {file = "pyclipper-1.3.0.post5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8bb9cd95fd4bd88fb1590d1763a52e3ea6a1095e11b3e885ff164da1313aae79"},
+ {file = "pyclipper-1.3.0.post5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0f516fd69aa61a9698a3ce3ba2f7edda5ac6aafc8d964ee3bc60897906947fcb"},
+ {file = "pyclipper-1.3.0.post5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e36f018303656ea4a629d2fba0d0d4c74960eacec7119fe2ab3c658ce84c494b"},
+ {file = "pyclipper-1.3.0.post5-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:dd3c4b312a931e668a7a291d4bd5b10bacb0687bd163220a9f0418c7e23169e2"},
+ {file = "pyclipper-1.3.0.post5-cp39-cp39-win32.whl", hash = "sha256:cfea42972e90954b3c89da9216993373a2270a5103d4916fd543a1109528ed4c"},
+ {file = "pyclipper-1.3.0.post5-cp39-cp39-win_amd64.whl", hash = "sha256:85ca06f382f999903d809380e4c01ec127d3eb26431402e9b3f01facaec68b80"},
+ {file = "pyclipper-1.3.0.post5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:da30e59c684eea198f6e19244e9a41e855a23a416cc708821fd4eb8f5f18626c"},
+ {file = "pyclipper-1.3.0.post5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d8a9e3e46aa50e4c3667db9a816d59ae4f9c62b05f997abb8a9b3f3afe6d94a4"},
+ {file = "pyclipper-1.3.0.post5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0589b80f2da1ad322345a93c053b5d46dc692def5a188351be01f34bcf041218"},
+ {file = "pyclipper-1.3.0.post5.tar.gz", hash = "sha256:c0239f928e0bf78a3efc2f2f615a10bfcdb9f33012d46d64c8d1225b4bde7096"},
+]
+
+[[package]]
+name = "pycparser"
+version = "2.21"
+description = "C parser in Python"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
+ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
+]
+
+[[package]]
+name = "pydantic"
+version = "1.10.13"
+description = "Data validation and settings management using python type hints"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pydantic-1.10.13-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:efff03cc7a4f29d9009d1c96ceb1e7a70a65cfe86e89d34e4a5f2ab1e5693737"},
+ {file = "pydantic-1.10.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3ecea2b9d80e5333303eeb77e180b90e95eea8f765d08c3d278cd56b00345d01"},
+ {file = "pydantic-1.10.13-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1740068fd8e2ef6eb27a20e5651df000978edce6da6803c2bef0bc74540f9548"},
+ {file = "pydantic-1.10.13-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84bafe2e60b5e78bc64a2941b4c071a4b7404c5c907f5f5a99b0139781e69ed8"},
+ {file = "pydantic-1.10.13-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bc0898c12f8e9c97f6cd44c0ed70d55749eaf783716896960b4ecce2edfd2d69"},
+ {file = "pydantic-1.10.13-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:654db58ae399fe6434e55325a2c3e959836bd17a6f6a0b6ca8107ea0571d2e17"},
+ {file = "pydantic-1.10.13-cp310-cp310-win_amd64.whl", hash = "sha256:75ac15385a3534d887a99c713aa3da88a30fbd6204a5cd0dc4dab3d770b9bd2f"},
+ {file = "pydantic-1.10.13-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c553f6a156deb868ba38a23cf0df886c63492e9257f60a79c0fd8e7173537653"},
+ {file = "pydantic-1.10.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5e08865bc6464df8c7d61439ef4439829e3ab62ab1669cddea8dd00cd74b9ffe"},
+ {file = "pydantic-1.10.13-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e31647d85a2013d926ce60b84f9dd5300d44535a9941fe825dc349ae1f760df9"},
+ {file = "pydantic-1.10.13-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:210ce042e8f6f7c01168b2d84d4c9eb2b009fe7bf572c2266e235edf14bacd80"},
+ {file = "pydantic-1.10.13-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8ae5dd6b721459bfa30805f4c25880e0dd78fc5b5879f9f7a692196ddcb5a580"},
+ {file = "pydantic-1.10.13-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f8e81fc5fb17dae698f52bdd1c4f18b6ca674d7068242b2aff075f588301bbb0"},
+ {file = "pydantic-1.10.13-cp311-cp311-win_amd64.whl", hash = "sha256:61d9dce220447fb74f45e73d7ff3b530e25db30192ad8d425166d43c5deb6df0"},
+ {file = "pydantic-1.10.13-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4b03e42ec20286f052490423682016fd80fda830d8e4119f8ab13ec7464c0132"},
+ {file = "pydantic-1.10.13-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f59ef915cac80275245824e9d771ee939133be38215555e9dc90c6cb148aaeb5"},
+ {file = "pydantic-1.10.13-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a1f9f747851338933942db7af7b6ee8268568ef2ed86c4185c6ef4402e80ba8"},
+ {file = "pydantic-1.10.13-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:97cce3ae7341f7620a0ba5ef6cf043975cd9d2b81f3aa5f4ea37928269bc1b87"},
+ {file = "pydantic-1.10.13-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:854223752ba81e3abf663d685f105c64150873cc6f5d0c01d3e3220bcff7d36f"},
+ {file = "pydantic-1.10.13-cp37-cp37m-win_amd64.whl", hash = "sha256:b97c1fac8c49be29486df85968682b0afa77e1b809aff74b83081cc115e52f33"},
+ {file = "pydantic-1.10.13-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c958d053453a1c4b1c2062b05cd42d9d5c8eb67537b8d5a7e3c3032943ecd261"},
+ {file = "pydantic-1.10.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c5370a7edaac06daee3af1c8b1192e305bc102abcbf2a92374b5bc793818599"},
+ {file = "pydantic-1.10.13-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d6f6e7305244bddb4414ba7094ce910560c907bdfa3501e9db1a7fd7eaea127"},
+ {file = "pydantic-1.10.13-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3a3c792a58e1622667a2837512099eac62490cdfd63bd407993aaf200a4cf1f"},
+ {file = "pydantic-1.10.13-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c636925f38b8db208e09d344c7aa4f29a86bb9947495dd6b6d376ad10334fb78"},
+ {file = "pydantic-1.10.13-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:678bcf5591b63cc917100dc50ab6caebe597ac67e8c9ccb75e698f66038ea953"},
+ {file = "pydantic-1.10.13-cp38-cp38-win_amd64.whl", hash = "sha256:6cf25c1a65c27923a17b3da28a0bdb99f62ee04230c931d83e888012851f4e7f"},
+ {file = "pydantic-1.10.13-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8ef467901d7a41fa0ca6db9ae3ec0021e3f657ce2c208e98cd511f3161c762c6"},
+ {file = "pydantic-1.10.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:968ac42970f57b8344ee08837b62f6ee6f53c33f603547a55571c954a4225691"},
+ {file = "pydantic-1.10.13-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9849f031cf8a2f0a928fe885e5a04b08006d6d41876b8bbd2fc68a18f9f2e3fd"},
+ {file = "pydantic-1.10.13-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56e3ff861c3b9c6857579de282ce8baabf443f42ffba355bf070770ed63e11e1"},
+ {file = "pydantic-1.10.13-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f00790179497767aae6bcdc36355792c79e7bbb20b145ff449700eb076c5f96"},
+ {file = "pydantic-1.10.13-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:75b297827b59bc229cac1a23a2f7a4ac0031068e5be0ce385be1462e7e17a35d"},
+ {file = "pydantic-1.10.13-cp39-cp39-win_amd64.whl", hash = "sha256:e70ca129d2053fb8b728ee7d1af8e553a928d7e301a311094b8a0501adc8763d"},
+ {file = "pydantic-1.10.13-py3-none-any.whl", hash = "sha256:b87326822e71bd5f313e7d3bfdc77ac3247035ac10b0c0618bd99dcf95b1e687"},
+ {file = "pydantic-1.10.13.tar.gz", hash = "sha256:32c8b48dcd3b2ac4e78b0ba4af3a2c2eb6048cb75202f0ea7b34feb740efc340"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.2.0"
+
+[package.extras]
+dotenv = ["python-dotenv (>=0.10.4)"]
+email = ["email-validator (>=1.0.3)"]
+
+[[package]]
+name = "pydeck"
+version = "0.8.0"
+description = "Widget for deck.gl maps"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "pydeck-0.8.0-py2.py3-none-any.whl", hash = "sha256:a8fa7757c6f24bba033af39db3147cb020eef44012ba7e60d954de187f9ed4d5"},
+ {file = "pydeck-0.8.0.tar.gz", hash = "sha256:07edde833f7cfcef6749124351195aa7dcd24663d4909fd7898dbd0b6fbc01ec"},
+]
+
+[package.dependencies]
+jinja2 = ">=2.10.1"
+numpy = ">=1.16.4"
+
+[package.extras]
+carto = ["pydeck-carto"]
+jupyter = ["ipykernel (>=5.1.2)", "ipython (>=5.8.0)", "ipywidgets (>=7,<8)", "traitlets (>=4.3.2)"]
+
+[[package]]
+name = "pygments"
+version = "2.17.2"
+description = "Pygments is a syntax highlighting package written in Python."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"},
+ {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"},
+]
+
+[package.extras]
+plugins = ["importlib-metadata"]
+windows-terminal = ["colorama (>=0.4.6)"]
+
+[[package]]
+name = "pyjwt"
+version = "2.8.0"
+description = "JSON Web Token implementation in Python"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
+ {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
+]
+
+[package.dependencies]
+cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""}
+
+[package.extras]
+crypto = ["cryptography (>=3.4.0)"]
+dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
+docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
+tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
+
+[[package]]
+name = "pymongo"
+version = "4.6.1"
+description = "Python driver for MongoDB "
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "pymongo-4.6.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4344c30025210b9fa80ec257b0e0aab5aa1d5cca91daa70d82ab97b482cc038e"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux1_i686.whl", hash = "sha256:1c5654bb8bb2bdb10e7a0bc3c193dd8b49a960b9eebc4381ff5a2043f4c3c441"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:eaf2f65190c506def2581219572b9c70b8250615dc918b3b7c218361a51ec42e"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux2014_i686.whl", hash = "sha256:262356ea5fcb13d35fb2ab6009d3927bafb9504ef02339338634fffd8a9f1ae4"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux2014_ppc64le.whl", hash = "sha256:2dd2f6960ee3c9360bed7fb3c678be0ca2d00f877068556785ec2eb6b73d2414"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux2014_s390x.whl", hash = "sha256:ff925f1cca42e933376d09ddc254598f8c5fcd36efc5cac0118bb36c36217c41"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3cadf7f4c8e94d8a77874b54a63c80af01f4d48c4b669c8b6867f86a07ba994f"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55dac73316e7e8c2616ba2e6f62b750918e9e0ae0b2053699d66ca27a7790105"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:154b361dcb358ad377d5d40df41ee35f1cc14c8691b50511547c12404f89b5cb"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2940aa20e9cc328e8ddeacea8b9a6f5ddafe0b087fedad928912e787c65b4909"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:010bc9aa90fd06e5cc52c8fac2c2fd4ef1b5f990d9638548dde178005770a5e8"},
+ {file = "pymongo-4.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e470fa4bace5f50076c32f4b3cc182b31303b4fefb9b87f990144515d572820b"},
+ {file = "pymongo-4.6.1-cp310-cp310-win32.whl", hash = "sha256:da08ea09eefa6b960c2dd9a68ec47949235485c623621eb1d6c02b46765322ac"},
+ {file = "pymongo-4.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:13d613c866f9f07d51180f9a7da54ef491d130f169e999c27e7633abe8619ec9"},
+ {file = "pymongo-4.6.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6a0ae7a48a6ef82ceb98a366948874834b86c84e288dbd55600c1abfc3ac1d88"},
+ {file = "pymongo-4.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bd94c503271e79917b27c6e77f7c5474da6930b3fb9e70a12e68c2dff386b9a"},
+ {file = "pymongo-4.6.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2d4ccac3053b84a09251da8f5350bb684cbbf8c8c01eda6b5418417d0a8ab198"},
+ {file = "pymongo-4.6.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:349093675a2d3759e4fb42b596afffa2b2518c890492563d7905fac503b20daa"},
+ {file = "pymongo-4.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88beb444fb438385e53dc9110852910ec2a22f0eab7dd489e827038fdc19ed8d"},
+ {file = "pymongo-4.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8e62d06e90f60ea2a3d463ae51401475568b995bafaffd81767d208d84d7bb1"},
+ {file = "pymongo-4.6.1-cp311-cp311-win32.whl", hash = "sha256:5556e306713e2522e460287615d26c0af0fe5ed9d4f431dad35c6624c5d277e9"},
+ {file = "pymongo-4.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:b10d8cda9fc2fcdcfa4a000aa10413a2bf8b575852cd07cb8a595ed09689ca98"},
+ {file = "pymongo-4.6.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b435b13bb8e36be11b75f7384a34eefe487fe87a6267172964628e2b14ecf0a7"},
+ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e438417ce1dc5b758742e12661d800482200b042d03512a8f31f6aaa9137ad40"},
+ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8b47ebd89e69fbf33d1c2df79759d7162fc80c7652dacfec136dae1c9b3afac7"},
+ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bbed8cccebe1169d45cedf00461b2842652d476d2897fd1c42cf41b635d88746"},
+ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c30a9e06041fbd7a7590693ec5e407aa8737ad91912a1e70176aff92e5c99d20"},
+ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
+ {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
+ {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:026a24a36394dc8930cbcb1d19d5eb35205ef3c838a7e619e04bd170713972e7"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:3b287e814a01deddb59b88549c1e0c87cefacd798d4afc0c8bd6042d1c3d48aa"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:9a710c184ba845afb05a6f876edac8f27783ba70e52d5eaf939f121fc13b2f59"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:30b2c9caf3e55c2e323565d1f3b7e7881ab87db16997dc0cbca7c52885ed2347"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff62ba8ff70f01ab4fe0ae36b2cb0b5d1f42e73dfc81ddf0758cd9f77331ad25"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:547dc5d7f834b1deefda51aedb11a7af9c51c45e689e44e14aa85d44147c7657"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1de3c6faf948f3edd4e738abdb4b76572b4f4fdfc1fed4dad02427e70c5a6219"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2831e05ce0a4df10c4ac5399ef50b9a621f90894c2a4d2945dc5658765514ed"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144a31391a39a390efce0c5ebcaf4bf112114af4384c90163f402cec5ede476b"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:33bb16a07d3cc4e0aea37b242097cd5f7a156312012455c2fa8ca396953b11c4"},
+ {file = "pymongo-4.6.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b7b1a83ce514700276a46af3d9e481ec381f05b64939effc9065afe18456a6b9"},
+ {file = "pymongo-4.6.1-cp37-cp37m-win32.whl", hash = "sha256:3071ec998cc3d7b4944377e5f1217c2c44b811fae16f9a495c7a1ce9b42fb038"},
+ {file = "pymongo-4.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:2346450a075625c4d6166b40a013b605a38b6b6168ce2232b192a37fb200d588"},
+ {file = "pymongo-4.6.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:061598cbc6abe2f382ab64c9caa83faa2f4c51256f732cdd890bcc6e63bfb67e"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:d483793a384c550c2d12cb794ede294d303b42beff75f3b3081f57196660edaf"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f9756f1d25454ba6a3c2f1ef8b7ddec23e5cdeae3dc3c3377243ae37a383db00"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:1ed23b0e2dac6f84f44c8494fbceefe6eb5c35db5c1099f56ab78fc0d94ab3af"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:3d18a9b9b858ee140c15c5bfcb3e66e47e2a70a03272c2e72adda2482f76a6ad"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:c258dbacfff1224f13576147df16ce3c02024a0d792fd0323ac01bed5d3c545d"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:f7acc03a4f1154ba2643edeb13658d08598fe6e490c3dd96a241b94f09801626"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:76013fef1c9cd1cd00d55efde516c154aa169f2bf059b197c263a255ba8a9ddf"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f0e6a6c807fa887a0c51cc24fe7ea51bb9e496fe88f00d7930063372c3664c3"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dd1fa413f8b9ba30140de198e4f408ffbba6396864c7554e0867aa7363eb58b2"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d219b4508f71d762368caec1fc180960569766049bbc4d38174f05e8ef2fe5b"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27b81ecf18031998ad7db53b960d1347f8f29e8b7cb5ea7b4394726468e4295e"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56816e43c92c2fa8c11dc2a686f0ca248bea7902f4a067fa6cbc77853b0f041e"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef801027629c5b511cf2ba13b9be29bfee36ae834b2d95d9877818479cdc99ea"},
+ {file = "pymongo-4.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d4c2be9760b112b1caf649b4977b81b69893d75aa86caf4f0f398447be871f3c"},
+ {file = "pymongo-4.6.1-cp38-cp38-win32.whl", hash = "sha256:39d77d8bbb392fa443831e6d4ae534237b1f4eee6aa186f0cdb4e334ba89536e"},
+ {file = "pymongo-4.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:4497d49d785482cc1a44a0ddf8830b036a468c088e72a05217f5b60a9e025012"},
+ {file = "pymongo-4.6.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:69247f7a2835fc0984bbf0892e6022e9a36aec70e187fcfe6cae6a373eb8c4de"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:7bb0e9049e81def6829d09558ad12d16d0454c26cabe6efc3658e544460688d9"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:6a1810c2cbde714decf40f811d1edc0dae45506eb37298fd9d4247b8801509fe"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e2aced6fb2f5261b47d267cb40060b73b6527e64afe54f6497844c9affed5fd0"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:d0355cff58a4ed6d5e5f6b9c3693f52de0784aa0c17119394e2a8e376ce489d4"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:3c74f4725485f0a7a3862cfd374cc1b740cebe4c133e0c1425984bcdcce0f4bb"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:9c79d597fb3a7c93d7c26924db7497eba06d58f88f58e586aa69b2ad89fee0f8"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:8ec75f35f62571a43e31e7bd11749d974c1b5cd5ea4a8388725d579263c0fdf6"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5e641f931c5cd95b376fd3c59db52770e17bec2bf86ef16cc83b3906c054845"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9aafd036f6f2e5ad109aec92f8dbfcbe76cff16bad683eb6dd18013739c0b3ae"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f2b856518bfcfa316c8dae3d7b412aecacf2e8ba30b149f5eb3b63128d703b9"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec31adc2e988fd7db3ab509954791bbc5a452a03c85e45b804b4bfc31fa221d"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9167e735379ec43d8eafa3fd675bfbb12e2c0464f98960586e9447d2cf2c7a83"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1461199b07903fc1424709efafe379205bf5f738144b1a50a08b0396357b5abf"},
+ {file = "pymongo-4.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:3094c7d2f820eecabadae76bfec02669567bbdd1730eabce10a5764778564f7b"},
+ {file = "pymongo-4.6.1-cp39-cp39-win32.whl", hash = "sha256:c91ea3915425bd4111cb1b74511cdc56d1d16a683a48bf2a5a96b6a6c0f297f7"},
+ {file = "pymongo-4.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef102a67ede70e1721fe27f75073b5314911dbb9bc27cde0a1c402a11531e7bd"},
+ {file = "pymongo-4.6.1.tar.gz", hash = "sha256:31dab1f3e1d0cdd57e8df01b645f52d43cc1b653ed3afd535d2891f4fc4f9712"},
+]
+
+[package.dependencies]
+dnspython = ">=1.16.0,<3.0.0"
+
+[package.extras]
+aws = ["pymongo-auth-aws (<2.0.0)"]
+encryption = ["certifi", "pymongo[aws]", "pymongocrypt (>=1.6.0,<2.0.0)"]
+gssapi = ["pykerberos", "winkerberos (>=0.5.0)"]
+ocsp = ["certifi", "cryptography (>=2.5)", "pyopenssl (>=17.2.0)", "requests (<3.0.0)", "service-identity (>=18.1.0)"]
+snappy = ["python-snappy"]
+test = ["pytest (>=7)"]
+zstd = ["zstandard"]
+
+[[package]]
+name = "pympler"
+version = "1.0.1"
+description = "A development tool to measure, monitor and analyze the memory behavior of Python objects."
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "Pympler-1.0.1-py3-none-any.whl", hash = "sha256:d260dda9ae781e1eab6ea15bacb84015849833ba5555f141d2d9b7b7473b307d"},
+ {file = "Pympler-1.0.1.tar.gz", hash = "sha256:993f1a3599ca3f4fcd7160c7545ad06310c9e12f70174ae7ae8d4e25f6c5d3fa"},
+]
+
+[[package]]
+name = "pymupdf"
+version = "1.23.7"
+description = "A high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "PyMuPDF-1.23.7-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:74982a3e0186f5525c2f090863f75b593994fd7e0e10c4f2605159b800f3ca0b"},
+ {file = "PyMuPDF-1.23.7-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c996b11e015027638296d5923d53559de0493f146b3ca0bab76b3ee0db0bc6eb"},
+ {file = "PyMuPDF-1.23.7-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:3a6c0cfe9686edfe96e885bdec995588b6cdf78e69b1588a5d61e60756cfe824"},
+ {file = "PyMuPDF-1.23.7-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:a164a71e5d02bb31f136e5afbf4048c47c93125f8fac0eedf6a868bc5f40c675"},
+ {file = "PyMuPDF-1.23.7-cp310-none-win32.whl", hash = "sha256:a5eaf107d23c4b1281cfbe189dae634d2e749c20ca3d3bf3d162cf2357c5024b"},
+ {file = "PyMuPDF-1.23.7-cp310-none-win_amd64.whl", hash = "sha256:bdb2423bccb07218a42b1dcdfeb91a96ce001e872263fb545132000cd087bda0"},
+ {file = "PyMuPDF-1.23.7-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:ea9b4ed48509faa6b6f8834401c586c3780f4dcd28cdc3013a3d12bcb2153aa0"},
+ {file = "PyMuPDF-1.23.7-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:c27f5916623704da30608ac54c7880fe44b8f9f7a9c9fc6332e216599c536db9"},
+ {file = "PyMuPDF-1.23.7-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:ce5f0d37a5086d7601c775919963f1677342c644d7ad00e92f6b56b8ec48b667"},
+ {file = "PyMuPDF-1.23.7-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:dfca3952b285747f9d84b57f97902e2a33fd80f0556557c55b2da358da38e48c"},
+ {file = "PyMuPDF-1.23.7-cp311-none-win32.whl", hash = "sha256:bb302a798332260870cc6540bab28530b2ecd57447b3ce464da1b501dc1813b8"},
+ {file = "PyMuPDF-1.23.7-cp311-none-win_amd64.whl", hash = "sha256:8afbfe6c771cec7f28cdf8f460b92d973d233a42712a87e24cee225d88aaf1f5"},
+ {file = "PyMuPDF-1.23.7-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:cc18057a83e06871e6242ad39bcede65c53aa8d135d267edb05711ffee9e669a"},
+ {file = "PyMuPDF-1.23.7-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:ad743d866749a399ef61086b2b6985d3212bd985fd972d55a288e9b53a73dd98"},
+ {file = "PyMuPDF-1.23.7-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:a4e145914f929338dd9648f03b8cf9a8baba86c00410e5874dce8282fbd6b6ed"},
+ {file = "PyMuPDF-1.23.7-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:e4c1c3c75297d986da7266c6c39aee7b30783445468f58cb1b9659872f905cd8"},
+ {file = "PyMuPDF-1.23.7-cp312-none-win32.whl", hash = "sha256:7abb49faee62ddacb8b6dc4bbab3e9a3cb35d8782f2c461b42d178ff4af63da2"},
+ {file = "PyMuPDF-1.23.7-cp312-none-win_amd64.whl", hash = "sha256:59fe0f0c1d2e8d9ab678cf4c937e64bcaf551602ee7d8c80dc489c92ddb3cfe2"},
+ {file = "PyMuPDF-1.23.7-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:b487d49fc79a45e005cd06840f9c5f348b1aa85329d9e35c4eb924d7ae19c9b2"},
+ {file = "PyMuPDF-1.23.7-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c13cbb6bd7549814877cc5e4b0063090b9e4029063dd90e68b43541205508fe2"},
+ {file = "PyMuPDF-1.23.7-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:4fa6028040028be35bedadc18c16892e4d298319f8c7f071d5305b0ab84a0121"},
+ {file = "PyMuPDF-1.23.7-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d7ba9293c5c828cc1c10bee07f375ec3d032950cf0dba3514a4a93bae347d83b"},
+ {file = "PyMuPDF-1.23.7-cp38-none-win32.whl", hash = "sha256:d53c1d06989b32e5fce62d55dee59c6e534d5ed289fee37f5af3e0b009b63677"},
+ {file = "PyMuPDF-1.23.7-cp38-none-win_amd64.whl", hash = "sha256:43f00c7713124e36db2feca737ad9228d283d5b2ca3e01643b40af636a095cc9"},
+ {file = "PyMuPDF-1.23.7-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:130ae62ba55ced20eb35088968fd158651b66a510b60b25fcd8d62b58633dd02"},
+ {file = "PyMuPDF-1.23.7-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:835d9f922c3a6612cd202aaa6387ef83741f6ce1bb1c50b814298b27072fea69"},
+ {file = "PyMuPDF-1.23.7-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:c82093000ae12b5c6e9334b272da37f280968e33b4fcd122169af6f9abb71b0e"},
+ {file = "PyMuPDF-1.23.7-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:9b4024876ea72d1d3172c6adfc5cf69eb397ed8f773f36d0b7ed9fb88b134ace"},
+ {file = "PyMuPDF-1.23.7-cp39-none-win32.whl", hash = "sha256:96284e9d5a28ed3125355d129fe6a20c2223da861bc8527188e55608f06cbdf0"},
+ {file = "PyMuPDF-1.23.7-cp39-none-win_amd64.whl", hash = "sha256:365f772d7e32ff1f7bb3ee4cb502d71d5919566b61c3d9c350d1a61c5c5b3073"},
+ {file = "PyMuPDF-1.23.7.tar.gz", hash = "sha256:53b7c03a2f179943fadcb723440ef5832b5f60aa39fc1505ff37cafa209c63ea"},
+]
+
+[package.dependencies]
+PyMuPDFb = "1.23.7"
+
+[[package]]
+name = "pymupdfb"
+version = "1.23.7"
+description = "MuPDF shared libraries for PyMuPDF."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "PyMuPDFb-1.23.7-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:3fddd302121a2109c31d0b2d554ef4afc426b67baa60221daf1bc277951ae4ef"},
+ {file = "PyMuPDFb-1.23.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:aef672f303691904c8951f811f5de3e2ba09d1804571a7f002145ed535cedbdd"},
+ {file = "PyMuPDFb-1.23.7-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ca5a93fac4777f1d2de61fec2e0b96cf649c75bd60bc44f6b6547f8aaccb8a70"},
+ {file = "PyMuPDFb-1.23.7-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adb43972f75500fae50279919d589a49b91ed7a74ec03e811c5000727dd63cea"},
+ {file = "PyMuPDFb-1.23.7-py3-none-win32.whl", hash = "sha256:f65e6dbf48daa2348ae708d76ed8310cc5eb9fc78eb335c5cade5dcaa3d52979"},
+ {file = "PyMuPDFb-1.23.7-py3-none-win_amd64.whl", hash = "sha256:7552793efa6976574b8b7840fd0091773c410e6048bc7cbf4b2eb3ed92d0b7a5"},
+]
+
+[[package]]
+name = "pyparsing"
+version = "3.1.1"
+description = "pyparsing module - Classes and methods to define and execute parsing grammars"
+optional = true
+python-versions = ">=3.6.8"
+files = [
+ {file = "pyparsing-3.1.1-py3-none-any.whl", hash = "sha256:32c7c0b711493c72ff18a981d24f28aaf9c1fb7ed5e9667c9e84e3db623bdbfb"},
+ {file = "pyparsing-3.1.1.tar.gz", hash = "sha256:ede28a1a32462f5a9705e07aea48001a08f7cf81a021585011deba701581a0db"},
+]
+
+[package.extras]
+diagrams = ["jinja2", "railroad-diagrams"]
+
+[[package]]
+name = "pypdf"
+version = "3.17.1"
+description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files"
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "pypdf-3.17.1-py3-none-any.whl", hash = "sha256:df3a7e90f1d3e4c9fe88a6b45c2ae58e61fe48a0fe0bc6de1544596e479a3f97"},
+ {file = "pypdf-3.17.1.tar.gz", hash = "sha256:c79ad4db16c9a86071a3556fb5d619022b36b8880ba3ef416558ea95fbec4cb9"},
+]
+
+[package.dependencies]
+typing_extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.10\""}
+
+[package.extras]
+crypto = ["PyCryptodome", "cryptography"]
+dev = ["black", "flit", "pip-tools", "pre-commit (<2.18.0)", "pytest-cov", "pytest-socket", "pytest-timeout", "pytest-xdist", "wheel"]
+docs = ["myst_parser", "sphinx", "sphinx_rtd_theme"]
+full = ["Pillow (>=8.0.0)", "PyCryptodome", "cryptography"]
+image = ["Pillow (>=8.0.0)"]
+
+[[package]]
+name = "pypdfium2"
+version = "4.24.0"
+description = "Python bindings to PDFium"
+optional = true
+python-versions = ">= 3.6"
+files = [
+ {file = "pypdfium2-4.24.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:94b722c4dddbd858d62fe4df3192651f9376f1c99e7c2bc74d7d8c8d06362bf3"},
+ {file = "pypdfium2-4.24.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c2891aa6059acf9bdabccb7aa193f111ebf96fabae3fb968f04ec925d710ec95"},
+ {file = "pypdfium2-4.24.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:147d101686af8828fbaeb3ac3fd82114f0900d58a24e80eff96496fd89fd9d2d"},
+ {file = "pypdfium2-4.24.0-py3-none-manylinux_2_17_armv7l.whl", hash = "sha256:60c7d9c442aff40d30dbf044ffb67cdc5eb56acca59ac640bc3adad77fc4d781"},
+ {file = "pypdfium2-4.24.0-py3-none-manylinux_2_17_i686.whl", hash = "sha256:025553c8b3633b32e2ef0e9ec9ee07be4a4fda76519889607ad3283090eef7f1"},
+ {file = "pypdfium2-4.24.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:3b26ad59ebef92edfcb44400838edce2e299a9709fe472742a4800251b30e5c9"},
+ {file = "pypdfium2-4.24.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:4034e6b4bde7cb6d281898c43ccb9a5522e25edb1e24689bf89fc7eb2a0c9a15"},
+ {file = "pypdfium2-4.24.0-py3-none-musllinux_1_1_i686.whl", hash = "sha256:a1776dde55b55d81e18026cf746274c1e2959bc8ed2f502a997401e1f0e7c3c1"},
+ {file = "pypdfium2-4.24.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:39a1e1cc02bc39233f742b8cdf60a81e5b4868bfee28ec79457e37e3d41304e6"},
+ {file = "pypdfium2-4.24.0-py3-none-win32.whl", hash = "sha256:7556801f2b42c91590e3f862034ab61e30e732b09e1487b0cf1a3c5250cb29d4"},
+ {file = "pypdfium2-4.24.0-py3-none-win_amd64.whl", hash = "sha256:fa65834fbc6540114ceaebc5e9ca90c5455b0ebedaaaf6c2c8351c851ada366b"},
+ {file = "pypdfium2-4.24.0-py3-none-win_arm64.whl", hash = "sha256:9333304e289fa727fbeae6dab793a9bacb68375184e14ad3d38a65d9a7490be1"},
+ {file = "pypdfium2-4.24.0.tar.gz", hash = "sha256:62706c06bc5be39aa7a2531af802420429b6c4c47498eebd2521af7e988d0848"},
+]
+
+[[package]]
+name = "pyproj"
+version = "3.5.0"
+description = "Python interface to PROJ (cartographic projections and coordinate transformations library)"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pyproj-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6475ce653880938468a1a1b7321267243909e34b972ba9e53d5982c41d555918"},
+ {file = "pyproj-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:61e4ad57d89b03a7b173793b31bca8ee110112cde1937ef0f42a70b9120c827d"},
+ {file = "pyproj-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bdd2021bb6f7f346bfe1d2a358aa109da017d22c4704af2d994e7c7ee0a7a53"},
+ {file = "pyproj-3.5.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5674923351e76222e2c10c58b5e1ac119d7a46b270d822c463035971b06f724b"},
+ {file = "pyproj-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd5e2b6aa255023c4acd0b977590f1f7cc801ba21b4d806fcf6dfac3474ebb83"},
+ {file = "pyproj-3.5.0-cp310-cp310-win32.whl", hash = "sha256:6f316a66031a14e9c5a88c91f8b77aa97f5454895674541ed6ab630b682be35d"},
+ {file = "pyproj-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:f7c2f4d9681e810cf40239caaca00079930a6d9ee6591139b88d592d36051d82"},
+ {file = "pyproj-3.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7572983134e310e0ca809c63f1722557a040fe9443df5f247bf11ba887eb1229"},
+ {file = "pyproj-3.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:eccb417b91d0be27805dfc97550bfb8b7db94e9fe1db5ebedb98f5b88d601323"},
+ {file = "pyproj-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:621d78a9d8bf4d06e08bef2471021fbcb1a65aa629ad4a20c22e521ce729cc20"},
+ {file = "pyproj-3.5.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d9a024370e917c899bff9171f03ea6079deecdc7482a146a2c565f3b9df134ea"},
+ {file = "pyproj-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b7c2113c4d11184a238077ec85e31eda1dcc58ffeb9a4429830e0a7036e787d"},
+ {file = "pyproj-3.5.0-cp311-cp311-win32.whl", hash = "sha256:a730f5b4c98c8a0f312437873e6e34dbd4cc6dc23d5afd91a6691c62724b1f68"},
+ {file = "pyproj-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:e97573de0ab3bbbcb4c7748bc41f4ceb6da10b45d35b1a294b5820701e7c25f0"},
+ {file = "pyproj-3.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2b708fd43453b985642b737d4a6e7f1d6a0ab1677ffa4e14cc258537b49224b0"},
+ {file = "pyproj-3.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b60d93a200639e8367c6542a964fd0aa2dbd152f256c1831dc18cd5aa470fb8a"},
+ {file = "pyproj-3.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38862fe07316ae12b79d82d298e390973a4f00b684f3c2d037238e20e00610ba"},
+ {file = "pyproj-3.5.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71b65f2a38cd9e16883dbb0f8ae82bdf8f6b79b1b02975c78483ab8428dbbf2f"},
+ {file = "pyproj-3.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b752b7d9c4b08181c7e8c0d9c7f277cbefff42227f34d3310696a87c863d9dd3"},
+ {file = "pyproj-3.5.0-cp38-cp38-win32.whl", hash = "sha256:b937215bfbaf404ec8f03ca741fc3f9f2c4c2c5590a02ccddddd820ae3c71331"},
+ {file = "pyproj-3.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:97ed199033c2c770e7eea2ef80ff5e6413426ec2d7ec985b869792f04ab95d05"},
+ {file = "pyproj-3.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:052c49fce8b5d55943a35c36ccecb87350c68b48ba95bc02a789770c374ef819"},
+ {file = "pyproj-3.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1507138ea28bf2134d31797675380791cc1a7156a3aeda484e65a78a4aba9b62"},
+ {file = "pyproj-3.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c02742ef3d846401861a878a61ef7ad911ea7539d6cc4619ddb52dbdf7b45aee"},
+ {file = "pyproj-3.5.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:385b0341861d3ebc8cad98337a738821dcb548d465576527399f4955ca24b6ed"},
+ {file = "pyproj-3.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fe6bb1b68a35d07378d38be77b5b2f8dd2bea5910c957bfcc7bee55988d3910"},
+ {file = "pyproj-3.5.0-cp39-cp39-win32.whl", hash = "sha256:5c4b85ac10d733c42d73a2e6261c8d6745bf52433a31848dd1b6561c9a382da3"},
+ {file = "pyproj-3.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:1798ff7d65d9057ebb2d017ffe8403268b8452f24d0428b2140018c25c7fa1bc"},
+ {file = "pyproj-3.5.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d711517a8487ef3245b08dc82f781a906df9abb3b6cb0ce0486f0eeb823ca570"},
+ {file = "pyproj-3.5.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:788a5dadb532644a64efe0f5f01bf508c821eb7e984f13a677d56002f1e8a67a"},
+ {file = "pyproj-3.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73f7960a97225812f9b1d7aeda5fb83812f38de9441e3476fcc8abb3e2b2f4de"},
+ {file = "pyproj-3.5.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fde5ece4d2436b5a57c8f5f97b49b5de06a856d03959f836c957d3e609f2de7e"},
+ {file = "pyproj-3.5.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e08db25b61cf024648d55973cc3d1c3f1d0818fabf594d5f5a8e2318103d2aa0"},
+ {file = "pyproj-3.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a87b419a2a352413fbf759ecb66da9da50bd19861c8f26db6a25439125b27b9"},
+ {file = "pyproj-3.5.0.tar.gz", hash = "sha256:9859d1591c1863414d875ae0759e72c2cffc01ab989dc64137fbac572cc81bf6"},
+]
+
+[package.dependencies]
+certifi = "*"
+
+[[package]]
+name = "pyreadline3"
+version = "3.4.1"
+description = "A python implementation of GNU readline."
+optional = true
+python-versions = "*"
+files = [
+ {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"},
+ {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"},
+]
+
+[[package]]
+name = "pyspark"
+version = "3.5.0"
+description = "Apache Spark Python API"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pyspark-3.5.0.tar.gz", hash = "sha256:d41a9b76bd2aca370a6100d075c029e22ba44c5940927877e9435a3a9c566558"},
+]
+
+[package.dependencies]
+py4j = "0.10.9.7"
+
+[package.extras]
+connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.56.0)", "grpcio-status (>=1.56.0)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"]
+ml = ["numpy (>=1.15)"]
+mllib = ["numpy (>=1.15)"]
+pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"]
+sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"]
+
+[[package]]
+name = "pytest"
+version = "7.4.3"
+description = "pytest: simple powerful testing with Python"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"},
+ {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
+iniconfig = "*"
+packaging = "*"
+pluggy = ">=0.12,<2.0"
+tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
+
+[package.extras]
+testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+
+[[package]]
+name = "pytest-asyncio"
+version = "0.20.3"
+description = "Pytest support for asyncio"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-asyncio-0.20.3.tar.gz", hash = "sha256:83cbf01169ce3e8eb71c6c278ccb0574d1a7a3bb8eaaf5e50e0ad342afb33b36"},
+ {file = "pytest_asyncio-0.20.3-py3-none-any.whl", hash = "sha256:f129998b209d04fcc65c96fc85c11e5316738358909a8399e93be553d7656442"},
+]
+
+[package.dependencies]
+pytest = ">=6.1.0"
+
+[package.extras]
+docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
+testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
+
+[[package]]
+name = "pytest-cov"
+version = "4.1.0"
+description = "Pytest plugin for measuring coverage."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"},
+ {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"},
+]
+
+[package.dependencies]
+coverage = {version = ">=5.2.1", extras = ["toml"]}
+pytest = ">=4.6"
+
+[package.extras]
+testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
+
+[[package]]
+name = "pytest-dotenv"
+version = "0.5.2"
+description = "A py.test plugin that parses environment files before running tests"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"},
+ {file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"},
+]
+
+[package.dependencies]
+pytest = ">=5.0.0"
+python-dotenv = ">=0.9.1"
+
+[[package]]
+name = "pytest-mock"
+version = "3.12.0"
+description = "Thin-wrapper around the mock package for easier use with pytest"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"},
+ {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"},
+]
+
+[package.dependencies]
+pytest = ">=5.0"
+
+[package.extras]
+dev = ["pre-commit", "pytest-asyncio", "tox"]
+
+[[package]]
+name = "pytest-socket"
+version = "0.6.0"
+description = "Pytest Plugin to disable socket calls during tests"
+optional = false
+python-versions = ">=3.7,<4.0"
+files = [
+ {file = "pytest_socket-0.6.0-py3-none-any.whl", hash = "sha256:cca72f134ff01e0023c402e78d31b32e68da3efdf3493bf7788f8eba86a6824c"},
+ {file = "pytest_socket-0.6.0.tar.gz", hash = "sha256:363c1d67228315d4fc7912f1aabfd570de29d0e3db6217d61db5728adacd7138"},
+]
+
+[package.dependencies]
+pytest = ">=3.6.3"
+
+[[package]]
+name = "pytest-vcr"
+version = "1.0.2"
+description = "Plugin for managing VCR.py cassettes"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytest-vcr-1.0.2.tar.gz", hash = "sha256:23ee51b75abbcc43d926272773aae4f39f93aceb75ed56852d0bf618f92e1896"},
+ {file = "pytest_vcr-1.0.2-py2.py3-none-any.whl", hash = "sha256:2f316e0539399bea0296e8b8401145c62b6f85e9066af7e57b6151481b0d6d9c"},
+]
+
+[package.dependencies]
+pytest = ">=3.6.0"
+vcrpy = "*"
+
+[[package]]
+name = "pytest-watcher"
+version = "0.2.6"
+description = "Continiously runs pytest on changes in *.py files"
+optional = false
+python-versions = ">=3.7.0,<4.0.0"
+files = [
+ {file = "pytest-watcher-0.2.6.tar.gz", hash = "sha256:351dfb3477366030ff275bfbfc9f29bee35cd07f16a3355b38bf92766886bae4"},
+ {file = "pytest_watcher-0.2.6-py3-none-any.whl", hash = "sha256:0a507159d051c9461790363e0f9b2827c1d82ad2ae8966319598695e485b1dd5"},
+]
+
+[package.dependencies]
+watchdog = ">=2.0.0"
+
+[[package]]
+name = "python-dateutil"
+version = "2.8.2"
+description = "Extensions to the standard Python datetime module"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
+files = [
+ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
+ {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
+]
+
+[package.dependencies]
+six = ">=1.5"
+
+[[package]]
+name = "python-dotenv"
+version = "1.0.0"
+description = "Read key-value pairs from a .env file and set them as environment variables"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"},
+ {file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"},
+]
+
+[package.extras]
+cli = ["click (>=5.0)"]
+
+[[package]]
+name = "python-json-logger"
+version = "2.0.7"
+description = "A python library adding a json log formatter"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "python-json-logger-2.0.7.tar.gz", hash = "sha256:23e7ec02d34237c5aa1e29a070193a4ea87583bb4e7f8fd06d3de8264c4b2e1c"},
+ {file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"},
+]
+
+[[package]]
+name = "pytz"
+version = "2023.3.post1"
+description = "World timezone definitions, modern and historical"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytz-2023.3.post1-py2.py3-none-any.whl", hash = "sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7"},
+ {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
+]
+
+[[package]]
+name = "pywin32"
+version = "306"
+description = "Python for Window Extensions"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"},
+ {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"},
+ {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"},
+ {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"},
+ {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"},
+ {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"},
+ {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"},
+ {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"},
+ {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"},
+ {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"},
+ {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"},
+ {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"},
+ {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"},
+ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"},
+]
+
+[[package]]
+name = "pywinpty"
+version = "2.0.12"
+description = "Pseudo terminal support for Windows from Python."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pywinpty-2.0.12-cp310-none-win_amd64.whl", hash = "sha256:21319cd1d7c8844fb2c970fb3a55a3db5543f112ff9cfcd623746b9c47501575"},
+ {file = "pywinpty-2.0.12-cp311-none-win_amd64.whl", hash = "sha256:853985a8f48f4731a716653170cd735da36ffbdc79dcb4c7b7140bce11d8c722"},
+ {file = "pywinpty-2.0.12-cp312-none-win_amd64.whl", hash = "sha256:1617b729999eb6713590e17665052b1a6ae0ad76ee31e60b444147c5b6a35dca"},
+ {file = "pywinpty-2.0.12-cp38-none-win_amd64.whl", hash = "sha256:189380469ca143d06e19e19ff3fba0fcefe8b4a8cc942140a6b863aed7eebb2d"},
+ {file = "pywinpty-2.0.12-cp39-none-win_amd64.whl", hash = "sha256:7520575b6546db23e693cbd865db2764097bd6d4ef5dc18c92555904cd62c3d4"},
+ {file = "pywinpty-2.0.12.tar.gz", hash = "sha256:8197de460ae8ebb7f5d1701dfa1b5df45b157bb832e92acba316305e18ca00dd"},
+]
+
+[[package]]
+name = "pyyaml"
+version = "6.0.1"
+description = "YAML parser and emitter for Python"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
+ {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
+ {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
+ {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
+ {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
+ {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
+ {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
+ {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
+ {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
+ {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
+ {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
+ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
+]
+
+[[package]]
+name = "pyzmq"
+version = "25.1.1"
+description = "Python bindings for 0MQ"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"},
+ {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"},
+ {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:985bbb1316192b98f32e25e7b9958088431d853ac63aca1d2c236f40afb17c83"},
+ {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:afea96f64efa98df4da6958bae37f1cbea7932c35878b185e5982821bc883369"},
+ {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76705c9325d72a81155bb6ab48d4312e0032bf045fb0754889133200f7a0d849"},
+ {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:77a41c26205d2353a4c94d02be51d6cbdf63c06fbc1295ea57dad7e2d3381b71"},
+ {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:12720a53e61c3b99d87262294e2b375c915fea93c31fc2336898c26d7aed34cd"},
+ {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:57459b68e5cd85b0be8184382cefd91959cafe79ae019e6b1ae6e2ba8a12cda7"},
+ {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:292fe3fc5ad4a75bc8df0dfaee7d0babe8b1f4ceb596437213821f761b4589f9"},
+ {file = "pyzmq-25.1.1-cp310-cp310-win32.whl", hash = "sha256:35b5ab8c28978fbbb86ea54958cd89f5176ce747c1fb3d87356cf698048a7790"},
+ {file = "pyzmq-25.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:11baebdd5fc5b475d484195e49bae2dc64b94a5208f7c89954e9e354fc609d8f"},
+ {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:d20a0ddb3e989e8807d83225a27e5c2eb2260eaa851532086e9e0fa0d5287d83"},
+ {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e1c1be77bc5fb77d923850f82e55a928f8638f64a61f00ff18a67c7404faf008"},
+ {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d89528b4943d27029a2818f847c10c2cecc79fa9590f3cb1860459a5be7933eb"},
+ {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90f26dc6d5f241ba358bef79be9ce06de58d477ca8485e3291675436d3827cf8"},
+ {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2b92812bd214018e50b6380ea3ac0c8bb01ac07fcc14c5f86a5bb25e74026e9"},
+ {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2f957ce63d13c28730f7fd6b72333814221c84ca2421298f66e5143f81c9f91f"},
+ {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:047a640f5c9c6ade7b1cc6680a0e28c9dd5a0825135acbd3569cc96ea00b2505"},
+ {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7f7e58effd14b641c5e4dec8c7dab02fb67a13df90329e61c869b9cc607ef752"},
+ {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c2910967e6ab16bf6fbeb1f771c89a7050947221ae12a5b0b60f3bca2ee19bca"},
+ {file = "pyzmq-25.1.1-cp311-cp311-win32.whl", hash = "sha256:76c1c8efb3ca3a1818b837aea423ff8a07bbf7aafe9f2f6582b61a0458b1a329"},
+ {file = "pyzmq-25.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:44e58a0554b21fc662f2712814a746635ed668d0fbc98b7cb9d74cb798d202e6"},
+ {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:e1ffa1c924e8c72778b9ccd386a7067cddf626884fd8277f503c48bb5f51c762"},
+ {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1af379b33ef33757224da93e9da62e6471cf4a66d10078cf32bae8127d3d0d4a"},
+ {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cff084c6933680d1f8b2f3b4ff5bbb88538a4aac00d199ac13f49d0698727ecb"},
+ {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2400a94f7dd9cb20cd012951a0cbf8249e3d554c63a9c0cdfd5cbb6c01d2dec"},
+ {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d81f1ddae3858b8299d1da72dd7d19dd36aab654c19671aa8a7e7fb02f6638a"},
+ {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:255ca2b219f9e5a3a9ef3081512e1358bd4760ce77828e1028b818ff5610b87b"},
+ {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a882ac0a351288dd18ecae3326b8a49d10c61a68b01419f3a0b9a306190baf69"},
+ {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:724c292bb26365659fc434e9567b3f1adbdb5e8d640c936ed901f49e03e5d32e"},
+ {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ca1ed0bb2d850aa8471387882247c68f1e62a4af0ce9c8a1dbe0d2bf69e41fb"},
+ {file = "pyzmq-25.1.1-cp312-cp312-win32.whl", hash = "sha256:b3451108ab861040754fa5208bca4a5496c65875710f76789a9ad27c801a0075"},
+ {file = "pyzmq-25.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:eadbefd5e92ef8a345f0525b5cfd01cf4e4cc651a2cffb8f23c0dd184975d787"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:db0b2af416ba735c6304c47f75d348f498b92952f5e3e8bff449336d2728795d"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7c133e93b405eb0d36fa430c94185bdd13c36204a8635470cccc200723c13bb"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:273bc3959bcbff3f48606b28229b4721716598d76b5aaea2b4a9d0ab454ec062"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cbc8df5c6a88ba5ae385d8930da02201165408dde8d8322072e3e5ddd4f68e22"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:18d43df3f2302d836f2a56f17e5663e398416e9dd74b205b179065e61f1a6edf"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:73461eed88a88c866656e08f89299720a38cb4e9d34ae6bf5df6f71102570f2e"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:34c850ce7976d19ebe7b9d4b9bb8c9dfc7aac336c0958e2651b88cbd46682123"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-win32.whl", hash = "sha256:d2045d6d9439a0078f2a34b57c7b18c4a6aef0bee37f22e4ec9f32456c852c71"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:458dea649f2f02a0b244ae6aef8dc29325a2810aa26b07af8374dc2a9faf57e3"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7cff25c5b315e63b07a36f0c2bab32c58eafbe57d0dce61b614ef4c76058c115"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1579413ae492b05de5a6174574f8c44c2b9b122a42015c5292afa4be2507f28"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3d0a409d3b28607cc427aa5c30a6f1e4452cc44e311f843e05edb28ab5e36da0"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:21eb4e609a154a57c520e3d5bfa0d97e49b6872ea057b7c85257b11e78068222"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:034239843541ef7a1aee0c7b2cb7f6aafffb005ede965ae9cbd49d5ff4ff73cf"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f8115e303280ba09f3898194791a153862cbf9eef722ad8f7f741987ee2a97c7"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1a5d26fe8f32f137e784f768143728438877d69a586ddeaad898558dc971a5ae"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-win32.whl", hash = "sha256:f32260e556a983bc5c7ed588d04c942c9a8f9c2e99213fec11a031e316874c7e"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:abf34e43c531bbb510ae7e8f5b2b1f2a8ab93219510e2b287a944432fad135f3"},
+ {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:87e34f31ca8f168c56d6fbf99692cc8d3b445abb5bfd08c229ae992d7547a92a"},
+ {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c9c6c9b2c2f80747a98f34ef491c4d7b1a8d4853937bb1492774992a120f475d"},
+ {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5619f3f5a4db5dbb572b095ea3cb5cc035335159d9da950830c9c4db2fbb6995"},
+ {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5a34d2395073ef862b4032343cf0c32a712f3ab49d7ec4f42c9661e0294d106f"},
+ {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f0e6b78220aba09815cd1f3a32b9c7cb3e02cb846d1cfc526b6595f6046618"},
+ {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3669cf8ee3520c2f13b2e0351c41fea919852b220988d2049249db10046a7afb"},
+ {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2d163a18819277e49911f7461567bda923461c50b19d169a062536fffe7cd9d2"},
+ {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:df27ffddff4190667d40de7beba4a950b5ce78fe28a7dcc41d6f8a700a80a3c0"},
+ {file = "pyzmq-25.1.1-cp38-cp38-win32.whl", hash = "sha256:a382372898a07479bd34bda781008e4a954ed8750f17891e794521c3e21c2e1c"},
+ {file = "pyzmq-25.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:52533489f28d62eb1258a965f2aba28a82aa747202c8fa5a1c7a43b5db0e85c1"},
+ {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:03b3f49b57264909aacd0741892f2aecf2f51fb053e7d8ac6767f6c700832f45"},
+ {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:330f9e188d0d89080cde66dc7470f57d1926ff2fb5576227f14d5be7ab30b9fa"},
+ {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2ca57a5be0389f2a65e6d3bb2962a971688cbdd30b4c0bd188c99e39c234f414"},
+ {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d457aed310f2670f59cc5b57dcfced452aeeed77f9da2b9763616bd57e4dbaae"},
+ {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c56d748ea50215abef7030c72b60dd723ed5b5c7e65e7bc2504e77843631c1a6"},
+ {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f03d3f0d01cb5a018debeb412441996a517b11c5c17ab2001aa0597c6d6882c"},
+ {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:820c4a08195a681252f46926de10e29b6bbf3e17b30037bd4250d72dd3ddaab8"},
+ {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17ef5f01d25b67ca8f98120d5fa1d21efe9611604e8eb03a5147360f517dd1e2"},
+ {file = "pyzmq-25.1.1-cp39-cp39-win32.whl", hash = "sha256:04ccbed567171579ec2cebb9c8a3e30801723c575601f9a990ab25bcac6b51e2"},
+ {file = "pyzmq-25.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:e61f091c3ba0c3578411ef505992d356a812fb200643eab27f4f70eed34a29ef"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ade6d25bb29c4555d718ac6d1443a7386595528c33d6b133b258f65f963bb0f6"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0c95ddd4f6e9fca4e9e3afaa4f9df8552f0ba5d1004e89ef0a68e1f1f9807c7"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48e466162a24daf86f6b5ca72444d2bf39a5e58da5f96370078be67c67adc978"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abc719161780932c4e11aaebb203be3d6acc6b38d2f26c0f523b5b59d2fc1996"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ccf825981640b8c34ae54231b7ed00271822ea1c6d8ba1090ebd4943759abf5"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c2f20ce161ebdb0091a10c9ca0372e023ce24980d0e1f810f519da6f79c60800"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:deee9ca4727f53464daf089536e68b13e6104e84a37820a88b0a057b97bba2d2"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aa8d6cdc8b8aa19ceb319aaa2b660cdaccc533ec477eeb1309e2a291eaacc43a"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019e59ef5c5256a2c7378f2fb8560fc2a9ff1d315755204295b2eab96b254d0a"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b9af3757495c1ee3b5c4e945c1df7be95562277c6e5bccc20a39aec50f826cd0"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:548d6482dc8aadbe7e79d1b5806585c8120bafa1ef841167bc9090522b610fa6"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:057e824b2aae50accc0f9a0570998adc021b372478a921506fddd6c02e60308e"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2243700cc5548cff20963f0ca92d3e5e436394375ab8a354bbea2b12911b20b0"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79986f3b4af059777111409ee517da24a529bdbd46da578b33f25580adcff728"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:11d58723d44d6ed4dd677c5615b2ffb19d5c426636345567d6af82be4dff8a55"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:49d238cf4b69652257db66d0c623cd3e09b5d2e9576b56bc067a396133a00d4a"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fedbdc753827cf014c01dbbee9c3be17e5a208dcd1bf8641ce2cd29580d1f0d4"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc16ac425cc927d0a57d242589f87ee093884ea4804c05a13834d07c20db203c"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11c1d2aed9079c6b0c9550a7257a836b4a637feb334904610f06d70eb44c56d2"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e8a701123029cc240cea61dd2d16ad57cab4691804143ce80ecd9286b464d180"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:61706a6b6c24bdece85ff177fec393545a3191eeda35b07aaa1458a027ad1304"},
+ {file = "pyzmq-25.1.1.tar.gz", hash = "sha256:259c22485b71abacdfa8bf79720cd7bcf4b9d128b30ea554f01ae71fdbfdaa23"},
+]
+
+[package.dependencies]
+cffi = {version = "*", markers = "implementation_name == \"pypy\""}
+
+[[package]]
+name = "qtconsole"
+version = "5.5.1"
+description = "Jupyter Qt console"
+optional = false
+python-versions = ">= 3.8"
+files = [
+ {file = "qtconsole-5.5.1-py3-none-any.whl", hash = "sha256:8c75fa3e9b4ed884880ff7cea90a1b67451219279ec33deaee1d59e3df1a5d2b"},
+ {file = "qtconsole-5.5.1.tar.gz", hash = "sha256:a0e806c6951db9490628e4df80caec9669b65149c7ba40f9bf033c025a5b56bc"},
+]
+
+[package.dependencies]
+ipykernel = ">=4.1"
+jupyter-client = ">=4.1"
+jupyter-core = "*"
+packaging = "*"
+pygments = "*"
+pyzmq = ">=17.1"
+qtpy = ">=2.4.0"
+traitlets = "<5.2.1 || >5.2.1,<5.2.2 || >5.2.2"
+
+[package.extras]
+doc = ["Sphinx (>=1.3)"]
+test = ["flaky", "pytest", "pytest-qt"]
+
+[[package]]
+name = "qtpy"
+version = "2.4.1"
+description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "QtPy-2.4.1-py3-none-any.whl", hash = "sha256:1c1d8c4fa2c884ae742b069151b0abe15b3f70491f3972698c683b8e38de839b"},
+ {file = "QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"},
+]
+
+[package.dependencies]
+packaging = "*"
+
+[package.extras]
+test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"]
+
+[[package]]
+name = "rank-bm25"
+version = "0.2.2"
+description = "Various BM25 algorithms for document ranking"
+optional = true
+python-versions = "*"
+files = [
+ {file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"},
+ {file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"},
+]
+
+[package.dependencies]
+numpy = "*"
+
+[package.extras]
+dev = ["pytest"]
+
+[[package]]
+name = "rapidfuzz"
+version = "3.5.2"
+description = "rapid fuzzy string matching"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "rapidfuzz-3.5.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1a047d6e58833919d742bbc0dfa66d1de4f79e8562ee195007d3eae96635df39"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:22877c027c492b7dc7e3387a576a33ed5aad891104aa90da2e0844c83c5493ef"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e0f448b0eacbcc416feb634e1232a48d1cbde5e60f269c84e4fb0912f7bbb001"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d05146497672f869baf41147d5ec1222788c70e5b8b0cfcd6e95597c75b5b96b"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f2df3968738a38d2a0058b5e721753f5d3d602346a1027b0dde31b0476418f3"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5afc1fcf1830f9bb87d3b490ba03691081b9948a794ea851befd2643069a30c1"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84be69ea65f64fa01e5c4976be9826a5aa949f037508887add42da07420d65d6"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8658c1045766e87e0038323aa38b4a9f49b7f366563271f973c8890a98aa24b5"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:852b3f93c15fce58b8dc668bd54123713bfdbbb0796ba905ea5df99cfd083132"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:12424a06ad9bd0cbf5f7cea1015e78d924a0034a0e75a5a7b39c0703dcd94095"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b4e9ded8e80530bd7205a7a2b01802f934a4695ca9e9fbe1ce9644f5e0697864"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:affb8fe36157c2dc8a7bc45b6a1875eb03e2c49167a1d52789144bdcb7ab3b8c"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1d33a622572d384f4c90b5f7a139328246ab5600141e90032b521c2127bd605"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-win32.whl", hash = "sha256:2cf9f2ed4a97b388cffd48d534452a564c2491f68f4fd5bc140306f774ceb63a"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:6541ffb70097885f7302cd73e2efd77be99841103023c2f9408551f27f45f7a5"},
+ {file = "rapidfuzz-3.5.2-cp310-cp310-win_arm64.whl", hash = "sha256:1dd2542e5103fb8ca46500a979ae14d1609dcba11d2f9fe01e99eec03420e193"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bff7d3127ebc5cd908f3a72f6517f31f5247b84666137556a8fcc5177c560939"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fdfdb3685b631d8efbb6d6d3d86eb631be2b408d9adafcadc11e63e3f9c96dec"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:97b043fe8185ec53bb3ff0e59deb89425c0fc6ece6e118939963aab473505801"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a4a7832737f87583f3863dc62e6f56dd4a9fefc5f04a7bdcb4c433a0f36bb1b"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d876dba9a11fcf60dcf1562c5a84ef559db14c2ceb41e1ad2d93cd1dc085889"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa4c0612893716bbb6595066ca9ecb517c982355abe39ba9d1f4ab834ace91ad"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:120316824333e376b88b284724cfd394c6ccfcb9818519eab5d58a502e5533f0"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9cdbe8e80cc186d55f748a34393533a052d855357d5398a1ccb71a5021b58e8d"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1062425c8358a547ae5ebad148f2e0f02417716a571b803b0c68e4d552e99d32"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:66be181965aff13301dd5f9b94b646ce39d99c7fe2fd5de1656f4ca7fafcb38c"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:53df7aea3cf301633cfa2b4b2c2d2441a87dfc878ef810e5b4eddcd3e68723ad"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:76639dca5eb0afc6424ac5f42d43d3bd342ac710e06f38a8c877d5b96de09589"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:27689361c747b5f7b8a26056bc60979875323f1c3dcaaa9e2fec88f03b20a365"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-win32.whl", hash = "sha256:99c9fc5265566fb94731dc6826f43c5109e797078264e6389a36d47814473692"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:666928ee735562a909d81bd2f63207b3214afd4ca41f790ab3025d066975c814"},
+ {file = "rapidfuzz-3.5.2-cp311-cp311-win_arm64.whl", hash = "sha256:d55de67c48f06b7772541e8d4c062a2679205799ce904236e2836cb04c106442"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:04e1e02b182283c43c866e215317735e91d22f5d34e65400121c04d5ed7ed859"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:365e544aba3ac13acf1a62cb2e5909ad2ba078d0bfc7d69b1f801dfd673b9782"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b61f77d834f94b0099fa9ed35c189b7829759d4e9c2743697a130dd7ba62259f"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43fb368998b9703fa8c63db292a8ab9e988bf6da0c8a635754be8e69da1e7c1d"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:25510b5d142c47786dbd27cfd9da7cae5bdea28d458379377a3644d8460a3404"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf3093443751e5a419834162af358d1e31dec75f84747a91dbbc47b2c04fc085"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2fbaf546f15a924613f89d609ff66b85b4f4c2307ac14d93b80fe1025b713138"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32d580df0e130ed85400ff77e1c32d965e9bc7be29ac4072ab637f57e26d29fb"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:358a0fbc49343de20fee8ebdb33c7fa8f55a9ff93ff42d1ffe097d2caa248f1b"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fb379ac0ddfc86c5542a225d194f76ed468b071b6f79ff57c4b72e635605ad7d"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7fb21e182dc6d83617e88dea002963d5cf99cf5eabbdbf04094f503d8fe8d723"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:c04f9f1310ce414ab00bdcbf26d0906755094bfc59402cb66a7722c6f06d70b2"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f6da61cc38c1a95efc5edcedf258759e6dbab73191651a28c5719587f32a56ad"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-win32.whl", hash = "sha256:f823fd1977071486739f484e27092765d693da6beedaceece54edce1dfeec9b2"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:a8162d81486de85ab1606e48e076431b66d44cf431b2b678e9cae458832e7147"},
+ {file = "rapidfuzz-3.5.2-cp312-cp312-win_arm64.whl", hash = "sha256:dfc63fabb7d8da8483ca836bae7e55766fe39c63253571e103c034ba8ea80950"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:df8fae2515a1e4936affccac3e7d506dd904de5ff82bc0b1433b4574a51b9bfb"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dd6384780c2a16097d47588844cd677316a90e0f41ef96ff485b62d58de79dcf"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:467a4d730ae3bade87dba6bd769e837ab97e176968ce20591fe8f7bf819115b1"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54576669c1502b751b534bd76a4aeaaf838ed88b30af5d5c1b7d0a3ca5d4f7b5"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abafeb82f85a651a9d6d642a33dc021606bc459c33e250925b25d6b9e7105a2e"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73e14617a520c0f1bc15eb78c215383477e5ca70922ecaff1d29c63c060e04ca"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7cdf92116e9dfe40da17f921cdbfa0039dde9eb158914fa5f01b1e67a20b19cb"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1962d5ccf8602589dbf8e85246a0ee2b4050d82fade1568fb76f8a4419257704"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:db45028eae2fda7a24759c69ebeb2a7fbcc1a326606556448ed43ee480237a3c"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b685abb8b6d97989f6c69556d7934e0e533aa8822f50b9517ff2da06a1d29f23"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:40139552961018216b8cd88f6df4ecbbe984f907a62a5c823ccd907132c29a14"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:0fef4705459842ef8f79746d6f6a0b5d2b6a61a145d7d8bbe10b2e756ea337c8"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6b2ad5516f7068c7d9cbcda8ac5906c589e99bc427df2e1050282ee2d8bc2d58"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-win32.whl", hash = "sha256:2da3a24c2f7dfca7f26ba04966b848e3bbeb93e54d899908ff88dfe3e1def9dc"},
+ {file = "rapidfuzz-3.5.2-cp38-cp38-win_amd64.whl", hash = "sha256:e3f2be79d4114d01f383096dbee51b57df141cb8b209c19d0cf65f23a24e75ba"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:089a7e96e5032821af5964d8457fcb38877cc321cdd06ad7c5d6e3d852264cb9"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:75d8a52bf8d1aa2ac968ae4b21b83b94fc7e5ea3dfbab34811fc60f32df505b2"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2bacce6bbc0362f0789253424269cc742b1f45e982430387db3abe1d0496e371"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5fd627e604ddc02db2ddb9ddc4a91dd92b7a6d6378fcf30bb37b49229072b89"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2e8b369f23f00678f6e673572209a5d3b0832f4991888e3df97af7b8b9decf3"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c29958265e4c2b937269e804b8a160c027ee1c2627d6152655008a8b8083630e"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00be97f9219355945c46f37ac9fa447046e6f7930f7c901e5d881120d1695458"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada0d8d57e0f556ef38c24fee71bfe8d0db29c678bff2acd1819fc1b74f331c2"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:de89585268ed8ee44e80126814cae63ff6b00d08416481f31b784570ef07ec59"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:908ff2de9c442b379143d1da3c886c63119d4eba22986806e2533cee603fe64b"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:54f0061028723c026020f5bb20649c22bc8a0d9f5363c283bdc5901d4d3bff01"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:b581107ec0c610cdea48b25f52030770be390db4a9a73ca58b8d70fa8a5ec32e"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1d5a686ea258931aaa38019204bdc670bbe14b389a230b1363d84d6cf4b9dc38"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-win32.whl", hash = "sha256:97f811ca7709c6ee8c0b55830f63b3d87086f4abbcbb189b4067e1cd7014db7b"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:58ee34350f8c292dd24a050186c0e18301d80da904ef572cf5fda7be6a954929"},
+ {file = "rapidfuzz-3.5.2-cp39-cp39-win_arm64.whl", hash = "sha256:c5075ce7b9286624cafcf36720ef1cfb2946d75430b87cb4d1f006e82cd71244"},
+ {file = "rapidfuzz-3.5.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:af5221e4f7800db3e84c46b79dba4112e3b3cc2678f808bdff4fcd2487073846"},
+ {file = "rapidfuzz-3.5.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8501d7875b176930e6ed9dbc1bc35adb37ef312f6106bd6bb5c204adb90160ac"},
+ {file = "rapidfuzz-3.5.2-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e414e1ca40386deda4291aa2d45062fea0fbaa14f95015738f8bb75c4d27f862"},
+ {file = "rapidfuzz-3.5.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2059cd73b7ea779a9307d7a78ed743f0e3d33b88ccdcd84569abd2953cd859f"},
+ {file = "rapidfuzz-3.5.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:58e3e21f6f13a7cca265cce492bc797425bd4cb2025fdd161a9e86a824ad65ce"},
+ {file = "rapidfuzz-3.5.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b847a49377e64e92e11ef3d0a793de75451526c83af015bdafdd5d04de8a058a"},
+ {file = "rapidfuzz-3.5.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a42c7a8c62b29c4810e39da22b42524295fcb793f41c395c2cb07c126b729e83"},
+ {file = "rapidfuzz-3.5.2-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51b5166be86e09e011e92d9862b1fe64c4c7b9385f443fb535024e646d890460"},
+ {file = "rapidfuzz-3.5.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f808dcb0088a7a496cc9895e66a7b8de55ffea0eb9b547c75dfb216dd5f76ed"},
+ {file = "rapidfuzz-3.5.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d4b05a8f4ab7e7344459394094587b033fe259eea3a8720035e8ba30e79ab39b"},
+ {file = "rapidfuzz-3.5.2.tar.gz", hash = "sha256:9e9b395743e12c36a3167a3a9fd1b4e11d92fb0aa21ec98017ee6df639ed385e"},
+]
+
+[package.extras]
+full = ["numpy"]
+
+[[package]]
+name = "rapidocr-onnxruntime"
+version = "1.3.7"
+description = "A cross platform OCR Library based on OnnxRuntime."
+optional = true
+python-versions = ">=3.6,<3.12"
+files = [
+ {file = "rapidocr_onnxruntime-1.3.7-py3-none-any.whl", hash = "sha256:9d061786f6255c57a98f04a2f7624eacabc1d0dede2a69707c99a6dd9024e6fa"},
+]
+
+[package.dependencies]
+numpy = ">=1.19.5"
+onnxruntime = ">=1.7.0"
+opencv-python = ">=4.5.1.48"
+Pillow = "*"
+pyclipper = ">=1.2.0"
+PyYAML = "*"
+Shapely = ">=1.7.1"
+six = ">=1.15.0"
+
+[[package]]
+name = "referencing"
+version = "0.31.1"
+description = "JSON Referencing + Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "referencing-0.31.1-py3-none-any.whl", hash = "sha256:c19c4d006f1757e3dd75c4f784d38f8698d87b649c54f9ace14e5e8c9667c01d"},
+ {file = "referencing-0.31.1.tar.gz", hash = "sha256:81a1471c68c9d5e3831c30ad1dd9815c45b558e596653db751a2bfdd17b3b9ec"},
+]
+
+[package.dependencies]
+attrs = ">=22.2.0"
+rpds-py = ">=0.7.0"
+
+[[package]]
+name = "regex"
+version = "2023.10.3"
+description = "Alternative regular expression module, to replace re."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "regex-2023.10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4c34d4f73ea738223a094d8e0ffd6d2c1a1b4c175da34d6b0de3d8d69bee6bcc"},
+ {file = "regex-2023.10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8f4e49fc3ce020f65411432183e6775f24e02dff617281094ba6ab079ef0915"},
+ {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cd1bccf99d3ef1ab6ba835308ad85be040e6a11b0977ef7ea8c8005f01a3c29"},
+ {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81dce2ddc9f6e8f543d94b05d56e70d03a0774d32f6cca53e978dc01e4fc75b8"},
+ {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c6b4d23c04831e3ab61717a707a5d763b300213db49ca680edf8bf13ab5d91b"},
+ {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c15ad0aee158a15e17e0495e1e18741573d04eb6da06d8b84af726cfc1ed02ee"},
+ {file = "regex-2023.10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6239d4e2e0b52c8bd38c51b760cd870069f0bdf99700a62cd509d7a031749a55"},
+ {file = "regex-2023.10.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4a8bf76e3182797c6b1afa5b822d1d5802ff30284abe4599e1247be4fd6b03be"},
+ {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d9c727bbcf0065cbb20f39d2b4f932f8fa1631c3e01fcedc979bd4f51fe051c5"},
+ {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3ccf2716add72f80714b9a63899b67fa711b654be3fcdd34fa391d2d274ce767"},
+ {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:107ac60d1bfdc3edb53be75e2a52aff7481b92817cfdddd9b4519ccf0e54a6ff"},
+ {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:00ba3c9818e33f1fa974693fb55d24cdc8ebafcb2e4207680669d8f8d7cca79a"},
+ {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f0a47efb1dbef13af9c9a54a94a0b814902e547b7f21acb29434504d18f36e3a"},
+ {file = "regex-2023.10.3-cp310-cp310-win32.whl", hash = "sha256:36362386b813fa6c9146da6149a001b7bd063dabc4d49522a1f7aa65b725c7ec"},
+ {file = "regex-2023.10.3-cp310-cp310-win_amd64.whl", hash = "sha256:c65a3b5330b54103e7d21cac3f6bf3900d46f6d50138d73343d9e5b2900b2353"},
+ {file = "regex-2023.10.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:90a79bce019c442604662d17bf69df99090e24cdc6ad95b18b6725c2988a490e"},
+ {file = "regex-2023.10.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c7964c2183c3e6cce3f497e3a9f49d182e969f2dc3aeeadfa18945ff7bdd7051"},
+ {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ef80829117a8061f974b2fda8ec799717242353bff55f8a29411794d635d964"},
+ {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5addc9d0209a9afca5fc070f93b726bf7003bd63a427f65ef797a931782e7edc"},
+ {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c148bec483cc4b421562b4bcedb8e28a3b84fcc8f0aa4418e10898f3c2c0eb9b"},
+ {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d1f21af4c1539051049796a0f50aa342f9a27cde57318f2fc41ed50b0dbc4ac"},
+ {file = "regex-2023.10.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b9ac09853b2a3e0d0082104036579809679e7715671cfbf89d83c1cb2a30f58"},
+ {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ebedc192abbc7fd13c5ee800e83a6df252bec691eb2c4bedc9f8b2e2903f5e2a"},
+ {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d8a993c0a0ffd5f2d3bda23d0cd75e7086736f8f8268de8a82fbc4bd0ac6791e"},
+ {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:be6b7b8d42d3090b6c80793524fa66c57ad7ee3fe9722b258aec6d0672543fd0"},
+ {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4023e2efc35a30e66e938de5aef42b520c20e7eda7bb5fb12c35e5d09a4c43f6"},
+ {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0d47840dc05e0ba04fe2e26f15126de7c755496d5a8aae4a08bda4dd8d646c54"},
+ {file = "regex-2023.10.3-cp311-cp311-win32.whl", hash = "sha256:9145f092b5d1977ec8c0ab46e7b3381b2fd069957b9862a43bd383e5c01d18c2"},
+ {file = "regex-2023.10.3-cp311-cp311-win_amd64.whl", hash = "sha256:b6104f9a46bd8743e4f738afef69b153c4b8b592d35ae46db07fc28ae3d5fb7c"},
+ {file = "regex-2023.10.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff507ae210371d4b1fe316d03433ac099f184d570a1a611e541923f78f05037"},
+ {file = "regex-2023.10.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be5e22bbb67924dea15039c3282fa4cc6cdfbe0cbbd1c0515f9223186fc2ec5f"},
+ {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a992f702c9be9c72fa46f01ca6e18d131906a7180950958f766c2aa294d4b41"},
+ {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7434a61b158be563c1362d9071358f8ab91b8d928728cd2882af060481244c9e"},
+ {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2169b2dcabf4e608416f7f9468737583ce5f0a6e8677c4efbf795ce81109d7c"},
+ {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9e908ef5889cda4de038892b9accc36d33d72fb3e12c747e2799a0e806ec841"},
+ {file = "regex-2023.10.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12bd4bc2c632742c7ce20db48e0d99afdc05e03f0b4c1af90542e05b809a03d9"},
+ {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bc72c231f5449d86d6c7d9cc7cd819b6eb30134bb770b8cfdc0765e48ef9c420"},
+ {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bce8814b076f0ce5766dc87d5a056b0e9437b8e0cd351b9a6c4e1134a7dfbda9"},
+ {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ba7cd6dc4d585ea544c1412019921570ebd8a597fabf475acc4528210d7c4a6f"},
+ {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b0c7d2f698e83f15228ba41c135501cfe7d5740181d5903e250e47f617eb4292"},
+ {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5a8f91c64f390ecee09ff793319f30a0f32492e99f5dc1c72bc361f23ccd0a9a"},
+ {file = "regex-2023.10.3-cp312-cp312-win32.whl", hash = "sha256:ad08a69728ff3c79866d729b095872afe1e0557251da4abb2c5faff15a91d19a"},
+ {file = "regex-2023.10.3-cp312-cp312-win_amd64.whl", hash = "sha256:39cdf8d141d6d44e8d5a12a8569d5a227f645c87df4f92179bd06e2e2705e76b"},
+ {file = "regex-2023.10.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4a3ee019a9befe84fa3e917a2dd378807e423d013377a884c1970a3c2792d293"},
+ {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76066d7ff61ba6bf3cb5efe2428fc82aac91802844c022d849a1f0f53820502d"},
+ {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfe50b61bab1b1ec260fa7cd91106fa9fece57e6beba05630afe27c71259c59b"},
+ {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fd88f373cb71e6b59b7fa597e47e518282455c2734fd4306a05ca219a1991b0"},
+ {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3ab05a182c7937fb374f7e946f04fb23a0c0699c0450e9fb02ef567412d2fa3"},
+ {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dac37cf08fcf2094159922edc7a2784cfcc5c70f8354469f79ed085f0328ebdf"},
+ {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e54ddd0bb8fb626aa1f9ba7b36629564544954fff9669b15da3610c22b9a0991"},
+ {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3367007ad1951fde612bf65b0dffc8fd681a4ab98ac86957d16491400d661302"},
+ {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:16f8740eb6dbacc7113e3097b0a36065a02e37b47c936b551805d40340fb9971"},
+ {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:f4f2ca6df64cbdd27f27b34f35adb640b5d2d77264228554e68deda54456eb11"},
+ {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:39807cbcbe406efca2a233884e169d056c35aa7e9f343d4e78665246a332f597"},
+ {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7eece6fbd3eae4a92d7c748ae825cbc1ee41a89bb1c3db05b5578ed3cfcfd7cb"},
+ {file = "regex-2023.10.3-cp37-cp37m-win32.whl", hash = "sha256:ce615c92d90df8373d9e13acddd154152645c0dc060871abf6bd43809673d20a"},
+ {file = "regex-2023.10.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0f649fa32fe734c4abdfd4edbb8381c74abf5f34bc0b3271ce687b23729299ed"},
+ {file = "regex-2023.10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9b98b7681a9437262947f41c7fac567c7e1f6eddd94b0483596d320092004533"},
+ {file = "regex-2023.10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:91dc1d531f80c862441d7b66c4505cd6ea9d312f01fb2f4654f40c6fdf5cc37a"},
+ {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82fcc1f1cc3ff1ab8a57ba619b149b907072e750815c5ba63e7aa2e1163384a4"},
+ {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7979b834ec7a33aafae34a90aad9f914c41fd6eaa8474e66953f3f6f7cbd4368"},
+ {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef71561f82a89af6cfcbee47f0fabfdb6e63788a9258e913955d89fdd96902ab"},
+ {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd829712de97753367153ed84f2de752b86cd1f7a88b55a3a775eb52eafe8a94"},
+ {file = "regex-2023.10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00e871d83a45eee2f8688d7e6849609c2ca2a04a6d48fba3dff4deef35d14f07"},
+ {file = "regex-2023.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:706e7b739fdd17cb89e1fbf712d9dc21311fc2333f6d435eac2d4ee81985098c"},
+ {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cc3f1c053b73f20c7ad88b0d1d23be7e7b3901229ce89f5000a8399746a6e039"},
+ {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6f85739e80d13644b981a88f529d79c5bdf646b460ba190bffcaf6d57b2a9863"},
+ {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:741ba2f511cc9626b7561a440f87d658aabb3d6b744a86a3c025f866b4d19e7f"},
+ {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e77c90ab5997e85901da85131fd36acd0ed2221368199b65f0d11bca44549711"},
+ {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:979c24cbefaf2420c4e377ecd1f165ea08cc3d1fbb44bdc51bccbbf7c66a2cb4"},
+ {file = "regex-2023.10.3-cp38-cp38-win32.whl", hash = "sha256:58837f9d221744d4c92d2cf7201c6acd19623b50c643b56992cbd2b745485d3d"},
+ {file = "regex-2023.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:c55853684fe08d4897c37dfc5faeff70607a5f1806c8be148f1695be4a63414b"},
+ {file = "regex-2023.10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2c54e23836650bdf2c18222c87f6f840d4943944146ca479858404fedeb9f9af"},
+ {file = "regex-2023.10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:69c0771ca5653c7d4b65203cbfc5e66db9375f1078689459fe196fe08b7b4930"},
+ {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ac965a998e1388e6ff2e9781f499ad1eaa41e962a40d11c7823c9952c77123e"},
+ {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c0e8fae5b27caa34177bdfa5a960c46ff2f78ee2d45c6db15ae3f64ecadde14"},
+ {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c56c3d47da04f921b73ff9415fbaa939f684d47293f071aa9cbb13c94afc17d"},
+ {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ef1e014eed78ab650bef9a6a9cbe50b052c0aebe553fb2881e0453717573f52"},
+ {file = "regex-2023.10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d29338556a59423d9ff7b6eb0cb89ead2b0875e08fe522f3e068b955c3e7b59b"},
+ {file = "regex-2023.10.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9c6d0ced3c06d0f183b73d3c5920727268d2201aa0fe6d55c60d68c792ff3588"},
+ {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:994645a46c6a740ee8ce8df7911d4aee458d9b1bc5639bc968226763d07f00fa"},
+ {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:66e2fe786ef28da2b28e222c89502b2af984858091675044d93cb50e6f46d7af"},
+ {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:11175910f62b2b8c055f2b089e0fedd694fe2be3941b3e2633653bc51064c528"},
+ {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:06e9abc0e4c9ab4779c74ad99c3fc10d3967d03114449acc2c2762ad4472b8ca"},
+ {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fb02e4257376ae25c6dd95a5aec377f9b18c09be6ebdefa7ad209b9137b73d48"},
+ {file = "regex-2023.10.3-cp39-cp39-win32.whl", hash = "sha256:3b2c3502603fab52d7619b882c25a6850b766ebd1b18de3df23b2f939360e1bd"},
+ {file = "regex-2023.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:adbccd17dcaff65704c856bd29951c58a1bd4b2b0f8ad6b826dbd543fe740988"},
+ {file = "regex-2023.10.3.tar.gz", hash = "sha256:3fef4f844d2290ee0ba57addcec17eec9e3df73f10a2748485dfd6a3a188cc0f"},
+]
+
+[[package]]
+name = "requests"
+version = "2.31.0"
+description = "Python HTTP for Humans."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
+ {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
+]
+
+[package.dependencies]
+certifi = ">=2017.4.17"
+charset-normalizer = ">=2,<4"
+idna = ">=2.5,<4"
+urllib3 = ">=1.21.1,<3"
+
+[package.extras]
+socks = ["PySocks (>=1.5.6,!=1.5.7)"]
+use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+
+[[package]]
+name = "requests-file"
+version = "1.5.1"
+description = "File transport adapter for Requests"
+optional = true
+python-versions = "*"
+files = [
+ {file = "requests-file-1.5.1.tar.gz", hash = "sha256:07d74208d3389d01c38ab89ef403af0cfec63957d53a0081d8eca738d0247d8e"},
+ {file = "requests_file-1.5.1-py2.py3-none-any.whl", hash = "sha256:dfe5dae75c12481f68ba353183c53a65e6044c923e64c24b2209f6c7570ca953"},
+]
+
+[package.dependencies]
+requests = ">=1.0.0"
+six = "*"
+
+[[package]]
+name = "requests-mock"
+version = "1.11.0"
+description = "Mock out responses from the requests package"
+optional = false
+python-versions = "*"
+files = [
+ {file = "requests-mock-1.11.0.tar.gz", hash = "sha256:ef10b572b489a5f28e09b708697208c4a3b2b89ef80a9f01584340ea357ec3c4"},
+ {file = "requests_mock-1.11.0-py2.py3-none-any.whl", hash = "sha256:f7fae383f228633f6bececebdab236c478ace2284d6292c6e7e2867b9ab74d15"},
+]
+
+[package.dependencies]
+requests = ">=2.3,<3"
+six = "*"
+
+[package.extras]
+fixture = ["fixtures"]
+test = ["fixtures", "mock", "purl", "pytest", "requests-futures", "sphinx", "testtools"]
+
+[[package]]
+name = "requests-oauthlib"
+version = "1.3.1"
+description = "OAuthlib authentication support for Requests."
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
+ {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
+]
+
+[package.dependencies]
+oauthlib = ">=3.0.0"
+requests = ">=2.0.0"
+
+[package.extras]
+rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
+
+[[package]]
+name = "requests-toolbelt"
+version = "1.0.0"
+description = "A utility belt for advanced users of python-requests"
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"},
+ {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"},
+]
+
+[package.dependencies]
+requests = ">=2.0.1,<3.0.0"
+
+[[package]]
+name = "responses"
+version = "0.22.0"
+description = "A utility library for mocking out the `requests` Python library."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "responses-0.22.0-py3-none-any.whl", hash = "sha256:dcf294d204d14c436fddcc74caefdbc5764795a40ff4e6a7740ed8ddbf3294be"},
+ {file = "responses-0.22.0.tar.gz", hash = "sha256:396acb2a13d25297789a5866b4881cf4e46ffd49cc26c43ab1117f40b973102e"},
+]
+
+[package.dependencies]
+requests = ">=2.22.0,<3.0"
+toml = "*"
+types-toml = "*"
+urllib3 = ">=1.25.10"
+
+[package.extras]
+tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "types-requests"]
+
+[[package]]
+name = "rfc3339-validator"
+version = "0.1.4"
+description = "A pure python RFC3339 validator"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+ {file = "rfc3339_validator-0.1.4-py2.py3-none-any.whl", hash = "sha256:24f6ec1eda14ef823da9e36ec7113124b39c04d50a4d3d3a3c2859577e7791fa"},
+ {file = "rfc3339_validator-0.1.4.tar.gz", hash = "sha256:138a2abdf93304ad60530167e51d2dfb9549521a836871b88d7f4695d0022f6b"},
+]
+
+[package.dependencies]
+six = "*"
+
+[[package]]
+name = "rfc3986-validator"
+version = "0.1.1"
+description = "Pure python rfc3986 validator"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+ {file = "rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9"},
+ {file = "rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055"},
+]
+
+[[package]]
+name = "rich"
+version = "13.7.0"
+description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
+optional = true
+python-versions = ">=3.7.0"
+files = [
+ {file = "rich-13.7.0-py3-none-any.whl", hash = "sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235"},
+ {file = "rich-13.7.0.tar.gz", hash = "sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa"},
+]
+
+[package.dependencies]
+markdown-it-py = ">=2.2.0"
+pygments = ">=2.13.0,<3.0.0"
+typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
+
+[package.extras]
+jupyter = ["ipywidgets (>=7.5.1,<9)"]
+
+[[package]]
+name = "rpds-py"
+version = "0.13.2"
+description = "Python bindings to Rust's persistent data structures (rpds)"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "rpds_py-0.13.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:1ceebd0ae4f3e9b2b6b553b51971921853ae4eebf3f54086be0565d59291e53d"},
+ {file = "rpds_py-0.13.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:46e1ed994a0920f350a4547a38471217eb86f57377e9314fbaaa329b71b7dfe3"},
+ {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee353bb51f648924926ed05e0122b6a0b1ae709396a80eb583449d5d477fcdf7"},
+ {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:530190eb0cd778363bbb7596612ded0bb9fef662daa98e9d92a0419ab27ae914"},
+ {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d311e44dd16d2434d5506d57ef4d7036544fc3c25c14b6992ef41f541b10fb"},
+ {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e72f750048b32d39e87fc85c225c50b2a6715034848dbb196bf3348aa761fa1"},
+ {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db09b98c7540df69d4b47218da3fbd7cb466db0fb932e971c321f1c76f155266"},
+ {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2ac26f50736324beb0282c819668328d53fc38543fa61eeea2c32ea8ea6eab8d"},
+ {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:12ecf89bd54734c3c2c79898ae2021dca42750c7bcfb67f8fb3315453738ac8f"},
+ {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a44c8440183b43167fd1a0819e8356692bf5db1ad14ce140dbd40a1485f2dea"},
+ {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bcef4f2d3dc603150421de85c916da19471f24d838c3c62a4f04c1eb511642c1"},
+ {file = "rpds_py-0.13.2-cp310-none-win32.whl", hash = "sha256:ee6faebb265e28920a6f23a7d4c362414b3f4bb30607141d718b991669e49ddc"},
+ {file = "rpds_py-0.13.2-cp310-none-win_amd64.whl", hash = "sha256:ac96d67b37f28e4b6ecf507c3405f52a40658c0a806dffde624a8fcb0314d5fd"},
+ {file = "rpds_py-0.13.2-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:b5f6328e8e2ae8238fc767703ab7b95785521c42bb2b8790984e3477d7fa71ad"},
+ {file = "rpds_py-0.13.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:729408136ef8d45a28ee9a7411917c9e3459cf266c7e23c2f7d4bb8ef9e0da42"},
+ {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65cfed9c807c27dee76407e8bb29e6f4e391e436774bcc769a037ff25ad8646e"},
+ {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aefbdc934115d2f9278f153952003ac52cd2650e7313750390b334518c589568"},
+ {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d48db29bd47814671afdd76c7652aefacc25cf96aad6daefa82d738ee87461e2"},
+ {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c55d7f2d817183d43220738270efd3ce4e7a7b7cbdaefa6d551ed3d6ed89190"},
+ {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6aadae3042f8e6db3376d9e91f194c606c9a45273c170621d46128f35aef7cd0"},
+ {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5feae2f9aa7270e2c071f488fab256d768e88e01b958f123a690f1cc3061a09c"},
+ {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51967a67ea0d7b9b5cd86036878e2d82c0b6183616961c26d825b8c994d4f2c8"},
+ {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d0c10d803549427f427085ed7aebc39832f6e818a011dcd8785e9c6a1ba9b3e"},
+ {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:603d5868f7419081d616dab7ac3cfa285296735e7350f7b1e4f548f6f953ee7d"},
+ {file = "rpds_py-0.13.2-cp311-none-win32.whl", hash = "sha256:b8996ffb60c69f677245f5abdbcc623e9442bcc91ed81b6cd6187129ad1fa3e7"},
+ {file = "rpds_py-0.13.2-cp311-none-win_amd64.whl", hash = "sha256:5379e49d7e80dca9811b36894493d1c1ecb4c57de05c36f5d0dd09982af20211"},
+ {file = "rpds_py-0.13.2-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:8a776a29b77fe0cc28fedfd87277b0d0f7aa930174b7e504d764e0b43a05f381"},
+ {file = "rpds_py-0.13.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2a1472956c5bcc49fb0252b965239bffe801acc9394f8b7c1014ae9258e4572b"},
+ {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f252dfb4852a527987a9156cbcae3022a30f86c9d26f4f17b8c967d7580d65d2"},
+ {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f0d320e70b6b2300ff6029e234e79fe44e9dbbfc7b98597ba28e054bd6606a57"},
+ {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ade2ccb937060c299ab0dfb2dea3d2ddf7e098ed63ee3d651ebfc2c8d1e8632a"},
+ {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9d121be0217787a7d59a5c6195b0842d3f701007333426e5154bf72346aa658"},
+ {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fa6bd071ec6d90f6e7baa66ae25820d57a8ab1b0a3c6d3edf1834d4b26fafa2"},
+ {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c918621ee0a3d1fe61c313f2489464f2ae3d13633e60f520a8002a5e910982ee"},
+ {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:25b28b3d33ec0a78e944aaaed7e5e2a94ac811bcd68b557ca48a0c30f87497d2"},
+ {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:31e220a040b89a01505128c2f8a59ee74732f666439a03e65ccbf3824cdddae7"},
+ {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:15253fff410873ebf3cfba1cc686a37711efcd9b8cb30ea21bb14a973e393f60"},
+ {file = "rpds_py-0.13.2-cp312-none-win32.whl", hash = "sha256:b981a370f8f41c4024c170b42fbe9e691ae2dbc19d1d99151a69e2c84a0d194d"},
+ {file = "rpds_py-0.13.2-cp312-none-win_amd64.whl", hash = "sha256:4c4e314d36d4f31236a545696a480aa04ea170a0b021e9a59ab1ed94d4c3ef27"},
+ {file = "rpds_py-0.13.2-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:80e5acb81cb49fd9f2d5c08f8b74ffff14ee73b10ca88297ab4619e946bcb1e1"},
+ {file = "rpds_py-0.13.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:efe093acc43e869348f6f2224df7f452eab63a2c60a6c6cd6b50fd35c4e075ba"},
+ {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c2a61c0e4811012b0ba9f6cdcb4437865df5d29eab5d6018ba13cee1c3064a0"},
+ {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:751758d9dd04d548ec679224cc00e3591f5ebf1ff159ed0d4aba6a0746352452"},
+ {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ba8858933f0c1a979781272a5f65646fca8c18c93c99c6ddb5513ad96fa54b1"},
+ {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bfdfbe6a36bc3059fff845d64c42f2644cf875c65f5005db54f90cdfdf1df815"},
+ {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa0379c1935c44053c98826bc99ac95f3a5355675a297ac9ce0dfad0ce2d50ca"},
+ {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5593855b5b2b73dd8413c3fdfa5d95b99d657658f947ba2c4318591e745d083"},
+ {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2a7bef6977043673750a88da064fd513f89505111014b4e00fbdd13329cd4e9a"},
+ {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:3ab96754d23372009638a402a1ed12a27711598dd49d8316a22597141962fe66"},
+ {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e06cfea0ece444571d24c18ed465bc93afb8c8d8d74422eb7026662f3d3f779b"},
+ {file = "rpds_py-0.13.2-cp38-none-win32.whl", hash = "sha256:5493569f861fb7b05af6d048d00d773c6162415ae521b7010197c98810a14cab"},
+ {file = "rpds_py-0.13.2-cp38-none-win_amd64.whl", hash = "sha256:b07501b720cf060c5856f7b5626e75b8e353b5f98b9b354a21eb4bfa47e421b1"},
+ {file = "rpds_py-0.13.2-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:881df98f0a8404d32b6de0fd33e91c1b90ed1516a80d4d6dc69d414b8850474c"},
+ {file = "rpds_py-0.13.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d79c159adea0f1f4617f54aa156568ac69968f9ef4d1e5fefffc0a180830308e"},
+ {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38d4f822ee2f338febcc85aaa2547eb5ba31ba6ff68d10b8ec988929d23bb6b4"},
+ {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5d75d6d220d55cdced2f32cc22f599475dbe881229aeddba6c79c2e9df35a2b3"},
+ {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d97e9ae94fb96df1ee3cb09ca376c34e8a122f36927230f4c8a97f469994bff"},
+ {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:67a429520e97621a763cf9b3ba27574779c4e96e49a27ff8a1aa99ee70beb28a"},
+ {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:188435794405c7f0573311747c85a96b63c954a5f2111b1df8018979eca0f2f0"},
+ {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:38f9bf2ad754b4a45b8210a6c732fe876b8a14e14d5992a8c4b7c1ef78740f53"},
+ {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a6ba2cb7d676e9415b9e9ac7e2aae401dc1b1e666943d1f7bc66223d3d73467b"},
+ {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:eaffbd8814bb1b5dc3ea156a4c5928081ba50419f9175f4fc95269e040eff8f0"},
+ {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5a4c1058cdae6237d97af272b326e5f78ee7ee3bbffa6b24b09db4d828810468"},
+ {file = "rpds_py-0.13.2-cp39-none-win32.whl", hash = "sha256:b5267feb19070bef34b8dea27e2b504ebd9d31748e3ecacb3a4101da6fcb255c"},
+ {file = "rpds_py-0.13.2-cp39-none-win_amd64.whl", hash = "sha256:ddf23960cb42b69bce13045d5bc66f18c7d53774c66c13f24cf1b9c144ba3141"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:97163a1ab265a1073a6372eca9f4eeb9f8c6327457a0b22ddfc4a17dcd613e74"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:25ea41635d22b2eb6326f58e608550e55d01df51b8a580ea7e75396bafbb28e9"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d59d4d451ba77f08cb4cd9268dec07be5bc65f73666302dbb5061989b17198"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7c564c58cf8f248fe859a4f0fe501b050663f3d7fbc342172f259124fb59933"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61dbc1e01dc0c5875da2f7ae36d6e918dc1b8d2ce04e871793976594aad8a57a"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdb82eb60d31b0c033a8e8ee9f3fc7dfbaa042211131c29da29aea8531b4f18f"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d204957169f0b3511fb95395a9da7d4490fb361763a9f8b32b345a7fe119cb45"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c45008ca79bad237cbc03c72bc5205e8c6f66403773929b1b50f7d84ef9e4d07"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:79bf58c08f0756adba691d480b5a20e4ad23f33e1ae121584cf3a21717c36dfa"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:e86593bf8637659e6a6ed58854b6c87ec4e9e45ee8a4adfd936831cef55c2d21"},
+ {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:d329896c40d9e1e5c7715c98529e4a188a1f2df51212fd65102b32465612b5dc"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:4a5375c5fff13f209527cd886dc75394f040c7d1ecad0a2cb0627f13ebe78a12"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:06d218e4464d31301e943b65b2c6919318ea6f69703a351961e1baaf60347276"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1f41d32a2ddc5a94df4b829b395916a4b7f103350fa76ba6de625fcb9e773ac"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6bc568b05e02cd612be53900c88aaa55012e744930ba2eeb56279db4c6676eb3"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d94d78418203904730585efa71002286ac4c8ac0689d0eb61e3c465f9e608ff"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bed0252c85e21cf73d2d033643c945b460d6a02fc4a7d644e3b2d6f5f2956c64"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:244e173bb6d8f3b2f0c4d7370a1aa341f35da3e57ffd1798e5b2917b91731fd3"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7f55cd9cf1564b7b03f238e4c017ca4794c05b01a783e9291065cb2858d86ce4"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:f03a1b3a4c03e3e0161642ac5367f08479ab29972ea0ffcd4fa18f729cd2be0a"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:f5f4424cb87a20b016bfdc157ff48757b89d2cc426256961643d443c6c277007"},
+ {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:c82bbf7e03748417c3a88c1b0b291288ce3e4887a795a3addaa7a1cfd9e7153e"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:c0095b8aa3e432e32d372e9a7737e65b58d5ed23b9620fea7cb81f17672f1fa1"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4c2d26aa03d877c9730bf005621c92da263523a1e99247590abbbe252ccb7824"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96f2975fb14f39c5fe75203f33dd3010fe37d1c4e33177feef1107b5ced750e3"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4dcc5ee1d0275cb78d443fdebd0241e58772a354a6d518b1d7af1580bbd2c4e8"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61d42d2b08430854485135504f672c14d4fc644dd243a9c17e7c4e0faf5ed07e"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d3a61e928feddc458a55110f42f626a2a20bea942ccedb6fb4cee70b4830ed41"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7de12b69d95072394998c622cfd7e8cea8f560db5fca6a62a148f902a1029f8b"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:87a90f5545fd61f6964e65eebde4dc3fa8660bb7d87adb01d4cf17e0a2b484ad"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:9c95a1a290f9acf7a8f2ebbdd183e99215d491beea52d61aa2a7a7d2c618ddc6"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:35f53c76a712e323c779ca39b9a81b13f219a8e3bc15f106ed1e1462d56fcfe9"},
+ {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:96fb0899bb2ab353f42e5374c8f0789f54e0a94ef2f02b9ac7149c56622eaf31"},
+ {file = "rpds_py-0.13.2.tar.gz", hash = "sha256:f8eae66a1304de7368932b42d801c67969fd090ddb1a7a24f27b435ed4bed68f"},
+]
+
+[[package]]
+name = "rsa"
+version = "4.9"
+description = "Pure-Python RSA implementation"
+optional = false
+python-versions = ">=3.6,<4"
+files = [
+ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
+ {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.1.3"
+
+[[package]]
+name = "rspace-client"
+version = "2.5.0"
+description = "A client for calling RSpace ELN and Inventory APIs"
+optional = true
+python-versions = ">=3.7.11,<4.0.0"
+files = [
+ {file = "rspace-client-2.5.0.tar.gz", hash = "sha256:101abc83d094051d2babcaa133fa1a47221b3d5953d72eef3c331ef7084071a1"},
+ {file = "rspace_client-2.5.0-py3-none-any.whl", hash = "sha256:b1072df88dfa8f068f3137584d20cf135493b0521a9809c2f6ddec6b378a9cc3"},
+]
+
+[package.dependencies]
+beautifulsoup4 = ">=4.9.3,<5.0.0"
+requests = ">=2.25.1,<3.0.0"
+
+[[package]]
+name = "ruff"
+version = "0.1.7"
+description = "An extremely fast Python linter and code formatter, written in Rust."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "ruff-0.1.7-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7f80496854fdc65b6659c271d2c26e90d4d401e6a4a31908e7e334fab4645aac"},
+ {file = "ruff-0.1.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:1ea109bdb23c2a4413f397ebd8ac32cb498bee234d4191ae1a310af760e5d287"},
+ {file = "ruff-0.1.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b0c2de9dd9daf5e07624c24add25c3a490dbf74b0e9bca4145c632457b3b42a"},
+ {file = "ruff-0.1.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:69a4bed13bc1d5dabf3902522b5a2aadfebe28226c6269694283c3b0cecb45fd"},
+ {file = "ruff-0.1.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de02ca331f2143195a712983a57137c5ec0f10acc4aa81f7c1f86519e52b92a1"},
+ {file = "ruff-0.1.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45b38c3f8788a65e6a2cab02e0f7adfa88872696839d9882c13b7e2f35d64c5f"},
+ {file = "ruff-0.1.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c64cb67b2025b1ac6d58e5ffca8f7b3f7fd921f35e78198411237e4f0db8e73"},
+ {file = "ruff-0.1.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9dcc6bb2f4df59cb5b4b40ff14be7d57012179d69c6565c1da0d1f013d29951b"},
+ {file = "ruff-0.1.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df2bb4bb6bbe921f6b4f5b6fdd8d8468c940731cb9406f274ae8c5ed7a78c478"},
+ {file = "ruff-0.1.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:276a89bcb149b3d8c1b11d91aa81898fe698900ed553a08129b38d9d6570e717"},
+ {file = "ruff-0.1.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:90c958fe950735041f1c80d21b42184f1072cc3975d05e736e8d66fc377119ea"},
+ {file = "ruff-0.1.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b05e3b123f93bb4146a761b7a7d57af8cb7384ccb2502d29d736eaade0db519"},
+ {file = "ruff-0.1.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:290ecab680dce94affebefe0bbca2322a6277e83d4f29234627e0f8f6b4fa9ce"},
+ {file = "ruff-0.1.7-py3-none-win32.whl", hash = "sha256:416dfd0bd45d1a2baa3b1b07b1b9758e7d993c256d3e51dc6e03a5e7901c7d80"},
+ {file = "ruff-0.1.7-py3-none-win_amd64.whl", hash = "sha256:4af95fd1d3b001fc41325064336db36e3d27d2004cdb6d21fd617d45a172dd96"},
+ {file = "ruff-0.1.7-py3-none-win_arm64.whl", hash = "sha256:0683b7bfbb95e6df3c7c04fe9d78f631f8e8ba4868dfc932d43d690698057e2e"},
+ {file = "ruff-0.1.7.tar.gz", hash = "sha256:dffd699d07abf54833e5f6cc50b85a6ff043715da8788c4a79bcd4ab4734d306"},
+]
+
+[[package]]
+name = "s3transfer"
+version = "0.8.2"
+description = "An Amazon S3 Transfer Manager"
+optional = false
+python-versions = ">= 3.7"
+files = [
+ {file = "s3transfer-0.8.2-py3-none-any.whl", hash = "sha256:c9e56cbe88b28d8e197cf841f1f0c130f246595e77ae5b5a05b69fe7cb83de76"},
+ {file = "s3transfer-0.8.2.tar.gz", hash = "sha256:368ac6876a9e9ed91f6bc86581e319be08188dc60d50e0d56308ed5765446283"},
+]
+
+[package.dependencies]
+botocore = ">=1.33.2,<2.0a.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
+
+[[package]]
+name = "scikit-learn"
+version = "1.3.2"
+description = "A set of python modules for machine learning and data mining"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05"},
+ {file = "scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1"},
+ {file = "scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a"},
+ {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c"},
+ {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161"},
+ {file = "scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c"},
+ {file = "scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66"},
+ {file = "scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157"},
+ {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb"},
+ {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433"},
+ {file = "scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b"},
+ {file = "scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028"},
+ {file = "scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5"},
+ {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525"},
+ {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c"},
+ {file = "scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107"},
+ {file = "scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93"},
+ {file = "scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073"},
+ {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d"},
+ {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf"},
+ {file = "scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0"},
+ {file = "scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03"},
+ {file = "scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e"},
+ {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a"},
+ {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9"},
+ {file = "scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0"},
+]
+
+[package.dependencies]
+joblib = ">=1.1.1"
+numpy = ">=1.17.3,<2.0"
+scipy = ">=1.5.0"
+threadpoolctl = ">=2.0.0"
+
+[package.extras]
+benchmark = ["matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "pandas (>=1.0.5)"]
+docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.10.1)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"]
+examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"]
+tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.16.2)"]
+
+[[package]]
+name = "scipy"
+version = "1.9.3"
+description = "Fundamental algorithms for scientific computing in Python"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "scipy-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0"},
+ {file = "scipy-1.9.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:83b89e9586c62e787f5012e8475fbb12185bafb996a03257e9675cd73d3736dd"},
+ {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a72d885fa44247f92743fc20732ae55564ff2a519e8302fb7e18717c5355a8b"},
+ {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d01e1dd7b15bd2449c8bfc6b7cc67d630700ed655654f0dfcf121600bad205c9"},
+ {file = "scipy-1.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:68239b6aa6f9c593da8be1509a05cb7f9efe98b80f43a5861cd24c7557e98523"},
+ {file = "scipy-1.9.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b41bc822679ad1c9a5f023bc93f6d0543129ca0f37c1ce294dd9d386f0a21096"},
+ {file = "scipy-1.9.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:90453d2b93ea82a9f434e4e1cba043e779ff67b92f7a0e85d05d286a3625df3c"},
+ {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83c06e62a390a9167da60bedd4575a14c1f58ca9dfde59830fc42e5197283dab"},
+ {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abaf921531b5aeaafced90157db505e10345e45038c39e5d9b6c7922d68085cb"},
+ {file = "scipy-1.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:06d2e1b4c491dc7d8eacea139a1b0b295f74e1a1a0f704c375028f8320d16e31"},
+ {file = "scipy-1.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a04cd7d0d3eff6ea4719371cbc44df31411862b9646db617c99718ff68d4840"},
+ {file = "scipy-1.9.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:545c83ffb518094d8c9d83cce216c0c32f8c04aaf28b92cc8283eda0685162d5"},
+ {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d54222d7a3ba6022fdf5773931b5d7c56efe41ede7f7128c7b1637700409108"},
+ {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cff3a5295234037e39500d35316a4c5794739433528310e117b8a9a0c76d20fc"},
+ {file = "scipy-1.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:2318bef588acc7a574f5bfdff9c172d0b1bf2c8143d9582e05f878e580a3781e"},
+ {file = "scipy-1.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d644a64e174c16cb4b2e41dfea6af722053e83d066da7343f333a54dae9bc31c"},
+ {file = "scipy-1.9.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:da8245491d73ed0a994ed9c2e380fd058ce2fa8a18da204681f2fe1f57f98f95"},
+ {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4db5b30849606a95dcf519763dd3ab6fe9bd91df49eba517359e450a7d80ce2e"},
+ {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c68db6b290cbd4049012990d7fe71a2abd9ffbe82c0056ebe0f01df8be5436b0"},
+ {file = "scipy-1.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:5b88e6d91ad9d59478fafe92a7c757d00c59e3bdc3331be8ada76a4f8d683f58"},
+ {file = "scipy-1.9.3.tar.gz", hash = "sha256:fbc5c05c85c1a02be77b1ff591087c83bc44579c6d2bd9fb798bb64ea5e1a027"},
+]
+
+[package.dependencies]
+numpy = ">=1.18.5,<1.26.0"
+
+[package.extras]
+dev = ["flake8", "mypy", "pycodestyle", "typing_extensions"]
+doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-panels (>=0.5.2)", "sphinx-tabs"]
+test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+
+[[package]]
+name = "send2trash"
+version = "1.8.2"
+description = "Send file to trash natively under Mac OS X, Windows and Linux"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
+files = [
+ {file = "Send2Trash-1.8.2-py3-none-any.whl", hash = "sha256:a384719d99c07ce1eefd6905d2decb6f8b7ed054025bb0e618919f945de4f679"},
+ {file = "Send2Trash-1.8.2.tar.gz", hash = "sha256:c132d59fa44b9ca2b1699af5c86f57ce9f4c5eb56629d5d55fbb7a35f84e2312"},
+]
+
+[package.extras]
+nativelib = ["pyobjc-framework-Cocoa", "pywin32"]
+objc = ["pyobjc-framework-Cocoa"]
+win32 = ["pywin32"]
+
+[[package]]
+name = "setuptools"
+version = "67.8.0"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "setuptools-67.8.0-py3-none-any.whl", hash = "sha256:5df61bf30bb10c6f756eb19e7c9f3b473051f48db77fddbe06ff2ca307df9a6f"},
+ {file = "setuptools-67.8.0.tar.gz", hash = "sha256:62642358adc77ffa87233bc4d2354c4b2682d214048f500964dbe760ccedf102"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
+testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
+testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
+
+[[package]]
+name = "sgmllib3k"
+version = "1.0.0"
+description = "Py3k port of sgmllib."
+optional = true
+python-versions = "*"
+files = [
+ {file = "sgmllib3k-1.0.0.tar.gz", hash = "sha256:7868fb1c8bfa764c1ac563d3cf369c381d1325d36124933a726f29fcdaa812e9"},
+]
+
+[[package]]
+name = "shapely"
+version = "2.0.2"
+description = "Manipulation and analysis of geometric objects"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "shapely-2.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6ca8cffbe84ddde8f52b297b53f8e0687bd31141abb2c373fd8a9f032df415d6"},
+ {file = "shapely-2.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:baa14fc27771e180c06b499a0a7ba697c7988c7b2b6cba9a929a19a4d2762de3"},
+ {file = "shapely-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:36480e32c434d168cdf2f5e9862c84aaf4d714a43a8465ae3ce8ff327f0affb7"},
+ {file = "shapely-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ef753200cbffd4f652efb2c528c5474e5a14341a473994d90ad0606522a46a2"},
+ {file = "shapely-2.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9a41ff4323fc9d6257759c26eb1cf3a61ebc7e611e024e6091f42977303fd3a"},
+ {file = "shapely-2.0.2-cp310-cp310-win32.whl", hash = "sha256:72b5997272ae8c25f0fd5b3b967b3237e87fab7978b8d6cd5fa748770f0c5d68"},
+ {file = "shapely-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:34eac2337cbd67650248761b140d2535855d21b969d76d76123317882d3a0c1a"},
+ {file = "shapely-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5b0c052709c8a257c93b0d4943b0b7a3035f87e2d6a8ac9407b6a992d206422f"},
+ {file = "shapely-2.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2d217e56ae067e87b4e1731d0dc62eebe887ced729ba5c2d4590e9e3e9fdbd88"},
+ {file = "shapely-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94ac128ae2ab4edd0bffcd4e566411ea7bdc738aeaf92c32a8a836abad725f9f"},
+ {file = "shapely-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa3ee28f5e63a130ec5af4dc3c4cb9c21c5788bb13c15e89190d163b14f9fb89"},
+ {file = "shapely-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:737dba15011e5a9b54a8302f1748b62daa207c9bc06f820cd0ad32a041f1c6f2"},
+ {file = "shapely-2.0.2-cp311-cp311-win32.whl", hash = "sha256:45ac6906cff0765455a7b49c1670af6e230c419507c13e2f75db638c8fc6f3bd"},
+ {file = "shapely-2.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:dc9342fc82e374130db86a955c3c4525bfbf315a248af8277a913f30911bed9e"},
+ {file = "shapely-2.0.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:06f193091a7c6112fc08dfd195a1e3846a64306f890b151fa8c63b3e3624202c"},
+ {file = "shapely-2.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eebe544df5c018134f3c23b6515877f7e4cd72851f88a8d0c18464f414d141a2"},
+ {file = "shapely-2.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7e92e7c255f89f5cdf777690313311f422aa8ada9a3205b187113274e0135cd8"},
+ {file = "shapely-2.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be46d5509b9251dd9087768eaf35a71360de6afac82ce87c636990a0871aa18b"},
+ {file = "shapely-2.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5533a925d8e211d07636ffc2fdd9a7f9f13d54686d00577eeb11d16f00be9c4"},
+ {file = "shapely-2.0.2-cp312-cp312-win32.whl", hash = "sha256:084b023dae8ad3d5b98acee9d3bf098fdf688eb0bb9b1401e8b075f6a627b611"},
+ {file = "shapely-2.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:ea84d1cdbcf31e619d672b53c4532f06253894185ee7acb8ceb78f5f33cbe033"},
+ {file = "shapely-2.0.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ed1e99702125e7baccf401830a3b94d810d5c70b329b765fe93451fe14cf565b"},
+ {file = "shapely-2.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7d897e6bdc6bc64f7f65155dbbb30e49acaabbd0d9266b9b4041f87d6e52b3a"},
+ {file = "shapely-2.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0521d76d1e8af01e712db71da9096b484f081e539d4f4a8c97342e7971d5e1b4"},
+ {file = "shapely-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:5324be299d4c533ecfcfd43424dfd12f9428fd6f12cda38a4316da001d6ef0ea"},
+ {file = "shapely-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:78128357a0cee573257a0c2c388d4b7bf13cb7dbe5b3fe5d26d45ebbe2a39e25"},
+ {file = "shapely-2.0.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:87dc2be34ac3a3a4a319b963c507ac06682978a5e6c93d71917618b14f13066e"},
+ {file = "shapely-2.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:42997ac806e4583dad51c80a32d38570fd9a3d4778f5e2c98f9090aa7db0fe91"},
+ {file = "shapely-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ccfd5fa10a37e67dbafc601c1ddbcbbfef70d34c3f6b0efc866ddbdb55893a6c"},
+ {file = "shapely-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7c95d3379ae3abb74058938a9fcbc478c6b2e28d20dace38f8b5c587dde90aa"},
+ {file = "shapely-2.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a21353d28209fb0d8cc083e08ca53c52666e0d8a1f9bbe23b6063967d89ed24"},
+ {file = "shapely-2.0.2-cp38-cp38-win32.whl", hash = "sha256:03e63a99dfe6bd3beb8d5f41ec2086585bb969991d603f9aeac335ad396a06d4"},
+ {file = "shapely-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:c6fd29fbd9cd76350bd5cc14c49de394a31770aed02d74203e23b928f3d2f1aa"},
+ {file = "shapely-2.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1f217d28ecb48e593beae20a0082a95bd9898d82d14b8fcb497edf6bff9a44d7"},
+ {file = "shapely-2.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:394e5085b49334fd5b94fa89c086edfb39c3ecab7f669e8b2a4298b9d523b3a5"},
+ {file = "shapely-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fd3ad17b64466a033848c26cb5b509625c87d07dcf39a1541461cacdb8f7e91c"},
+ {file = "shapely-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d41a116fcad58048d7143ddb01285e1a8780df6dc1f56c3b1e1b7f12ed296651"},
+ {file = "shapely-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dea9a0651333cf96ef5bb2035044e3ad6a54f87d90e50fe4c2636debf1b77abc"},
+ {file = "shapely-2.0.2-cp39-cp39-win32.whl", hash = "sha256:b8eb0a92f7b8c74f9d8fdd1b40d395113f59bd8132ca1348ebcc1f5aece94b96"},
+ {file = "shapely-2.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:794affd80ca0f2c536fc948a3afa90bd8fb61ebe37fe873483ae818e7f21def4"},
+ {file = "shapely-2.0.2.tar.gz", hash = "sha256:1713cc04c171baffc5b259ba8531c58acc2a301707b7f021d88a15ed090649e7"},
+]
+
+[package.dependencies]
+numpy = ">=1.14"
+
+[package.extras]
+docs = ["matplotlib", "numpydoc (==1.1.*)", "sphinx", "sphinx-book-theme", "sphinx-remove-toctrees"]
+test = ["pytest", "pytest-cov"]
+
+[[package]]
+name = "six"
+version = "1.16.0"
+description = "Python 2 and 3 compatibility utilities"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
+files = [
+ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
+ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
+]
+
+[[package]]
+name = "smmap"
+version = "5.0.1"
+description = "A pure Python implementation of a sliding window memory map manager"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"},
+ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"},
+]
+
+[[package]]
+name = "sniffio"
+version = "1.3.0"
+description = "Sniff out which async library your code is running under"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
+ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
+]
+
+[[package]]
+name = "soupsieve"
+version = "2.5"
+description = "A modern CSS selector implementation for Beautiful Soup."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"},
+ {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"},
+]
+
+[[package]]
+name = "sqlalchemy"
+version = "2.0.23"
+description = "Database Abstraction Library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"},
+ {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"},
+ {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:787af80107fb691934a01889ca8f82a44adedbf5ef3d6ad7d0f0b9ac557e0c34"},
+ {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c14eba45983d2f48f7546bb32b47937ee2cafae353646295f0e99f35b14286ab"},
+ {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0666031df46b9badba9bed00092a1ffa3aa063a5e68fa244acd9f08070e936d3"},
+ {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89a01238fcb9a8af118eaad3ffcc5dedaacbd429dc6fdc43fe430d3a941ff965"},
+ {file = "SQLAlchemy-2.0.23-cp310-cp310-win32.whl", hash = "sha256:cabafc7837b6cec61c0e1e5c6d14ef250b675fa9c3060ed8a7e38653bd732ff8"},
+ {file = "SQLAlchemy-2.0.23-cp310-cp310-win_amd64.whl", hash = "sha256:87a3d6b53c39cd173990de2f5f4b83431d534a74f0e2f88bd16eabb5667e65c6"},
+ {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d5578e6863eeb998980c212a39106ea139bdc0b3f73291b96e27c929c90cd8e1"},
+ {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:62d9e964870ea5ade4bc870ac4004c456efe75fb50404c03c5fd61f8bc669a72"},
+ {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c80c38bd2ea35b97cbf7c21aeb129dcbebbf344ee01a7141016ab7b851464f8e"},
+ {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75eefe09e98043cff2fb8af9796e20747ae870c903dc61d41b0c2e55128f958d"},
+ {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd45a5b6c68357578263d74daab6ff9439517f87da63442d244f9f23df56138d"},
+ {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a86cb7063e2c9fb8e774f77fbf8475516d270a3e989da55fa05d08089d77f8c4"},
+ {file = "SQLAlchemy-2.0.23-cp311-cp311-win32.whl", hash = "sha256:b41f5d65b54cdf4934ecede2f41b9c60c9f785620416e8e6c48349ab18643855"},
+ {file = "SQLAlchemy-2.0.23-cp311-cp311-win_amd64.whl", hash = "sha256:9ca922f305d67605668e93991aaf2c12239c78207bca3b891cd51a4515c72e22"},
+ {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0f7fb0c7527c41fa6fcae2be537ac137f636a41b4c5a4c58914541e2f436b45"},
+ {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c424983ab447dab126c39d3ce3be5bee95700783204a72549c3dceffe0fc8f4"},
+ {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f508ba8f89e0a5ecdfd3761f82dda2a3d7b678a626967608f4273e0dba8f07ac"},
+ {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6463aa765cf02b9247e38b35853923edbf2f6fd1963df88706bc1d02410a5577"},
+ {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e599a51acf3cc4d31d1a0cf248d8f8d863b6386d2b6782c5074427ebb7803bda"},
+ {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fd54601ef9cc455a0c61e5245f690c8a3ad67ddb03d3b91c361d076def0b4c60"},
+ {file = "SQLAlchemy-2.0.23-cp312-cp312-win32.whl", hash = "sha256:42d0b0290a8fb0165ea2c2781ae66e95cca6e27a2fbe1016ff8db3112ac1e846"},
+ {file = "SQLAlchemy-2.0.23-cp312-cp312-win_amd64.whl", hash = "sha256:227135ef1e48165f37590b8bfc44ed7ff4c074bf04dc8d6f8e7f1c14a94aa6ca"},
+ {file = "SQLAlchemy-2.0.23-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:14aebfe28b99f24f8a4c1346c48bc3d63705b1f919a24c27471136d2f219f02d"},
+ {file = "SQLAlchemy-2.0.23-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e983fa42164577d073778d06d2cc5d020322425a509a08119bdcee70ad856bf"},
+ {file = "SQLAlchemy-2.0.23-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0dc9031baa46ad0dd5a269cb7a92a73284d1309228be1d5935dac8fb3cae24"},
+ {file = "SQLAlchemy-2.0.23-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5f94aeb99f43729960638e7468d4688f6efccb837a858b34574e01143cf11f89"},
+ {file = "SQLAlchemy-2.0.23-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:63bfc3acc970776036f6d1d0e65faa7473be9f3135d37a463c5eba5efcdb24c8"},
+ {file = "SQLAlchemy-2.0.23-cp37-cp37m-win32.whl", hash = "sha256:f48ed89dd11c3c586f45e9eec1e437b355b3b6f6884ea4a4c3111a3358fd0c18"},
+ {file = "SQLAlchemy-2.0.23-cp37-cp37m-win_amd64.whl", hash = "sha256:1e018aba8363adb0599e745af245306cb8c46b9ad0a6fc0a86745b6ff7d940fc"},
+ {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64ac935a90bc479fee77f9463f298943b0e60005fe5de2aa654d9cdef46c54df"},
+ {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c4722f3bc3c1c2fcc3702dbe0016ba31148dd6efcd2a2fd33c1b4897c6a19693"},
+ {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4af79c06825e2836de21439cb2a6ce22b2ca129bad74f359bddd173f39582bf5"},
+ {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:683ef58ca8eea4747737a1c35c11372ffeb84578d3aab8f3e10b1d13d66f2bc4"},
+ {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d4041ad05b35f1f4da481f6b811b4af2f29e83af253bf37c3c4582b2c68934ab"},
+ {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aeb397de65a0a62f14c257f36a726945a7f7bb60253462e8602d9b97b5cbe204"},
+ {file = "SQLAlchemy-2.0.23-cp38-cp38-win32.whl", hash = "sha256:42ede90148b73fe4ab4a089f3126b2cfae8cfefc955c8174d697bb46210c8306"},
+ {file = "SQLAlchemy-2.0.23-cp38-cp38-win_amd64.whl", hash = "sha256:964971b52daab357d2c0875825e36584d58f536e920f2968df8d581054eada4b"},
+ {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:616fe7bcff0a05098f64b4478b78ec2dfa03225c23734d83d6c169eb41a93e55"},
+ {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0e680527245895aba86afbd5bef6c316831c02aa988d1aad83c47ffe92655e74"},
+ {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9585b646ffb048c0250acc7dad92536591ffe35dba624bb8fd9b471e25212a35"},
+ {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4895a63e2c271ffc7a81ea424b94060f7b3b03b4ea0cd58ab5bb676ed02f4221"},
+ {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cc1d21576f958c42d9aec68eba5c1a7d715e5fc07825a629015fe8e3b0657fb0"},
+ {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:967c0b71156f793e6662dd839da54f884631755275ed71f1539c95bbada9aaab"},
+ {file = "SQLAlchemy-2.0.23-cp39-cp39-win32.whl", hash = "sha256:0a8c6aa506893e25a04233bc721c6b6cf844bafd7250535abb56cb6cc1368884"},
+ {file = "SQLAlchemy-2.0.23-cp39-cp39-win_amd64.whl", hash = "sha256:f3420d00d2cb42432c1d0e44540ae83185ccbbc67a6054dcc8ab5387add6620b"},
+ {file = "SQLAlchemy-2.0.23-py3-none-any.whl", hash = "sha256:31952bbc527d633b9479f5f81e8b9dfada00b91d6baba021a869095f1a97006d"},
+ {file = "SQLAlchemy-2.0.23.tar.gz", hash = "sha256:c1bda93cbbe4aa2aa0aa8655c5aeda505cd219ff3e8da91d1d329e143e4aff69"},
+]
+
+[package.dependencies]
+greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
+typing-extensions = ">=4.2.0"
+
+[package.extras]
+aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"]
+aioodbc = ["aioodbc", "greenlet (!=0.4.17)"]
+aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"]
+asyncio = ["greenlet (!=0.4.17)"]
+asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"]
+mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"]
+mssql = ["pyodbc"]
+mssql-pymssql = ["pymssql"]
+mssql-pyodbc = ["pyodbc"]
+mypy = ["mypy (>=0.910)"]
+mysql = ["mysqlclient (>=1.4.0)"]
+mysql-connector = ["mysql-connector-python"]
+oracle = ["cx-oracle (>=8)"]
+oracle-oracledb = ["oracledb (>=1.0.1)"]
+postgresql = ["psycopg2 (>=2.7)"]
+postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"]
+postgresql-pg8000 = ["pg8000 (>=1.29.1)"]
+postgresql-psycopg = ["psycopg (>=3.0.7)"]
+postgresql-psycopg2binary = ["psycopg2-binary"]
+postgresql-psycopg2cffi = ["psycopg2cffi"]
+postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
+pymysql = ["pymysql"]
+sqlcipher = ["sqlcipher3-binary"]
+
+[[package]]
+name = "sqlite-vss"
+version = "0.1.2"
+description = ""
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "sqlite_vss-0.1.2-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:9eefa4207f8b522e32b2747fce44422c773e36710bf807613795218c7ba125f0"},
+ {file = "sqlite_vss-0.1.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:84994eaf7fe700218b258422358c4536a6aca39b96026c308b28630967f954c4"},
+ {file = "sqlite_vss-0.1.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:e44f03bc4cb214bb77b206519abfb623e3e4795967a569218e288927a7715806"},
+]
+
+[package.extras]
+test = ["pytest"]
+
+[[package]]
+name = "sqlparse"
+version = "0.4.4"
+description = "A non-validating SQL parser."
+optional = true
+python-versions = ">=3.5"
+files = [
+ {file = "sqlparse-0.4.4-py3-none-any.whl", hash = "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3"},
+ {file = "sqlparse-0.4.4.tar.gz", hash = "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c"},
+]
+
+[package.extras]
+dev = ["build", "flake8"]
+doc = ["sphinx"]
+test = ["pytest", "pytest-cov"]
+
+[[package]]
+name = "stack-data"
+version = "0.6.3"
+description = "Extract data from python stack frames and tracebacks for informative displays"
+optional = false
+python-versions = "*"
+files = [
+ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"},
+ {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"},
+]
+
+[package.dependencies]
+asttokens = ">=2.1.0"
+executing = ">=1.2.0"
+pure-eval = "*"
+
+[package.extras]
+tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
+
+[[package]]
+name = "streamlit"
+version = "1.22.0"
+description = "A faster way to build and share data apps"
+optional = true
+python-versions = ">=3.7, !=3.9.7"
+files = [
+ {file = "streamlit-1.22.0-py2.py3-none-any.whl", hash = "sha256:520dd9b9e6efb559b5a9a22feadb48b1e6f0340ec83da3514810059fdecd4167"},
+ {file = "streamlit-1.22.0.tar.gz", hash = "sha256:5bef9bf8deef32814d9565c9df48331e6357eb0b90dabc3ec4f53c44fb34fc73"},
+]
+
+[package.dependencies]
+altair = ">=3.2.0,<5"
+blinker = ">=1.0.0"
+cachetools = ">=4.0"
+click = ">=7.0"
+gitpython = "!=3.1.19"
+importlib-metadata = ">=1.4"
+numpy = "*"
+packaging = ">=14.1"
+pandas = ">=0.25,<3"
+pillow = ">=6.2.0"
+protobuf = ">=3.12,<4"
+pyarrow = ">=4.0"
+pydeck = ">=0.1.dev5"
+pympler = ">=0.9"
+python-dateutil = "*"
+requests = ">=2.4"
+rich = ">=10.11.0"
+tenacity = ">=8.0.0,<9"
+toml = "*"
+tornado = ">=6.0.3"
+typing-extensions = ">=3.10.0.0"
+tzlocal = ">=1.1"
+validators = ">=0.2"
+watchdog = {version = "*", markers = "platform_system != \"Darwin\""}
+
+[package.extras]
+snowflake = ["snowflake-snowpark-python"]
+
+[[package]]
+name = "sympy"
+version = "1.12"
+description = "Computer algebra system (CAS) in Python"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
+ {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
+]
+
+[package.dependencies]
+mpmath = ">=0.19"
+
+[[package]]
+name = "syrupy"
+version = "4.6.0"
+description = "Pytest Snapshot Test Utility"
+optional = false
+python-versions = ">=3.8.1,<4"
+files = [
+ {file = "syrupy-4.6.0-py3-none-any.whl", hash = "sha256:747aae1bcf3cb3249e33b1e6d81097874d23615982d5686ebe637875b0775a1b"},
+ {file = "syrupy-4.6.0.tar.gz", hash = "sha256:231b1f5d00f1f85048ba81676c79448076189c4aef4d33f21ae32f3b4c565a54"},
+]
+
+[package.dependencies]
+pytest = ">=7.0.0,<8.0.0"
+
+[[package]]
+name = "tabulate"
+version = "0.9.0"
+description = "Pretty-print tabular data"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
+ {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
+]
+
+[package.extras]
+widechars = ["wcwidth"]
+
+[[package]]
+name = "telethon"
+version = "1.33.0"
+description = "Full-featured Telegram client library for Python 3"
+optional = true
+python-versions = ">=3.5"
+files = [
+ {file = "Telethon-1.33.0.tar.gz", hash = "sha256:9e515bac70fc5bc58cbace0e193a42bb42511f16bb7a319707f860476e8ce164"},
+]
+
+[package.dependencies]
+pyaes = "*"
+rsa = "*"
+
+[package.extras]
+cryptg = ["cryptg"]
+
+[[package]]
+name = "tenacity"
+version = "8.2.3"
+description = "Retry code until it succeeds"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"},
+ {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"},
+]
+
+[package.extras]
+doc = ["reno", "sphinx", "tornado (>=4.5)"]
+
+[[package]]
+name = "terminado"
+version = "0.18.0"
+description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "terminado-0.18.0-py3-none-any.whl", hash = "sha256:87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e"},
+ {file = "terminado-0.18.0.tar.gz", hash = "sha256:1ea08a89b835dd1b8c0c900d92848147cef2537243361b2e3f4dc15df9b6fded"},
+]
+
+[package.dependencies]
+ptyprocess = {version = "*", markers = "os_name != \"nt\""}
+pywinpty = {version = ">=1.1.0", markers = "os_name == \"nt\""}
+tornado = ">=6.1.0"
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
+test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"]
+typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"]
+
+[[package]]
+name = "threadpoolctl"
+version = "3.2.0"
+description = "threadpoolctl"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"},
+ {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"},
+]
+
+[[package]]
+name = "tiktoken"
+version = "0.3.3"
+description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "tiktoken-0.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1f37fa75ba70c1bc7806641e8ccea1fba667d23e6341a1591ea333914c226a9"},
+ {file = "tiktoken-0.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3d7296c38392a943c2ccc0b61323086b8550cef08dcf6855de9949890dbc1fd3"},
+ {file = "tiktoken-0.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c84491965e139a905280ac28b74baaa13445b3678e07f96767089ad1ef5ee7b"},
+ {file = "tiktoken-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65970d77ea85ce6c7fce45131da9258cd58a802ffb29ead8f5552e331c025b2b"},
+ {file = "tiktoken-0.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bd3f72d0ba7312c25c1652292121a24c8f1711207b63c6d8dab21afe4be0bf04"},
+ {file = "tiktoken-0.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:719c9e13432602dc496b24f13e3c3ad3ec0d2fbdb9aace84abfb95e9c3a425a4"},
+ {file = "tiktoken-0.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:dc00772284c94e65045b984ed7e9f95d000034f6b2411df252011b069bd36217"},
+ {file = "tiktoken-0.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4db2c40f79f8f7a21a9fdbf1c6dee32dea77b0d7402355dc584a3083251d2e15"},
+ {file = "tiktoken-0.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3c0f2231aa3829a1a431a882201dc27858634fd9989898e0f7d991dbc6bcc9d"},
+ {file = "tiktoken-0.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48c13186a479de16cfa2c72bb0631fa9c518350a5b7569e4d77590f7fee96be9"},
+ {file = "tiktoken-0.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6674e4e37ab225020135cd66a392589623d5164c6456ba28cc27505abed10d9e"},
+ {file = "tiktoken-0.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4a0c1357f6191211c544f935d5aa3cb9d7abd118c8f3c7124196d5ecd029b4af"},
+ {file = "tiktoken-0.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2e948d167fc3b04483cbc33426766fd742e7cefe5346cd62b0cbd7279ef59539"},
+ {file = "tiktoken-0.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:5dca434c8680b987eacde2dbc449e9ea4526574dbf9f3d8938665f638095be82"},
+ {file = "tiktoken-0.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:984758ebc07cd8c557345697c234f1f221bd730b388f4340dd08dffa50213a01"},
+ {file = "tiktoken-0.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:891012f29e159a989541ae47259234fb29ff88c22e1097567316e27ad33a3734"},
+ {file = "tiktoken-0.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:210f8602228e4c5d706deeb389da5a152b214966a5aa558eec87b57a1969ced5"},
+ {file = "tiktoken-0.3.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd783564f80d4dc44ff0a64b13756ded8390ed2548549aefadbe156af9188307"},
+ {file = "tiktoken-0.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:03f64bde9b4eb8338bf49c8532bfb4c3578f6a9a6979fc176d939f9e6f68b408"},
+ {file = "tiktoken-0.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:1ac369367b6f5e5bd80e8f9a7766ac2a9c65eda2aa856d5f3c556d924ff82986"},
+ {file = "tiktoken-0.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:94600798891f78db780e5aa9321456cf355e54a4719fbd554147a628de1f163f"},
+ {file = "tiktoken-0.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e59db6fca8d5ccea302fe2888917364446d6f4201a25272a1a1c44975c65406a"},
+ {file = "tiktoken-0.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:19340d8ba4d6fd729b2e3a096a547ded85f71012843008f97475f9db484869ee"},
+ {file = "tiktoken-0.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:542686cbc9225540e3a10f472f82fa2e1bebafce2233a211dee8459e95821cfd"},
+ {file = "tiktoken-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a43612b2a09f4787c050163a216bf51123851859e9ab128ad03d2729826cde9"},
+ {file = "tiktoken-0.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a11674f0275fa75fb59941b703650998bd4acb295adbd16fc8af17051aaed19d"},
+ {file = "tiktoken-0.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:65fc0a449630bab28c30b4adec257442a4706d79cffc2337c1d9df3e91825cdd"},
+ {file = "tiktoken-0.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:0b9a7a9a8b781a50ee9289e85e28771d7e113cc0c656eadfb6fc6d3a106ff9bb"},
+ {file = "tiktoken-0.3.3.tar.gz", hash = "sha256:97b58b7bfda945791ec855e53d166e8ec20c6378942b93851a6c919ddf9d0496"},
+]
+
+[package.dependencies]
+regex = ">=2022.1.18"
+requests = ">=2.26.0"
+
+[package.extras]
+blobfile = ["blobfile (>=2)"]
+
+[[package]]
+name = "timescale-vector"
+version = "0.0.1"
+description = "Python library for storing vector data in Postgres"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "timescale-vector-0.0.1.tar.gz", hash = "sha256:420d088b1d45e98f5b9770c76ddf826521aa6e813cb4997d24355eaeda1a7775"},
+ {file = "timescale_vector-0.0.1-py3-none-any.whl", hash = "sha256:81283e8f359387bacd2bd092431a288f34c211968c53b3fed7f3fed1979f39eb"},
+]
+
+[package.dependencies]
+asyncpg = "*"
+pgvector = "*"
+psycopg2 = "*"
+
+[package.extras]
+dev = ["python-dotenv"]
+
+[[package]]
+name = "tinycss2"
+version = "1.2.1"
+description = "A tiny CSS parser"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tinycss2-1.2.1-py3-none-any.whl", hash = "sha256:2b80a96d41e7c3914b8cda8bc7f705a4d9c49275616e886103dd839dfc847847"},
+ {file = "tinycss2-1.2.1.tar.gz", hash = "sha256:8cff3a8f066c2ec677c06dbc7b45619804a6938478d9d73c284b29d14ecb0627"},
+]
+
+[package.dependencies]
+webencodings = ">=0.4"
+
+[package.extras]
+doc = ["sphinx", "sphinx_rtd_theme"]
+test = ["flake8", "isort", "pytest"]
+
+[[package]]
+name = "tinysegmenter"
+version = "0.3"
+description = "Very compact Japanese tokenizer"
+optional = true
+python-versions = "*"
+files = [
+ {file = "tinysegmenter-0.3.tar.gz", hash = "sha256:ed1f6d2e806a4758a73be589754384cbadadc7e1a414c81a166fc9adf2d40c6d"},
+]
+
+[[package]]
+name = "tldextract"
+version = "5.1.1"
+description = "Accurately separates a URL's subdomain, domain, and public suffix, using the Public Suffix List (PSL). By default, this includes the public ICANN TLDs and their exceptions. You can optionally support the Public Suffix List's private domains as well."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "tldextract-5.1.1-py3-none-any.whl", hash = "sha256:b9c4510a8766d377033b6bace7e9f1f17a891383ced3c5d50c150f181e9e1cc2"},
+ {file = "tldextract-5.1.1.tar.gz", hash = "sha256:9b6dbf803cb5636397f0203d48541c0da8ba53babaf0e8a6feda2d88746813d4"},
+]
+
+[package.dependencies]
+filelock = ">=3.0.8"
+idna = "*"
+requests = ">=2.1.0"
+requests-file = ">=1.4"
+
+[package.extras]
+testing = ["black", "mypy", "pytest", "pytest-gitignore", "pytest-mock", "responses", "ruff", "tox", "types-filelock", "types-requests"]
+
+[[package]]
+name = "tokenizers"
+version = "0.15.0"
+description = ""
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tokenizers-0.15.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cd3cd0299aaa312cd2988957598f80becd04d5a07338741eca076057a2b37d6e"},
+ {file = "tokenizers-0.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a922c492c721744ee175f15b91704be2d305569d25f0547c77cd6c9f210f9dc"},
+ {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:331dd786d02fc38698f835fff61c99480f98b73ce75a4c65bd110c9af5e4609a"},
+ {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88dd0961c437d413ab027f8b115350c121d49902cfbadf08bb8f634b15fa1814"},
+ {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6fdcc55339df7761cd52e1fbe8185d3b3963bc9e3f3545faa6c84f9e8818259a"},
+ {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1480b0051d8ab5408e8e4db2dc832f7082ea24aa0722c427bde2418c6f3bd07"},
+ {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9855e6c258918f9cf62792d4f6ddfa6c56dccd8c8118640f867f6393ecaf8bd7"},
+ {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de9529fe75efcd54ba8d516aa725e1851df9199f0669b665c55e90df08f5af86"},
+ {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8edcc90a36eab0705fe9121d6c77c6e42eeef25c7399864fd57dfb27173060bf"},
+ {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ae17884aafb3e94f34fb7cfedc29054f5f54e142475ebf8a265a4e388fee3f8b"},
+ {file = "tokenizers-0.15.0-cp310-none-win32.whl", hash = "sha256:9a3241acdc9b44cff6e95c4a55b9be943ef3658f8edb3686034d353734adba05"},
+ {file = "tokenizers-0.15.0-cp310-none-win_amd64.whl", hash = "sha256:4b31807cb393d6ea31926b307911c89a1209d5e27629aa79553d1599c8ffdefe"},
+ {file = "tokenizers-0.15.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:af7e9be8c05d30bb137b9fd20f9d99354816599e5fd3d58a4b1e28ba3b36171f"},
+ {file = "tokenizers-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c3d7343fa562ea29661783344a2d83662db0d3d17a6fa6a403cac8e512d2d9fd"},
+ {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:32371008788aeeb0309a9244809a23e4c0259625e6b74a103700f6421373f395"},
+ {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca9db64c7c9954fbae698884c5bb089764edc549731e5f9b7fa1dd4e4d78d77f"},
+ {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dbed5944c31195514669cf6381a0d8d47f164943000d10f93d6d02f0d45c25e0"},
+ {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aab16c4a26d351d63e965b0c792f5da7227a37b69a6dc6d922ff70aa595b1b0c"},
+ {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c2b60b12fdd310bf85ce5d7d3f823456b9b65eed30f5438dd7761879c495983"},
+ {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0344d6602740e44054a9e5bbe9775a5e149c4dddaff15959bb07dcce95a5a859"},
+ {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4525f6997d81d9b6d9140088f4f5131f6627e4c960c2c87d0695ae7304233fc3"},
+ {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:65975094fef8cc68919644936764efd2ce98cf1bacbe8db2687155d2b0625bee"},
+ {file = "tokenizers-0.15.0-cp311-none-win32.whl", hash = "sha256:ff5d2159c5d93015f5a4542aac6c315506df31853123aa39042672031768c301"},
+ {file = "tokenizers-0.15.0-cp311-none-win_amd64.whl", hash = "sha256:2dd681b53cf615e60a31a115a3fda3980e543d25ca183797f797a6c3600788a3"},
+ {file = "tokenizers-0.15.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:c9cce6ee149a3d703f86877bc2a6d997e34874b2d5a2d7839e36b2273f31d3d9"},
+ {file = "tokenizers-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a0a94bc3370e6f1cc8a07a8ae867ce13b7c1b4291432a773931a61f256d44ea"},
+ {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:309cfcccfc7e502cb1f1de2c9c1c94680082a65bfd3a912d5a5b2c90c677eb60"},
+ {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8413e994dd7d875ab13009127fc85633916c71213917daf64962bafd488f15dc"},
+ {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d0ebf9430f901dbdc3dcb06b493ff24a3644c9f88c08e6a1d6d0ae2228b9b818"},
+ {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10361e9c7864b22dd791ec5126327f6c9292fb1d23481d4895780688d5e298ac"},
+ {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:babe42635b8a604c594bdc56d205755f73414fce17ba8479d142a963a6c25cbc"},
+ {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3768829861e964c7a4556f5f23307fce6a23872c2ebf030eb9822dbbbf7e9b2a"},
+ {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9c91588a630adc88065e1c03ac6831e3e2112558869b9ebcb2b8afd8a14c944d"},
+ {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:77606994e793ca54ecf3a3619adc8a906a28ca223d9354b38df41cb8766a0ed6"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:6fe143939f3b596681922b2df12a591a5b010e7dcfbee2202482cd0c1c2f2459"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:b7bee0f1795e3e3561e9a557061b1539e5255b8221e3f928f58100282407e090"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5d37e7f4439b4c46192ab4f2ff38ab815e4420f153caa13dec9272ef14403d34"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caadf255cf7f951b38d10097836d1f3bcff4aeaaffadfdf748bab780bf5bff95"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05accb9162bf711a941b1460b743d62fec61c160daf25e53c5eea52c74d77814"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26a2ef890740127cb115ee5260878f4a677e36a12831795fd7e85887c53b430b"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e54c5f26df14913620046b33e822cb3bcd091a332a55230c0e63cc77135e2169"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669b8ed653a578bcff919566631156f5da3aab84c66f3c0b11a6281e8b4731c7"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0ea480d943297df26f06f508dab6e012b07f42bf3dffdd36e70799368a5f5229"},
+ {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bc80a0a565ebfc7cd89de7dd581da8c2b3238addfca6280572d27d763f135f2f"},
+ {file = "tokenizers-0.15.0-cp37-none-win32.whl", hash = "sha256:cdd945e678bbdf4517d5d8de66578a5030aeefecdb46f5320b034de9cad8d4dd"},
+ {file = "tokenizers-0.15.0-cp37-none-win_amd64.whl", hash = "sha256:1ab96ab7dc706e002c32b2ea211a94c1c04b4f4de48354728c3a6e22401af322"},
+ {file = "tokenizers-0.15.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:f21c9eb71c9a671e2a42f18b456a3d118e50c7f0fc4dd9fa8f4eb727fea529bf"},
+ {file = "tokenizers-0.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a5f4543a35889679fc3052086e69e81880b2a5a28ff2a52c5a604be94b77a3f"},
+ {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f8aa81afec893e952bd39692b2d9ef60575ed8c86fce1fd876a06d2e73e82dca"},
+ {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1574a5a4af22c3def93fe8fe4adcc90a39bf5797ed01686a4c46d1c3bc677d2f"},
+ {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c7982fd0ec9e9122d03b209dac48cebfea3de0479335100ef379a9a959b9a5a"},
+ {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d16b647032df2ce2c1f9097236e046ea9fedd969b25637b9d5d734d78aa53b"},
+ {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b3cdf29e6f9653da330515dc8fa414be5a93aae79e57f8acc50d4028dd843edf"},
+ {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7286f3df10de840867372e3e64b99ef58c677210e3ceb653cd0e740a5c53fe78"},
+ {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aabc83028baa5a36ce7a94e7659250f0309c47fa4a639e5c2c38e6d5ea0de564"},
+ {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:72f78b0e0e276b1fc14a672fa73f3acca034ba8db4e782124a2996734a9ba9cf"},
+ {file = "tokenizers-0.15.0-cp38-none-win32.whl", hash = "sha256:9680b0ecc26e7e42f16680c1aa62e924d58d1c2dd992707081cc10a374896ea2"},
+ {file = "tokenizers-0.15.0-cp38-none-win_amd64.whl", hash = "sha256:f17cbd88dab695911cbdd385a5a7e3709cc61dff982351f5d1b5939f074a2466"},
+ {file = "tokenizers-0.15.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:3661862df7382c5eb23ac4fbf7c75e69b02dc4f5784e4c5a734db406b5b24596"},
+ {file = "tokenizers-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c3045d191dad49647f5a5039738ecf1c77087945c7a295f7bcf051c37067e883"},
+ {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9fcaad9ab0801f14457d7c820d9f246b5ab590c407fc6b073819b1573097aa7"},
+ {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a79f17027f24fe9485701c8dbb269b9c713954ec3bdc1e7075a66086c0c0cd3c"},
+ {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:01a3aa332abc4bee7640563949fcfedca4de8f52691b3b70f2fc6ca71bfc0f4e"},
+ {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05b83896a893cdfedad8785250daa3ba9f0504848323471524d4783d7291661e"},
+ {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbbf2489fcf25d809731ba2744ff278dd07d9eb3f8b7482726bd6cae607073a4"},
+ {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab806ad521a5e9de38078b7add97589c313915f6f5fec6b2f9f289d14d607bd6"},
+ {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4a522612d5c88a41563e3463226af64e2fa00629f65cdcc501d1995dd25d23f5"},
+ {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e58a38c4e6075810bdfb861d9c005236a72a152ebc7005941cc90d1bbf16aca9"},
+ {file = "tokenizers-0.15.0-cp39-none-win32.whl", hash = "sha256:b8034f1041fd2bd2b84ff9f4dc4ae2e1c3b71606820a9cd5c562ebd291a396d1"},
+ {file = "tokenizers-0.15.0-cp39-none-win_amd64.whl", hash = "sha256:edde9aa964145d528d0e0dbf14f244b8a85ebf276fb76869bc02e2530fa37a96"},
+ {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:309445d10d442b7521b98083dc9f0b5df14eca69dbbfebeb98d781ee2cef5d30"},
+ {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d3125a6499226d4d48efc54f7498886b94c418e93a205b673bc59364eecf0804"},
+ {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ed56ddf0d54877bb9c6d885177db79b41576e61b5ef6defeb579dcb803c04ad5"},
+ {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b22cd714706cc5b18992a232b023f736e539495f5cc61d2d28d176e55046f6c"},
+ {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fac2719b1e9bc8e8e7f6599b99d0a8e24f33d023eb8ef644c0366a596f0aa926"},
+ {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:85ddae17570ec7e5bfaf51ffa78d044f444a8693e1316e1087ee6150596897ee"},
+ {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:76f1bed992e396bf6f83e3df97b64ff47885e45e8365f8983afed8556a0bc51f"},
+ {file = "tokenizers-0.15.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:3bb0f4df6dce41a1c7482087b60d18c372ef4463cb99aa8195100fcd41e0fd64"},
+ {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:22c27672c27a059a5f39ff4e49feed8c7f2e1525577c8a7e3978bd428eb5869d"},
+ {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78104f5d035c9991f92831fc0efe9e64a05d4032194f2a69f67aaa05a4d75bbb"},
+ {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40b73dc19d82c3e3ffb40abdaacca8fbc95eeb26c66b7f9f860aebc07a73998"},
+ {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d801d1368188c74552cd779b1286e67cb9fd96f4c57a9f9a2a09b6def9e1ab37"},
+ {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82641ffb13a4da1293fcc9f437d457647e60ed0385a9216cd135953778b3f0a1"},
+ {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:160f9d1810f2c18fffa94aa98bf17632f6bd2dabc67fcb01a698ca80c37d52ee"},
+ {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:8d7d6eea831ed435fdeeb9bcd26476226401d7309d115a710c65da4088841948"},
+ {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f6456bec6c557d63d8ec0023758c32f589e1889ed03c055702e84ce275488bed"},
+ {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eef39a502fad3bf104b9e1906b4fb0cee20e44e755e51df9a98f8922c3bf6d4"},
+ {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1e4664c5b797e093c19b794bbecc19d2367e782b4a577d8b7c1821db5dc150d"},
+ {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ca003fb5f3995ff5cf676db6681b8ea5d54d3b30bea36af1120e78ee1a4a4cdf"},
+ {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7f17363141eb0c53752c89e10650b85ef059a52765d0802ba9613dbd2d21d425"},
+ {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:8a765db05581c7d7e1280170f2888cda351760d196cc059c37ea96f121125799"},
+ {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:2a0dd641a72604486cd7302dd8f87a12c8a9b45e1755e47d2682733f097c1af5"},
+ {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0a1a3c973e4dc97797fc19e9f11546c95278ffc55c4492acb742f69e035490bc"},
+ {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4fab75642aae4e604e729d6f78e0addb9d7e7d49e28c8f4d16b24da278e5263"},
+ {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65f80be77f6327a86d8fd35a4467adcfe6174c159b4ab52a1a8dd4c6f2d7d9e1"},
+ {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a8da7533dbe66b88afd430c56a2f2ce1fd82e2681868f857da38eeb3191d7498"},
+ {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa8eb4584fc6cbe6a84d7a7864be3ed28e23e9fd2146aa8ef1814d579df91958"},
+ {file = "tokenizers-0.15.0.tar.gz", hash = "sha256:10c7e6e7b4cabd757da59e93f5f8d1126291d16f8b54f28510825ef56a3e5d0e"},
+]
+
+[package.dependencies]
+huggingface_hub = ">=0.16.4,<1.0"
+
+[package.extras]
+dev = ["tokenizers[testing]"]
+docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"]
+testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
+
+[[package]]
+name = "toml"
+version = "0.10.2"
+description = "Python Library for Tom's Obvious, Minimal Language"
+optional = false
+python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
+files = [
+ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
+ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
+]
+
+[[package]]
+name = "tomli"
+version = "2.0.1"
+description = "A lil' TOML parser"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
+ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+]
+
+[[package]]
+name = "toolz"
+version = "0.12.0"
+description = "List processing tools and functional utilities"
+optional = true
+python-versions = ">=3.5"
+files = [
+ {file = "toolz-0.12.0-py3-none-any.whl", hash = "sha256:2059bd4148deb1884bb0eb770a3cde70e7f954cfbbdc2285f1f2de01fd21eb6f"},
+ {file = "toolz-0.12.0.tar.gz", hash = "sha256:88c570861c440ee3f2f6037c4654613228ff40c93a6c25e0eba70d17282c6194"},
+]
+
+[[package]]
+name = "tornado"
+version = "6.4"
+description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
+optional = false
+python-versions = ">= 3.8"
+files = [
+ {file = "tornado-6.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:02ccefc7d8211e5a7f9e8bc3f9e5b0ad6262ba2fbb683a6443ecc804e5224ce0"},
+ {file = "tornado-6.4-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:27787de946a9cffd63ce5814c33f734c627a87072ec7eed71f7fc4417bb16263"},
+ {file = "tornado-6.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7894c581ecdcf91666a0912f18ce5e757213999e183ebfc2c3fdbf4d5bd764e"},
+ {file = "tornado-6.4-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e43bc2e5370a6a8e413e1e1cd0c91bedc5bd62a74a532371042a18ef19e10579"},
+ {file = "tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212"},
+ {file = "tornado-6.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fd03192e287fbd0899dd8f81c6fb9cbbc69194d2074b38f384cb6fa72b80e9c2"},
+ {file = "tornado-6.4-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:88b84956273fbd73420e6d4b8d5ccbe913c65d31351b4c004ae362eba06e1f78"},
+ {file = "tornado-6.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:71ddfc23a0e03ef2df1c1397d859868d158c8276a0603b96cf86892bff58149f"},
+ {file = "tornado-6.4-cp38-abi3-win32.whl", hash = "sha256:6f8a6c77900f5ae93d8b4ae1196472d0ccc2775cc1dfdc9e7727889145c45052"},
+ {file = "tornado-6.4-cp38-abi3-win_amd64.whl", hash = "sha256:10aeaa8006333433da48dec9fe417877f8bcc21f48dda8d661ae79da357b2a63"},
+ {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"},
+]
+
+[[package]]
+name = "tqdm"
+version = "4.66.1"
+description = "Fast, Extensible Progress Meter"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"},
+ {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
+notebook = ["ipywidgets (>=6)"]
+slack = ["slack-sdk"]
+telegram = ["requests"]
+
+[[package]]
+name = "traitlets"
+version = "5.14.0"
+description = "Traitlets Python configuration system"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "traitlets-5.14.0-py3-none-any.whl", hash = "sha256:f14949d23829023013c47df20b4a76ccd1a85effb786dc060f34de7948361b33"},
+ {file = "traitlets-5.14.0.tar.gz", hash = "sha256:fcdaa8ac49c04dfa0ed3ee3384ef6dfdb5d6f3741502be247279407679296772"},
+]
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
+test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"]
+
+[[package]]
+name = "typer"
+version = "0.9.0"
+description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"},
+ {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"},
+]
+
+[package.dependencies]
+click = ">=7.1.1,<9.0.0"
+typing-extensions = ">=3.7.4.3"
+
+[package.extras]
+all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
+dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"]
+doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
+test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
+
+[[package]]
+name = "types-chardet"
+version = "5.0.4.6"
+description = "Typing stubs for chardet"
+optional = false
+python-versions = "*"
+files = [
+ {file = "types-chardet-5.0.4.6.tar.gz", hash = "sha256:caf4c74cd13ccfd8b3313c314aba943b159de562a2573ed03137402b2bb37818"},
+ {file = "types_chardet-5.0.4.6-py3-none-any.whl", hash = "sha256:ea832d87e798abf1e4dfc73767807c2b7fee35d0003ae90348aea4ae00fb004d"},
+]
+
+[[package]]
+name = "types-protobuf"
+version = "4.24.0.4"
+description = "Typing stubs for protobuf"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "types-protobuf-4.24.0.4.tar.gz", hash = "sha256:57ab42cb171dfdba2c74bb5b50c250478538cc3c5ed95b8b368929ad0c9f90a5"},
+ {file = "types_protobuf-4.24.0.4-py3-none-any.whl", hash = "sha256:131ab7d0cbc9e444bc89c994141327dcce7bcaeded72b1acb72a94827eb9c7af"},
+]
+
+[[package]]
+name = "types-pyopenssl"
+version = "23.3.0.0"
+description = "Typing stubs for pyOpenSSL"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "types-pyOpenSSL-23.3.0.0.tar.gz", hash = "sha256:5ffb077fe70b699c88d5caab999ae80e192fe28bf6cda7989b7e79b1e4e2dcd3"},
+ {file = "types_pyOpenSSL-23.3.0.0-py3-none-any.whl", hash = "sha256:00171433653265843b7469ddb9f3c86d698668064cc33ef10537822156130ebf"},
+]
+
+[package.dependencies]
+cryptography = ">=35.0.0"
+
+[[package]]
+name = "types-python-dateutil"
+version = "2.8.19.14"
+description = "Typing stubs for python-dateutil"
+optional = false
+python-versions = "*"
+files = [
+ {file = "types-python-dateutil-2.8.19.14.tar.gz", hash = "sha256:1f4f10ac98bb8b16ade9dbee3518d9ace017821d94b057a425b069f834737f4b"},
+ {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"},
+]
+
+[[package]]
+name = "types-pytz"
+version = "2023.3.1.1"
+description = "Typing stubs for pytz"
+optional = false
+python-versions = "*"
+files = [
+ {file = "types-pytz-2023.3.1.1.tar.gz", hash = "sha256:cc23d0192cd49c8f6bba44ee0c81e4586a8f30204970fc0894d209a6b08dab9a"},
+ {file = "types_pytz-2023.3.1.1-py3-none-any.whl", hash = "sha256:1999a123a3dc0e39a2ef6d19f3f8584211de9e6a77fe7a0259f04a524e90a5cf"},
+]
+
+[[package]]
+name = "types-pyyaml"
+version = "6.0.12.12"
+description = "Typing stubs for PyYAML"
+optional = false
+python-versions = "*"
+files = [
+ {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"},
+ {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"},
+]
+
+[[package]]
+name = "types-redis"
+version = "4.6.0.11"
+description = "Typing stubs for redis"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "types-redis-4.6.0.11.tar.gz", hash = "sha256:c8cfc84635183deca2db4a528966c5566445fd3713983f0034fb0f5a09e0890d"},
+ {file = "types_redis-4.6.0.11-py3-none-any.whl", hash = "sha256:94fc61118601fb4f79206b33b9f4344acff7ca1d7bba67834987fb0efcf6a770"},
+]
+
+[package.dependencies]
+cryptography = ">=35.0.0"
+types-pyOpenSSL = "*"
+
+[[package]]
+name = "types-requests"
+version = "2.31.0.6"
+description = "Typing stubs for requests"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"},
+ {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"},
+]
+
+[package.dependencies]
+types-urllib3 = "*"
+
+[[package]]
+name = "types-toml"
+version = "0.10.8.7"
+description = "Typing stubs for toml"
+optional = false
+python-versions = "*"
+files = [
+ {file = "types-toml-0.10.8.7.tar.gz", hash = "sha256:58b0781c681e671ff0b5c0319309910689f4ab40e8a2431e205d70c94bb6efb1"},
+ {file = "types_toml-0.10.8.7-py3-none-any.whl", hash = "sha256:61951da6ad410794c97bec035d59376ce1cbf4453dc9b6f90477e81e4442d631"},
+]
+
+[[package]]
+name = "types-urllib3"
+version = "1.26.25.14"
+description = "Typing stubs for urllib3"
+optional = false
+python-versions = "*"
+files = [
+ {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"},
+ {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"},
+]
+
+[[package]]
+name = "typing"
+version = "3.7.4.3"
+description = "Type Hints for Python"
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "typing-3.7.4.3-py2-none-any.whl", hash = "sha256:283d868f5071ab9ad873e5e52268d611e851c870a2ba354193026f2dfb29d8b5"},
+ {file = "typing-3.7.4.3.tar.gz", hash = "sha256:1187fb9c82fd670d10aa07bbb6cfcfe4bdda42d6fab8d5134f04e8c4d0b71cc9"},
+]
+
+[[package]]
+name = "typing-extensions"
+version = "4.8.0"
+description = "Backported and Experimental Type Hints for Python 3.8+"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"},
+ {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"},
+]
+
+[[package]]
+name = "typing-inspect"
+version = "0.9.0"
+description = "Runtime inspection utilities for typing module."
+optional = false
+python-versions = "*"
+files = [
+ {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"},
+ {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"},
+]
+
+[package.dependencies]
+mypy-extensions = ">=0.3.0"
+typing-extensions = ">=3.7.4"
+
+[[package]]
+name = "tzdata"
+version = "2023.3"
+description = "Provider of IANA time zone data"
+optional = false
+python-versions = ">=2"
+files = [
+ {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"},
+ {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"},
+]
+
+[[package]]
+name = "tzlocal"
+version = "5.2"
+description = "tzinfo object for the local timezone"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "tzlocal-5.2-py3-none-any.whl", hash = "sha256:49816ef2fe65ea8ac19d19aa7a1ae0551c834303d5014c6d5a62e4cbda8047b8"},
+ {file = "tzlocal-5.2.tar.gz", hash = "sha256:8d399205578f1a9342816409cc1e46a93ebd5755e39ea2d85334bea911bf0e6e"},
+]
+
+[package.dependencies]
+"backports.zoneinfo" = {version = "*", markers = "python_version < \"3.9\""}
+tzdata = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"]
+
+[[package]]
+name = "update-checker"
+version = "0.18.0"
+description = "A python module that will check for package updates."
+optional = true
+python-versions = "*"
+files = [
+ {file = "update_checker-0.18.0-py3-none-any.whl", hash = "sha256:cbba64760a36fe2640d80d85306e8fe82b6816659190993b7bdabadee4d4bbfd"},
+ {file = "update_checker-0.18.0.tar.gz", hash = "sha256:6a2d45bb4ac585884a6b03f9eade9161cedd9e8111545141e9aa9058932acb13"},
+]
+
+[package.dependencies]
+requests = ">=2.3.0"
+
+[package.extras]
+dev = ["black", "flake8", "pytest (>=2.7.3)"]
+lint = ["black", "flake8"]
+test = ["pytest (>=2.7.3)"]
+
+[[package]]
+name = "upstash-redis"
+version = "0.15.0"
+description = "Serverless Redis SDK from Upstash"
+optional = true
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "upstash_redis-0.15.0-py3-none-any.whl", hash = "sha256:4a89913cb2bb2422610bc2a9c8d6b9a9d75d0674c22c5ea8037d35d343ee5846"},
+ {file = "upstash_redis-0.15.0.tar.gz", hash = "sha256:910f6a567142167b742c38efecfabf23f47e24fcbddb00a6b5845cb11064c3af"},
+]
+
+[package.dependencies]
+aiohttp = ">=3.8.4,<4.0.0"
+requests = ">=2.31.0,<3.0.0"
+
+[[package]]
+name = "uri-template"
+version = "1.3.0"
+description = "RFC 6570 URI Template Processor"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "uri-template-1.3.0.tar.gz", hash = "sha256:0e00f8eb65e18c7de20d595a14336e9f337ead580c70934141624b6d1ffdacc7"},
+ {file = "uri_template-1.3.0-py3-none-any.whl", hash = "sha256:a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363"},
+]
+
+[package.extras]
+dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"]
+
+[[package]]
+name = "urllib3"
+version = "1.26.18"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
+files = [
+ {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"},
+ {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"},
+]
+
+[package.extras]
+brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
+secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
+socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
+
+[[package]]
+name = "uuid"
+version = "1.30"
+description = "UUID object and generation functions (Python 2.3 or higher)"
+optional = true
+python-versions = "*"
+files = [
+ {file = "uuid-1.30.tar.gz", hash = "sha256:1f87cc004ac5120466f36c5beae48b4c48cc411968eed0eaecd3da82aa96193f"},
+]
+
+[[package]]
+name = "validators"
+version = "0.22.0"
+description = "Python Data Validation for Humansβ’"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "validators-0.22.0-py3-none-any.whl", hash = "sha256:61cf7d4a62bbae559f2e54aed3b000cea9ff3e2fdbe463f51179b92c58c9585a"},
+ {file = "validators-0.22.0.tar.gz", hash = "sha256:77b2689b172eeeb600d9605ab86194641670cdb73b60afd577142a9397873370"},
+]
+
+[package.extras]
+docs-offline = ["myst-parser (>=2.0.0)", "pypandoc-binary (>=1.11)", "sphinx (>=7.1.1)"]
+docs-online = ["mkdocs (>=1.5.2)", "mkdocs-git-revision-date-localized-plugin (>=1.2.0)", "mkdocs-material (>=9.2.6)", "mkdocstrings[python] (>=0.22.0)", "pyaml (>=23.7.0)"]
+hooks = ["pre-commit (>=3.3.3)"]
+package = ["build (>=1.0.0)", "twine (>=4.0.2)"]
+runner = ["tox (>=4.11.1)"]
+sast = ["bandit[toml] (>=1.7.5)"]
+testing = ["pytest (>=7.4.0)"]
+tooling = ["black (>=23.7.0)", "pyright (>=1.1.325)", "ruff (>=0.0.287)"]
+tooling-extras = ["pyaml (>=23.7.0)", "pypandoc-binary (>=1.11)", "pytest (>=7.4.0)"]
+
+[[package]]
+name = "vcrpy"
+version = "5.1.0"
+description = "Automatically mock your HTTP interactions to simplify and speed up testing"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "vcrpy-5.1.0-py2.py3-none-any.whl", hash = "sha256:605e7b7a63dcd940db1df3ab2697ca7faf0e835c0852882142bafb19649d599e"},
+ {file = "vcrpy-5.1.0.tar.gz", hash = "sha256:bbf1532f2618a04f11bce2a99af3a9647a32c880957293ff91e0a5f187b6b3d2"},
+]
+
+[package.dependencies]
+PyYAML = "*"
+urllib3 = {version = "<2", markers = "python_version < \"3.10\""}
+wrapt = "*"
+yarl = "*"
+
+[[package]]
+name = "watchdog"
+version = "3.0.0"
+description = "Filesystem events monitoring"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:336adfc6f5cc4e037d52db31194f7581ff744b67382eb6021c868322e32eef41"},
+ {file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a70a8dcde91be523c35b2bf96196edc5730edb347e374c7de7cd20c43ed95397"},
+ {file = "watchdog-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adfdeab2da79ea2f76f87eb42a3ab1966a5313e5a69a0213a3cc06ef692b0e96"},
+ {file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2b57a1e730af3156d13b7fdddfc23dea6487fceca29fc75c5a868beed29177ae"},
+ {file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7ade88d0d778b1b222adebcc0927428f883db07017618a5e684fd03b83342bd9"},
+ {file = "watchdog-3.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7e447d172af52ad204d19982739aa2346245cc5ba6f579d16dac4bfec226d2e7"},
+ {file = "watchdog-3.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9fac43a7466eb73e64a9940ac9ed6369baa39b3bf221ae23493a9ec4d0022674"},
+ {file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8ae9cda41fa114e28faf86cb137d751a17ffd0316d1c34ccf2235e8a84365c7f"},
+ {file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f70b4aa53bd743729c7475d7ec41093a580528b100e9a8c5b5efe8899592fc"},
+ {file = "watchdog-3.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4f94069eb16657d2c6faada4624c39464f65c05606af50bb7902e036e3219be3"},
+ {file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7c5f84b5194c24dd573fa6472685b2a27cc5a17fe5f7b6fd40345378ca6812e3"},
+ {file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3aa7f6a12e831ddfe78cdd4f8996af9cf334fd6346531b16cec61c3b3c0d8da0"},
+ {file = "watchdog-3.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:233b5817932685d39a7896b1090353fc8efc1ef99c9c054e46c8002561252fb8"},
+ {file = "watchdog-3.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13bbbb462ee42ec3c5723e1205be8ced776f05b100e4737518c67c8325cf6100"},
+ {file = "watchdog-3.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8f3ceecd20d71067c7fd4c9e832d4e22584318983cabc013dbf3f70ea95de346"},
+ {file = "watchdog-3.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c9d8c8ec7efb887333cf71e328e39cffbf771d8f8f95d308ea4125bf5f90ba64"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0e06ab8858a76e1219e68c7573dfeba9dd1c0219476c5a44d5333b01d7e1743a"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:d00e6be486affb5781468457b21a6cbe848c33ef43f9ea4a73b4882e5f188a44"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:c07253088265c363d1ddf4b3cdb808d59a0468ecd017770ed716991620b8f77a"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:5113334cf8cf0ac8cd45e1f8309a603291b614191c9add34d33075727a967709"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:51f90f73b4697bac9c9a78394c3acbbd331ccd3655c11be1a15ae6fe289a8c83"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:ba07e92756c97e3aca0912b5cbc4e5ad802f4557212788e72a72a47ff376950d"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d429c2430c93b7903914e4db9a966c7f2b068dd2ebdd2fa9b9ce094c7d459f33"},
+ {file = "watchdog-3.0.0-py3-none-win32.whl", hash = "sha256:3ed7c71a9dccfe838c2f0b6314ed0d9b22e77d268c67e015450a29036a81f60f"},
+ {file = "watchdog-3.0.0-py3-none-win_amd64.whl", hash = "sha256:4c9956d27be0bb08fc5f30d9d0179a855436e655f046d288e2bcc11adfae893c"},
+ {file = "watchdog-3.0.0-py3-none-win_ia64.whl", hash = "sha256:5d9f3a10e02d7371cd929b5d8f11e87d4bad890212ed3901f9b4d68767bee759"},
+ {file = "watchdog-3.0.0.tar.gz", hash = "sha256:4d98a320595da7a7c5a18fc48cb633c2e73cda78f93cac2ef42d42bf609a33f9"},
+]
+
+[package.extras]
+watchmedo = ["PyYAML (>=3.10)"]
+
+[[package]]
+name = "wcwidth"
+version = "0.2.12"
+description = "Measures the displayed width of unicode strings in a terminal"
+optional = false
+python-versions = "*"
+files = [
+ {file = "wcwidth-0.2.12-py2.py3-none-any.whl", hash = "sha256:f26ec43d96c8cbfed76a5075dac87680124fa84e0855195a6184da9c187f133c"},
+ {file = "wcwidth-0.2.12.tar.gz", hash = "sha256:f01c104efdf57971bcb756f054dd58ddec5204dd15fa31d6503ea57947d97c02"},
+]
+
+[[package]]
+name = "webcolors"
+version = "1.13"
+description = "A library for working with the color formats defined by HTML and CSS."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "webcolors-1.13-py3-none-any.whl", hash = "sha256:29bc7e8752c0a1bd4a1f03c14d6e6a72e93d82193738fa860cbff59d0fcc11bf"},
+ {file = "webcolors-1.13.tar.gz", hash = "sha256:c225b674c83fa923be93d235330ce0300373d02885cef23238813b0d5668304a"},
+]
+
+[package.extras]
+docs = ["furo", "sphinx", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-notfound-page", "sphinxext-opengraph"]
+tests = ["pytest", "pytest-cov"]
+
+[[package]]
+name = "webencodings"
+version = "0.5.1"
+description = "Character encoding aliases for legacy web content"
+optional = false
+python-versions = "*"
+files = [
+ {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"},
+ {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"},
+]
+
+[[package]]
+name = "websocket-client"
+version = "1.7.0"
+description = "WebSocket client for Python with low level API options"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "websocket-client-1.7.0.tar.gz", hash = "sha256:10e511ea3a8c744631d3bd77e61eb17ed09304c413ad42cf6ddfa4c7787e8fe6"},
+ {file = "websocket_client-1.7.0-py3-none-any.whl", hash = "sha256:f4c3d22fec12a2461427a29957ff07d35098ee2d976d3ba244e688b8b4057588"},
+]
+
+[package.extras]
+docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"]
+optional = ["python-socks", "wsaccel"]
+test = ["websockets"]
+
+[[package]]
+name = "websockets"
+version = "12.0"
+description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "websockets-12.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d554236b2a2006e0ce16315c16eaa0d628dab009c33b63ea03f41c6107958374"},
+ {file = "websockets-12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2d225bb6886591b1746b17c0573e29804619c8f755b5598d875bb4235ea639be"},
+ {file = "websockets-12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb809e816916a3b210bed3c82fb88eaf16e8afcf9c115ebb2bacede1797d2547"},
+ {file = "websockets-12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c588f6abc13f78a67044c6b1273a99e1cf31038ad51815b3b016ce699f0d75c2"},
+ {file = "websockets-12.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aa9348186d79a5f232115ed3fa9020eab66d6c3437d72f9d2c8ac0c6858c558"},
+ {file = "websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6350b14a40c95ddd53e775dbdbbbc59b124a5c8ecd6fbb09c2e52029f7a9f480"},
+ {file = "websockets-12.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:70ec754cc2a769bcd218ed8d7209055667b30860ffecb8633a834dde27d6307c"},
+ {file = "websockets-12.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6e96f5ed1b83a8ddb07909b45bd94833b0710f738115751cdaa9da1fb0cb66e8"},
+ {file = "websockets-12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d87be612cbef86f994178d5186add3d94e9f31cc3cb499a0482b866ec477603"},
+ {file = "websockets-12.0-cp310-cp310-win32.whl", hash = "sha256:befe90632d66caaf72e8b2ed4d7f02b348913813c8b0a32fae1cc5fe3730902f"},
+ {file = "websockets-12.0-cp310-cp310-win_amd64.whl", hash = "sha256:363f57ca8bc8576195d0540c648aa58ac18cf85b76ad5202b9f976918f4219cf"},
+ {file = "websockets-12.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5d873c7de42dea355d73f170be0f23788cf3fa9f7bed718fd2830eefedce01b4"},
+ {file = "websockets-12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3f61726cae9f65b872502ff3c1496abc93ffbe31b278455c418492016e2afc8f"},
+ {file = "websockets-12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2fcf7a07334c77fc8a230755c2209223a7cc44fc27597729b8ef5425aa61a3"},
+ {file = "websockets-12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e332c210b14b57904869ca9f9bf4ca32f5427a03eeb625da9b616c85a3a506c"},
+ {file = "websockets-12.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5693ef74233122f8ebab026817b1b37fe25c411ecfca084b29bc7d6efc548f45"},
+ {file = "websockets-12.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e9e7db18b4539a29cc5ad8c8b252738a30e2b13f033c2d6e9d0549b45841c04"},
+ {file = "websockets-12.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6e2df67b8014767d0f785baa98393725739287684b9f8d8a1001eb2839031447"},
+ {file = "websockets-12.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bea88d71630c5900690fcb03161ab18f8f244805c59e2e0dc4ffadae0a7ee0ca"},
+ {file = "websockets-12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dff6cdf35e31d1315790149fee351f9e52978130cef6c87c4b6c9b3baf78bc53"},
+ {file = "websockets-12.0-cp311-cp311-win32.whl", hash = "sha256:3e3aa8c468af01d70332a382350ee95f6986db479ce7af14d5e81ec52aa2b402"},
+ {file = "websockets-12.0-cp311-cp311-win_amd64.whl", hash = "sha256:25eb766c8ad27da0f79420b2af4b85d29914ba0edf69f547cc4f06ca6f1d403b"},
+ {file = "websockets-12.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0e6e2711d5a8e6e482cacb927a49a3d432345dfe7dea8ace7b5790df5932e4df"},
+ {file = "websockets-12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:dbcf72a37f0b3316e993e13ecf32f10c0e1259c28ffd0a85cee26e8549595fbc"},
+ {file = "websockets-12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12743ab88ab2af1d17dd4acb4645677cb7063ef4db93abffbf164218a5d54c6b"},
+ {file = "websockets-12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b645f491f3c48d3f8a00d1fce07445fab7347fec54a3e65f0725d730d5b99cb"},
+ {file = "websockets-12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9893d1aa45a7f8b3bc4510f6ccf8db8c3b62120917af15e3de247f0780294b92"},
+ {file = "websockets-12.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f38a7b376117ef7aff996e737583172bdf535932c9ca021746573bce40165ed"},
+ {file = "websockets-12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f764ba54e33daf20e167915edc443b6f88956f37fb606449b4a5b10ba42235a5"},
+ {file = "websockets-12.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1e4b3f8ea6a9cfa8be8484c9221ec0257508e3a1ec43c36acdefb2a9c3b00aa2"},
+ {file = "websockets-12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9fdf06fd06c32205a07e47328ab49c40fc1407cdec801d698a7c41167ea45113"},
+ {file = "websockets-12.0-cp312-cp312-win32.whl", hash = "sha256:baa386875b70cbd81798fa9f71be689c1bf484f65fd6fb08d051a0ee4e79924d"},
+ {file = "websockets-12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae0a5da8f35a5be197f328d4727dbcfafa53d1824fac3d96cdd3a642fe09394f"},
+ {file = "websockets-12.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5f6ffe2c6598f7f7207eef9a1228b6f5c818f9f4d53ee920aacd35cec8110438"},
+ {file = "websockets-12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9edf3fc590cc2ec20dc9d7a45108b5bbaf21c0d89f9fd3fd1685e223771dc0b2"},
+ {file = "websockets-12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8572132c7be52632201a35f5e08348137f658e5ffd21f51f94572ca6c05ea81d"},
+ {file = "websockets-12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604428d1b87edbf02b233e2c207d7d528460fa978f9e391bd8aaf9c8311de137"},
+ {file = "websockets-12.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a9d160fd080c6285e202327aba140fc9a0d910b09e423afff4ae5cbbf1c7205"},
+ {file = "websockets-12.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87b4aafed34653e465eb77b7c93ef058516cb5acf3eb21e42f33928616172def"},
+ {file = "websockets-12.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b2ee7288b85959797970114deae81ab41b731f19ebcd3bd499ae9ca0e3f1d2c8"},
+ {file = "websockets-12.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7fa3d25e81bfe6a89718e9791128398a50dec6d57faf23770787ff441d851967"},
+ {file = "websockets-12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a571f035a47212288e3b3519944f6bf4ac7bc7553243e41eac50dd48552b6df7"},
+ {file = "websockets-12.0-cp38-cp38-win32.whl", hash = "sha256:3c6cc1360c10c17463aadd29dd3af332d4a1adaa8796f6b0e9f9df1fdb0bad62"},
+ {file = "websockets-12.0-cp38-cp38-win_amd64.whl", hash = "sha256:1bf386089178ea69d720f8db6199a0504a406209a0fc23e603b27b300fdd6892"},
+ {file = "websockets-12.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab3d732ad50a4fbd04a4490ef08acd0517b6ae6b77eb967251f4c263011a990d"},
+ {file = "websockets-12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1d9697f3337a89691e3bd8dc56dea45a6f6d975f92e7d5f773bc715c15dde28"},
+ {file = "websockets-12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1df2fbd2c8a98d38a66f5238484405b8d1d16f929bb7a33ed73e4801222a6f53"},
+ {file = "websockets-12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23509452b3bc38e3a057382c2e941d5ac2e01e251acce7adc74011d7d8de434c"},
+ {file = "websockets-12.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e5fc14ec6ea568200ea4ef46545073da81900a2b67b3e666f04adf53ad452ec"},
+ {file = "websockets-12.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e71dbbd12850224243f5d2aeec90f0aaa0f2dde5aeeb8fc8df21e04d99eff9"},
+ {file = "websockets-12.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b81f90dcc6c85a9b7f29873beb56c94c85d6f0dac2ea8b60d995bd18bf3e2aae"},
+ {file = "websockets-12.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a02413bc474feda2849c59ed2dfb2cddb4cd3d2f03a2fedec51d6e959d9b608b"},
+ {file = "websockets-12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bbe6013f9f791944ed31ca08b077e26249309639313fff132bfbf3ba105673b9"},
+ {file = "websockets-12.0-cp39-cp39-win32.whl", hash = "sha256:cbe83a6bbdf207ff0541de01e11904827540aa069293696dd528a6640bd6a5f6"},
+ {file = "websockets-12.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc4e7fa5414512b481a2483775a8e8be7803a35b30ca805afa4998a84f9fd9e8"},
+ {file = "websockets-12.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:248d8e2446e13c1d4326e0a6a4e9629cb13a11195051a73acf414812700badbd"},
+ {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f44069528d45a933997a6fef143030d8ca8042f0dfaad753e2906398290e2870"},
+ {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e37d36f0d19f0a4413d3e18c0d03d0c268ada2061868c1e6f5ab1a6d575077"},
+ {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d829f975fc2e527a3ef2f9c8f25e553eb7bc779c6665e8e1d52aa22800bb38b"},
+ {file = "websockets-12.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2c71bd45a777433dd9113847af751aae36e448bc6b8c361a566cb043eda6ec30"},
+ {file = "websockets-12.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0bee75f400895aef54157b36ed6d3b308fcab62e5260703add87f44cee9c82a6"},
+ {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:423fc1ed29f7512fceb727e2d2aecb952c46aa34895e9ed96071821309951123"},
+ {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27a5e9964ef509016759f2ef3f2c1e13f403725a5e6a1775555994966a66e931"},
+ {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3181df4583c4d3994d31fb235dc681d2aaad744fbdbf94c4802485ececdecf2"},
+ {file = "websockets-12.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b067cb952ce8bf40115f6c19f478dc71c5e719b7fbaa511359795dfd9d1a6468"},
+ {file = "websockets-12.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:00700340c6c7ab788f176d118775202aadea7602c5cc6be6ae127761c16d6b0b"},
+ {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e469d01137942849cff40517c97a30a93ae79917752b34029f0ec72df6b46399"},
+ {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffefa1374cd508d633646d51a8e9277763a9b78ae71324183693959cf94635a7"},
+ {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba0cab91b3956dfa9f512147860783a1829a8d905ee218a9837c18f683239611"},
+ {file = "websockets-12.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2cb388a5bfb56df4d9a406783b7f9dbefb888c09b71629351cc6b036e9259370"},
+ {file = "websockets-12.0-py3-none-any.whl", hash = "sha256:dc284bbc8d7c78a6c69e0c7325ab46ee5e40bb4d50e494d8131a07ef47500e9e"},
+ {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"},
+]
+
+[[package]]
+name = "widgetsnbextension"
+version = "4.0.9"
+description = "Jupyter interactive widgets for Jupyter Notebook"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"},
+ {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"},
+]
+
+[[package]]
+name = "wrapt"
+version = "1.16.0"
+description = "Module for decorators, wrappers and monkey patching."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "wrapt-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4"},
+ {file = "wrapt-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020"},
+ {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440"},
+ {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487"},
+ {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf"},
+ {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72"},
+ {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0"},
+ {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136"},
+ {file = "wrapt-1.16.0-cp310-cp310-win32.whl", hash = "sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d"},
+ {file = "wrapt-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2"},
+ {file = "wrapt-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09"},
+ {file = "wrapt-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d"},
+ {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389"},
+ {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060"},
+ {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1"},
+ {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3"},
+ {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956"},
+ {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d"},
+ {file = "wrapt-1.16.0-cp311-cp311-win32.whl", hash = "sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362"},
+ {file = "wrapt-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89"},
+ {file = "wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b"},
+ {file = "wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36"},
+ {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73"},
+ {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809"},
+ {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b"},
+ {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81"},
+ {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9"},
+ {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c"},
+ {file = "wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc"},
+ {file = "wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8"},
+ {file = "wrapt-1.16.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8"},
+ {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39"},
+ {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c"},
+ {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40"},
+ {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc"},
+ {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e"},
+ {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465"},
+ {file = "wrapt-1.16.0-cp36-cp36m-win32.whl", hash = "sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e"},
+ {file = "wrapt-1.16.0-cp36-cp36m-win_amd64.whl", hash = "sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966"},
+ {file = "wrapt-1.16.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593"},
+ {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292"},
+ {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5"},
+ {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf"},
+ {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228"},
+ {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f"},
+ {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c"},
+ {file = "wrapt-1.16.0-cp37-cp37m-win32.whl", hash = "sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c"},
+ {file = "wrapt-1.16.0-cp37-cp37m-win_amd64.whl", hash = "sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00"},
+ {file = "wrapt-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0"},
+ {file = "wrapt-1.16.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202"},
+ {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0"},
+ {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e"},
+ {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f"},
+ {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267"},
+ {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca"},
+ {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6"},
+ {file = "wrapt-1.16.0-cp38-cp38-win32.whl", hash = "sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b"},
+ {file = "wrapt-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41"},
+ {file = "wrapt-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2"},
+ {file = "wrapt-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb"},
+ {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8"},
+ {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c"},
+ {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a"},
+ {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664"},
+ {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f"},
+ {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537"},
+ {file = "wrapt-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3"},
+ {file = "wrapt-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35"},
+ {file = "wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1"},
+ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"},
+]
+
+[[package]]
+name = "xata"
+version = "1.2.0"
+description = "Python SDK for Xata.io"
+optional = true
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "xata-1.2.0-py3-none-any.whl", hash = "sha256:a3710a273c0b64464080e332e24a1754a7fc9076a4117af558353c57f25c23e1"},
+ {file = "xata-1.2.0.tar.gz", hash = "sha256:048bf24c8aa3d09241dbe9f2a31513ce62c75c06ea3aa5822f000d2eac116462"},
+]
+
+[package.dependencies]
+deprecation = ">=2.1.0,<3.0.0"
+orjson = ">=3.8.1,<4.0.0"
+python-dotenv = ">=0.21,<2.0"
+requests = ">=2.28.1,<3.0.0"
+
+[[package]]
+name = "xmltodict"
+version = "0.13.0"
+description = "Makes working with XML feel like you are working with JSON"
+optional = true
+python-versions = ">=3.4"
+files = [
+ {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
+ {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
+]
+
+[[package]]
+name = "xxhash"
+version = "3.4.1"
+description = "Python binding for xxHash"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "xxhash-3.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:91dbfa55346ad3e18e738742236554531a621042e419b70ad8f3c1d9c7a16e7f"},
+ {file = "xxhash-3.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:665a65c2a48a72068fcc4d21721510df5f51f1142541c890491afc80451636d2"},
+ {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb11628470a6004dc71a09fe90c2f459ff03d611376c1debeec2d648f44cb693"},
+ {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5bef2a7dc7b4f4beb45a1edbba9b9194c60a43a89598a87f1a0226d183764189"},
+ {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c0f7b2d547d72c7eda7aa817acf8791f0146b12b9eba1d4432c531fb0352228"},
+ {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00f2fdef6b41c9db3d2fc0e7f94cb3db86693e5c45d6de09625caad9a469635b"},
+ {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:23cfd9ca09acaf07a43e5a695143d9a21bf00f5b49b15c07d5388cadf1f9ce11"},
+ {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6a9ff50a3cf88355ca4731682c168049af1ca222d1d2925ef7119c1a78e95b3b"},
+ {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f1d7c69a1e9ca5faa75546fdd267f214f63f52f12692f9b3a2f6467c9e67d5e7"},
+ {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:672b273040d5d5a6864a36287f3514efcd1d4b1b6a7480f294c4b1d1ee1b8de0"},
+ {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4178f78d70e88f1c4a89ff1ffe9f43147185930bb962ee3979dba15f2b1cc799"},
+ {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9804b9eb254d4b8cc83ab5a2002128f7d631dd427aa873c8727dba7f1f0d1c2b"},
+ {file = "xxhash-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c09c49473212d9c87261d22c74370457cfff5db2ddfc7fd1e35c80c31a8c14ce"},
+ {file = "xxhash-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:ebbb1616435b4a194ce3466d7247df23499475c7ed4eb2681a1fa42ff766aff6"},
+ {file = "xxhash-3.4.1-cp310-cp310-win_arm64.whl", hash = "sha256:25dc66be3db54f8a2d136f695b00cfe88018e59ccff0f3b8f545869f376a8a46"},
+ {file = "xxhash-3.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58c49083801885273e262c0f5bbeac23e520564b8357fbb18fb94ff09d3d3ea5"},
+ {file = "xxhash-3.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b526015a973bfbe81e804a586b703f163861da36d186627e27524f5427b0d520"},
+ {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36ad4457644c91a966f6fe137d7467636bdc51a6ce10a1d04f365c70d6a16d7e"},
+ {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:248d3e83d119770f96003271fe41e049dd4ae52da2feb8f832b7a20e791d2920"},
+ {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2070b6d5bbef5ee031666cf21d4953c16e92c2f8a24a94b5c240f8995ba3b1d0"},
+ {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2746035f518f0410915e247877f7df43ef3372bf36cfa52cc4bc33e85242641"},
+ {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a8ba6181514681c2591840d5632fcf7356ab287d4aff1c8dea20f3c78097088"},
+ {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aac5010869240e95f740de43cd6a05eae180c59edd182ad93bf12ee289484fa"},
+ {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4cb11d8debab1626181633d184b2372aaa09825bde709bf927704ed72765bed1"},
+ {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b29728cff2c12f3d9f1d940528ee83918d803c0567866e062683f300d1d2eff3"},
+ {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:a15cbf3a9c40672523bdb6ea97ff74b443406ba0ab9bca10ceccd9546414bd84"},
+ {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e66df260fed01ed8ea790c2913271641c58481e807790d9fca8bfd5a3c13844"},
+ {file = "xxhash-3.4.1-cp311-cp311-win32.whl", hash = "sha256:e867f68a8f381ea12858e6d67378c05359d3a53a888913b5f7d35fbf68939d5f"},
+ {file = "xxhash-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:200a5a3ad9c7c0c02ed1484a1d838b63edcf92ff538770ea07456a3732c577f4"},
+ {file = "xxhash-3.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:1d03f1c0d16d24ea032e99f61c552cb2b77d502e545187338bea461fde253583"},
+ {file = "xxhash-3.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c4bbba9b182697a52bc0c9f8ec0ba1acb914b4937cd4a877ad78a3b3eeabefb3"},
+ {file = "xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9fd28a9da300e64e434cfc96567a8387d9a96e824a9be1452a1e7248b7763b78"},
+ {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6066d88c9329ab230e18998daec53d819daeee99d003955c8db6fc4971b45ca3"},
+ {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93805bc3233ad89abf51772f2ed3355097a5dc74e6080de19706fc447da99cd3"},
+ {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64da57d5ed586ebb2ecdde1e997fa37c27fe32fe61a656b77fabbc58e6fbff6e"},
+ {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a97322e9a7440bf3c9805cbaac090358b43f650516486746f7fa482672593df"},
+ {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbe750d512982ee7d831838a5dee9e9848f3fb440e4734cca3f298228cc957a6"},
+ {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fd79d4087727daf4d5b8afe594b37d611ab95dc8e29fe1a7517320794837eb7d"},
+ {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:743612da4071ff9aa4d055f3f111ae5247342931dedb955268954ef7201a71ff"},
+ {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:b41edaf05734092f24f48c0958b3c6cbaaa5b7e024880692078c6b1f8247e2fc"},
+ {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:a90356ead70d715fe64c30cd0969072de1860e56b78adf7c69d954b43e29d9fa"},
+ {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac56eebb364e44c85e1d9e9cc5f6031d78a34f0092fea7fc80478139369a8b4a"},
+ {file = "xxhash-3.4.1-cp312-cp312-win32.whl", hash = "sha256:911035345932a153c427107397c1518f8ce456f93c618dd1c5b54ebb22e73747"},
+ {file = "xxhash-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:f31ce76489f8601cc7b8713201ce94b4bd7b7ce90ba3353dccce7e9e1fee71fa"},
+ {file = "xxhash-3.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:b5beb1c6a72fdc7584102f42c4d9df232ee018ddf806e8c90906547dfb43b2da"},
+ {file = "xxhash-3.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6d42b24d1496deb05dee5a24ed510b16de1d6c866c626c2beb11aebf3be278b9"},
+ {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b685fab18876b14a8f94813fa2ca80cfb5ab6a85d31d5539b7cd749ce9e3624"},
+ {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:419ffe34c17ae2df019a4685e8d3934d46b2e0bbe46221ab40b7e04ed9f11137"},
+ {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e041ce5714f95251a88670c114b748bca3bf80cc72400e9f23e6d0d59cf2681"},
+ {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc860d887c5cb2f524899fb8338e1bb3d5789f75fac179101920d9afddef284b"},
+ {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:312eba88ffe0a05e332e3a6f9788b73883752be63f8588a6dc1261a3eaaaf2b2"},
+ {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e01226b6b6a1ffe4e6bd6d08cfcb3ca708b16f02eb06dd44f3c6e53285f03e4f"},
+ {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9f3025a0d5d8cf406a9313cd0d5789c77433ba2004b1c75439b67678e5136537"},
+ {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:6d3472fd4afef2a567d5f14411d94060099901cd8ce9788b22b8c6f13c606a93"},
+ {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:43984c0a92f06cac434ad181f329a1445017c33807b7ae4f033878d860a4b0f2"},
+ {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a55e0506fdb09640a82ec4f44171273eeabf6f371a4ec605633adb2837b5d9d5"},
+ {file = "xxhash-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:faec30437919555b039a8bdbaba49c013043e8f76c999670aef146d33e05b3a0"},
+ {file = "xxhash-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:c9e1b646af61f1fc7083bb7b40536be944f1ac67ef5e360bca2d73430186971a"},
+ {file = "xxhash-3.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:961d948b7b1c1b6c08484bbce3d489cdf153e4122c3dfb07c2039621243d8795"},
+ {file = "xxhash-3.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:719a378930504ab159f7b8e20fa2aa1896cde050011af838af7e7e3518dd82de"},
+ {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74fb5cb9406ccd7c4dd917f16630d2e5e8cbbb02fc2fca4e559b2a47a64f4940"},
+ {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dab508ac39e0ab988039bc7f962c6ad021acd81fd29145962b068df4148c476"},
+ {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c59f3e46e7daf4c589e8e853d700ef6607afa037bfad32c390175da28127e8c"},
+ {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cc07256eff0795e0f642df74ad096f8c5d23fe66bc138b83970b50fc7f7f6c5"},
+ {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9f749999ed80f3955a4af0eb18bb43993f04939350b07b8dd2f44edc98ffee9"},
+ {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7688d7c02149a90a3d46d55b341ab7ad1b4a3f767be2357e211b4e893efbaaf6"},
+ {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a8b4977963926f60b0d4f830941c864bed16aa151206c01ad5c531636da5708e"},
+ {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:8106d88da330f6535a58a8195aa463ef5281a9aa23b04af1848ff715c4398fb4"},
+ {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4c76a77dbd169450b61c06fd2d5d436189fc8ab7c1571d39265d4822da16df22"},
+ {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:11f11357c86d83e53719c592021fd524efa9cf024dc7cb1dfb57bbbd0d8713f2"},
+ {file = "xxhash-3.4.1-cp38-cp38-win32.whl", hash = "sha256:0c786a6cd74e8765c6809892a0d45886e7c3dc54de4985b4a5eb8b630f3b8e3b"},
+ {file = "xxhash-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:aabf37fb8fa27430d50507deeab2ee7b1bcce89910dd10657c38e71fee835594"},
+ {file = "xxhash-3.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6127813abc1477f3a83529b6bbcfeddc23162cece76fa69aee8f6a8a97720562"},
+ {file = "xxhash-3.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef2e194262f5db16075caea7b3f7f49392242c688412f386d3c7b07c7733a70a"},
+ {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71be94265b6c6590f0018bbf73759d21a41c6bda20409782d8117e76cd0dfa8b"},
+ {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10e0a619cdd1c0980e25eb04e30fe96cf8f4324758fa497080af9c21a6de573f"},
+ {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa122124d2e3bd36581dd78c0efa5f429f5220313479fb1072858188bc2d5ff1"},
+ {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17032f5a4fea0a074717fe33477cb5ee723a5f428de7563e75af64bfc1b1e10"},
+ {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca7783b20e3e4f3f52f093538895863f21d18598f9a48211ad757680c3bd006f"},
+ {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d77d09a1113899fad5f354a1eb4f0a9afcf58cefff51082c8ad643ff890e30cf"},
+ {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:21287bcdd299fdc3328cc0fbbdeaa46838a1c05391264e51ddb38a3f5b09611f"},
+ {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:dfd7a6cc483e20b4ad90224aeb589e64ec0f31e5610ab9957ff4314270b2bf31"},
+ {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:543c7fcbc02bbb4840ea9915134e14dc3dc15cbd5a30873a7a5bf66039db97ec"},
+ {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fe0a98d990e433013f41827b62be9ab43e3cf18e08b1483fcc343bda0d691182"},
+ {file = "xxhash-3.4.1-cp39-cp39-win32.whl", hash = "sha256:b9097af00ebf429cc7c0e7d2fdf28384e4e2e91008130ccda8d5ae653db71e54"},
+ {file = "xxhash-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:d699b921af0dcde50ab18be76c0d832f803034d80470703700cb7df0fbec2832"},
+ {file = "xxhash-3.4.1-cp39-cp39-win_arm64.whl", hash = "sha256:2be491723405e15cc099ade1280133ccfbf6322d2ef568494fb7d07d280e7eee"},
+ {file = "xxhash-3.4.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:431625fad7ab5649368c4849d2b49a83dc711b1f20e1f7f04955aab86cd307bc"},
+ {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc6dbd5fc3c9886a9e041848508b7fb65fd82f94cc793253990f81617b61fe49"},
+ {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ff8dbd0ec97aec842476cb8ccc3e17dd288cd6ce3c8ef38bff83d6eb927817"},
+ {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef73a53fe90558a4096e3256752268a8bdc0322f4692ed928b6cd7ce06ad4fe3"},
+ {file = "xxhash-3.4.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:450401f42bbd274b519d3d8dcf3c57166913381a3d2664d6609004685039f9d3"},
+ {file = "xxhash-3.4.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a162840cf4de8a7cd8720ff3b4417fbc10001eefdd2d21541a8226bb5556e3bb"},
+ {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b736a2a2728ba45017cb67785e03125a79d246462dfa892d023b827007412c52"},
+ {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d0ae4c2e7698adef58710d6e7a32ff518b66b98854b1c68e70eee504ad061d8"},
+ {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6322c4291c3ff174dcd104fae41500e75dad12be6f3085d119c2c8a80956c51"},
+ {file = "xxhash-3.4.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:dd59ed668801c3fae282f8f4edadf6dc7784db6d18139b584b6d9677ddde1b6b"},
+ {file = "xxhash-3.4.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92693c487e39523a80474b0394645b393f0ae781d8db3474ccdcead0559ccf45"},
+ {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4603a0f642a1e8d7f3ba5c4c25509aca6a9c1cc16f85091004a7028607ead663"},
+ {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa45e8cbfbadb40a920fe9ca40c34b393e0b067082d94006f7f64e70c7490a6"},
+ {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:595b252943b3552de491ff51e5bb79660f84f033977f88f6ca1605846637b7c6"},
+ {file = "xxhash-3.4.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:562d8b8f783c6af969806aaacf95b6c7b776929ae26c0cd941d54644ea7ef51e"},
+ {file = "xxhash-3.4.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:41ddeae47cf2828335d8d991f2d2b03b0bdc89289dc64349d712ff8ce59d0647"},
+ {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c44d584afdf3c4dbb3277e32321d1a7b01d6071c1992524b6543025fb8f4206f"},
+ {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd7bddb3a5b86213cc3f2c61500c16945a1b80ecd572f3078ddbbe68f9dabdfb"},
+ {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9ecb6c987b62437c2f99c01e97caf8d25660bf541fe79a481d05732e5236719c"},
+ {file = "xxhash-3.4.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:696b4e18b7023527d5c50ed0626ac0520edac45a50ec7cf3fc265cd08b1f4c03"},
+ {file = "xxhash-3.4.1.tar.gz", hash = "sha256:0379d6cf1ff987cd421609a264ce025e74f346e3e145dd106c0cc2e3ec3f99a9"},
+]
+
+[[package]]
+name = "yarl"
+version = "1.9.3"
+description = "Yet another URL library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "yarl-1.9.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:32435d134414e01d937cd9d6cc56e8413a8d4741dea36af5840c7750f04d16ab"},
+ {file = "yarl-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9a5211de242754b5e612557bca701f39f8b1a9408dff73c6db623f22d20f470e"},
+ {file = "yarl-1.9.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:525cd69eff44833b01f8ef39aa33a9cc53a99ff7f9d76a6ef6a9fb758f54d0ff"},
+ {file = "yarl-1.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc94441bcf9cb8c59f51f23193316afefbf3ff858460cb47b5758bf66a14d130"},
+ {file = "yarl-1.9.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e36021db54b8a0475805acc1d6c4bca5d9f52c3825ad29ae2d398a9d530ddb88"},
+ {file = "yarl-1.9.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0f17d1df951336a02afc8270c03c0c6e60d1f9996fcbd43a4ce6be81de0bd9d"},
+ {file = "yarl-1.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5f3faeb8100a43adf3e7925d556801d14b5816a0ac9e75e22948e787feec642"},
+ {file = "yarl-1.9.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aed37db837ecb5962469fad448aaae0f0ee94ffce2062cf2eb9aed13328b5196"},
+ {file = "yarl-1.9.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:721ee3fc292f0d069a04016ef2c3a25595d48c5b8ddc6029be46f6158d129c92"},
+ {file = "yarl-1.9.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b8bc5b87a65a4e64bc83385c05145ea901b613d0d3a434d434b55511b6ab0067"},
+ {file = "yarl-1.9.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:dd952b9c64f3b21aedd09b8fe958e4931864dba69926d8a90c90d36ac4e28c9a"},
+ {file = "yarl-1.9.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:c405d482c320a88ab53dcbd98d6d6f32ada074f2d965d6e9bf2d823158fa97de"},
+ {file = "yarl-1.9.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9df9a0d4c5624790a0dea2e02e3b1b3c69aed14bcb8650e19606d9df3719e87d"},
+ {file = "yarl-1.9.3-cp310-cp310-win32.whl", hash = "sha256:d34c4f80956227f2686ddea5b3585e109c2733e2d4ef12eb1b8b4e84f09a2ab6"},
+ {file = "yarl-1.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:cf7a4e8de7f1092829caef66fd90eaf3710bc5efd322a816d5677b7664893c93"},
+ {file = "yarl-1.9.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d61a0ca95503867d4d627517bcfdc28a8468c3f1b0b06c626f30dd759d3999fd"},
+ {file = "yarl-1.9.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:73cc83f918b69110813a7d95024266072d987b903a623ecae673d1e71579d566"},
+ {file = "yarl-1.9.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d81657b23e0edb84b37167e98aefb04ae16cbc5352770057893bd222cdc6e45f"},
+ {file = "yarl-1.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26a1a8443091c7fbc17b84a0d9f38de34b8423b459fb853e6c8cdfab0eacf613"},
+ {file = "yarl-1.9.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fe34befb8c765b8ce562f0200afda3578f8abb159c76de3ab354c80b72244c41"},
+ {file = "yarl-1.9.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c757f64afe53a422e45e3e399e1e3cf82b7a2f244796ce80d8ca53e16a49b9f"},
+ {file = "yarl-1.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72a57b41a0920b9a220125081c1e191b88a4cdec13bf9d0649e382a822705c65"},
+ {file = "yarl-1.9.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632c7aeb99df718765adf58eacb9acb9cbc555e075da849c1378ef4d18bf536a"},
+ {file = "yarl-1.9.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b0b8c06afcf2bac5a50b37f64efbde978b7f9dc88842ce9729c020dc71fae4ce"},
+ {file = "yarl-1.9.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1d93461e2cf76c4796355494f15ffcb50a3c198cc2d601ad8d6a96219a10c363"},
+ {file = "yarl-1.9.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:4003f380dac50328c85e85416aca6985536812c082387255c35292cb4b41707e"},
+ {file = "yarl-1.9.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4d6d74a97e898c1c2df80339aa423234ad9ea2052f66366cef1e80448798c13d"},
+ {file = "yarl-1.9.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b61e64b06c3640feab73fa4ff9cb64bd8182de52e5dc13038e01cfe674ebc321"},
+ {file = "yarl-1.9.3-cp311-cp311-win32.whl", hash = "sha256:29beac86f33d6c7ab1d79bd0213aa7aed2d2f555386856bb3056d5fdd9dab279"},
+ {file = "yarl-1.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:f7271d6bd8838c49ba8ae647fc06469137e1c161a7ef97d778b72904d9b68696"},
+ {file = "yarl-1.9.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:dd318e6b75ca80bff0b22b302f83a8ee41c62b8ac662ddb49f67ec97e799885d"},
+ {file = "yarl-1.9.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c4b1efb11a8acd13246ffb0bee888dd0e8eb057f8bf30112e3e21e421eb82d4a"},
+ {file = "yarl-1.9.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c6f034386e5550b5dc8ded90b5e2ff7db21f0f5c7de37b6efc5dac046eb19c10"},
+ {file = "yarl-1.9.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd49a908cb6d387fc26acee8b7d9fcc9bbf8e1aca890c0b2fdfd706057546080"},
+ {file = "yarl-1.9.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa4643635f26052401750bd54db911b6342eb1a9ac3e74f0f8b58a25d61dfe41"},
+ {file = "yarl-1.9.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e741bd48e6a417bdfbae02e088f60018286d6c141639359fb8df017a3b69415a"},
+ {file = "yarl-1.9.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c86d0d0919952d05df880a1889a4f0aeb6868e98961c090e335671dea5c0361"},
+ {file = "yarl-1.9.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d5434b34100b504aabae75f0622ebb85defffe7b64ad8f52b8b30ec6ef6e4b9"},
+ {file = "yarl-1.9.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79e1df60f7c2b148722fb6cafebffe1acd95fd8b5fd77795f56247edaf326752"},
+ {file = "yarl-1.9.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:44e91a669c43f03964f672c5a234ae0d7a4d49c9b85d1baa93dec28afa28ffbd"},
+ {file = "yarl-1.9.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:3cfa4dbe17b2e6fca1414e9c3bcc216f6930cb18ea7646e7d0d52792ac196808"},
+ {file = "yarl-1.9.3-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:88d2c3cc4b2f46d1ba73d81c51ec0e486f59cc51165ea4f789677f91a303a9a7"},
+ {file = "yarl-1.9.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cccdc02e46d2bd7cb5f38f8cc3d9db0d24951abd082b2f242c9e9f59c0ab2af3"},
+ {file = "yarl-1.9.3-cp312-cp312-win32.whl", hash = "sha256:96758e56dceb8a70f8a5cff1e452daaeff07d1cc9f11e9b0c951330f0a2396a7"},
+ {file = "yarl-1.9.3-cp312-cp312-win_amd64.whl", hash = "sha256:c4472fe53ebf541113e533971bd8c32728debc4c6d8cc177f2bff31d011ec17e"},
+ {file = "yarl-1.9.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:126638ab961633f0940a06e1c9d59919003ef212a15869708dcb7305f91a6732"},
+ {file = "yarl-1.9.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c99ddaddb2fbe04953b84d1651149a0d85214780e4d0ee824e610ab549d98d92"},
+ {file = "yarl-1.9.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dab30b21bd6fb17c3f4684868c7e6a9e8468078db00f599fb1c14e324b10fca"},
+ {file = "yarl-1.9.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:828235a2a169160ee73a2fcfb8a000709edf09d7511fccf203465c3d5acc59e4"},
+ {file = "yarl-1.9.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc391e3941045fd0987c77484b2799adffd08e4b6735c4ee5f054366a2e1551d"},
+ {file = "yarl-1.9.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51382c72dd5377861b573bd55dcf680df54cea84147c8648b15ac507fbef984d"},
+ {file = "yarl-1.9.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:28a108cb92ce6cf867690a962372996ca332d8cda0210c5ad487fe996e76b8bb"},
+ {file = "yarl-1.9.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:8f18a7832ff85dfcd77871fe677b169b1bc60c021978c90c3bb14f727596e0ae"},
+ {file = "yarl-1.9.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:7eaf13af79950142ab2bbb8362f8d8d935be9aaf8df1df89c86c3231e4ff238a"},
+ {file = "yarl-1.9.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:66a6dbf6ca7d2db03cc61cafe1ee6be838ce0fbc97781881a22a58a7c5efef42"},
+ {file = "yarl-1.9.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1a0a4f3aaa18580038cfa52a7183c8ffbbe7d727fe581300817efc1e96d1b0e9"},
+ {file = "yarl-1.9.3-cp37-cp37m-win32.whl", hash = "sha256:946db4511b2d815979d733ac6a961f47e20a29c297be0d55b6d4b77ee4b298f6"},
+ {file = "yarl-1.9.3-cp37-cp37m-win_amd64.whl", hash = "sha256:2dad8166d41ebd1f76ce107cf6a31e39801aee3844a54a90af23278b072f1ccf"},
+ {file = "yarl-1.9.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:bb72d2a94481e7dc7a0c522673db288f31849800d6ce2435317376a345728225"},
+ {file = "yarl-1.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9a172c3d5447b7da1680a1a2d6ecdf6f87a319d21d52729f45ec938a7006d5d8"},
+ {file = "yarl-1.9.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2dc72e891672343b99db6d497024bf8b985537ad6c393359dc5227ef653b2f17"},
+ {file = "yarl-1.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8d51817cf4b8d545963ec65ff06c1b92e5765aa98831678d0e2240b6e9fd281"},
+ {file = "yarl-1.9.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53ec65f7eee8655bebb1f6f1607760d123c3c115a324b443df4f916383482a67"},
+ {file = "yarl-1.9.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cfd77e8e5cafba3fb584e0f4b935a59216f352b73d4987be3af51f43a862c403"},
+ {file = "yarl-1.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e73db54c967eb75037c178a54445c5a4e7461b5203b27c45ef656a81787c0c1b"},
+ {file = "yarl-1.9.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09c19e5f4404574fcfb736efecf75844ffe8610606f3fccc35a1515b8b6712c4"},
+ {file = "yarl-1.9.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6280353940f7e5e2efaaabd686193e61351e966cc02f401761c4d87f48c89ea4"},
+ {file = "yarl-1.9.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c25ec06e4241e162f5d1f57c370f4078797ade95c9208bd0c60f484834f09c96"},
+ {file = "yarl-1.9.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7217234b10c64b52cc39a8d82550342ae2e45be34f5bff02b890b8c452eb48d7"},
+ {file = "yarl-1.9.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4ce77d289f8d40905c054b63f29851ecbfd026ef4ba5c371a158cfe6f623663e"},
+ {file = "yarl-1.9.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5f74b015c99a5eac5ae589de27a1201418a5d9d460e89ccb3366015c6153e60a"},
+ {file = "yarl-1.9.3-cp38-cp38-win32.whl", hash = "sha256:8a2538806be846ea25e90c28786136932ec385c7ff3bc1148e45125984783dc6"},
+ {file = "yarl-1.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:6465d36381af057d0fab4e0f24ef0e80ba61f03fe43e6eeccbe0056e74aadc70"},
+ {file = "yarl-1.9.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2f3c8822bc8fb4a347a192dd6a28a25d7f0ea3262e826d7d4ef9cc99cd06d07e"},
+ {file = "yarl-1.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7831566595fe88ba17ea80e4b61c0eb599f84c85acaa14bf04dd90319a45b90"},
+ {file = "yarl-1.9.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ff34cb09a332832d1cf38acd0f604c068665192c6107a439a92abfd8acf90fe2"},
+ {file = "yarl-1.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe8080b4f25dfc44a86bedd14bc4f9d469dfc6456e6f3c5d9077e81a5fedfba7"},
+ {file = "yarl-1.9.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8535e111a064f3bdd94c0ed443105934d6f005adad68dd13ce50a488a0ad1bf3"},
+ {file = "yarl-1.9.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d155a092bf0ebf4a9f6f3b7a650dc5d9a5bbb585ef83a52ed36ba46f55cc39d"},
+ {file = "yarl-1.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:778df71c8d0c8c9f1b378624b26431ca80041660d7be7c3f724b2c7a6e65d0d6"},
+ {file = "yarl-1.9.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b9f9cafaf031c34d95c1528c16b2fa07b710e6056b3c4e2e34e9317072da5d1a"},
+ {file = "yarl-1.9.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ca6b66f69e30f6e180d52f14d91ac854b8119553b524e0e28d5291a724f0f423"},
+ {file = "yarl-1.9.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:e0e7e83f31e23c5d00ff618045ddc5e916f9e613d33c5a5823bc0b0a0feb522f"},
+ {file = "yarl-1.9.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:af52725c7c39b0ee655befbbab5b9a1b209e01bb39128dce0db226a10014aacc"},
+ {file = "yarl-1.9.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:0ab5baaea8450f4a3e241ef17e3d129b2143e38a685036b075976b9c415ea3eb"},
+ {file = "yarl-1.9.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6d350388ba1129bc867c6af1cd17da2b197dff0d2801036d2d7d83c2d771a682"},
+ {file = "yarl-1.9.3-cp39-cp39-win32.whl", hash = "sha256:e2a16ef5fa2382af83bef4a18c1b3bcb4284c4732906aa69422cf09df9c59f1f"},
+ {file = "yarl-1.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:d92d897cb4b4bf915fbeb5e604c7911021a8456f0964f3b8ebbe7f9188b9eabb"},
+ {file = "yarl-1.9.3-py3-none-any.whl", hash = "sha256:271d63396460b6607b588555ea27a1a02b717ca2e3f2cf53bdde4013d7790929"},
+ {file = "yarl-1.9.3.tar.gz", hash = "sha256:4a14907b597ec55740f63e52d7fee0e9ee09d5b9d57a4f399a7423268e457b57"},
+]
+
+[package.dependencies]
+idna = ">=2.0"
+multidict = ">=4.0"
+
+[[package]]
+name = "zipp"
+version = "3.17.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
+ {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
+
+[extras]
+cli = ["typer"]
+extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
+
+[metadata]
+lock-version = "2.0"
+python-versions = ">=3.8.1,<4.0"
+content-hash = "6ada6ee5af954616af167a895f59823cdbe01a5381d1a9bc7c9b8b56e4e951bb"
diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml
new file mode 100644
index 00000000000..52e44722cf1
--- /dev/null
+++ b/libs/community/pyproject.toml
@@ -0,0 +1,298 @@
+[tool.poetry]
+name = "langchain-community"
+version = "0.0.1-rc.2"
+description = "Community contributed LangChain integrations."
+authors = []
+license = "MIT"
+readme = "README.md"
+repository = "https://github.com/langchain-ai/langchain"
+
+[tool.poetry.dependencies]
+python = ">=3.8.1,<4.0"
+langchain-core = { version = ">=0.0.13-rc.2,<0.1", allow-prereleases = true }
+SQLAlchemy = ">=1.4,<3"
+requests = "^2"
+PyYAML = ">=5.3"
+numpy = "^1"
+aiohttp = "^3.8.3"
+tenacity = "^8.1.0"
+dataclasses-json = ">= 0.5.7, < 0.7"
+langsmith = "~0.0.63"
+tqdm = {version = ">=4.48.0", optional = true}
+openapi-pydantic = {version = "^0.3.2", optional = true}
+faiss-cpu = {version = "^1", optional = true}
+beautifulsoup4 = {version = "^4", optional = true}
+jinja2 = {version = "^3", optional = true}
+cohere = {version = "^4", optional = true}
+openai = {version = "<2", optional = true}
+arxiv = {version = "^1.4", optional = true}
+pypdf = {version = "^3.4.0", optional = true}
+aleph-alpha-client = {version="^2.15.0", optional = true}
+pgvector = {version = "^0.1.6", optional = true}
+atlassian-python-api = {version = "^3.36.0", optional=true}
+html2text = {version="^2020.1.16", optional=true}
+numexpr = {version="^2.8.6", optional=true}
+jq = {version = "^1.4.1", optional = true}
+pdfminer-six = {version = "^20221105", optional = true}
+lxml = {version = "^4.9.2", optional = true}
+pymupdf = {version = "^1.22.3", optional = true}
+rapidocr-onnxruntime = {version = "^1.3.2", optional = true, python = ">=3.8.1,<3.12"}
+pypdfium2 = {version = "^4.10.0", optional = true}
+gql = {version = "^3.4.1", optional = true}
+pandas = {version = "^2.0.1", optional = true}
+telethon = {version = "^1.28.5", optional = true}
+chardet = {version="^5.1.0", optional=true}
+requests-toolbelt = {version = "^1.0.0", optional = true}
+scikit-learn = {version = "^1.2.2", optional = true}
+py-trello = {version = "^0.19.0", optional = true}
+bibtexparser = {version = "^1.4.0", optional = true}
+pyspark = {version = "^3.4.0", optional = true}
+mwparserfromhell = {version = "^0.6.4", optional = true}
+mwxml = {version = "^0.3.3", optional = true}
+esprima = {version = "^4.0.1", optional = true}
+streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
+psychicapi = {version = "^0.8.0", optional = true}
+cassio = {version = "^0.1.0", optional = true}
+sympy = {version = "^1.12", optional = true}
+rapidfuzz = {version = "^3.1.1", optional = true}
+jsonschema = {version = ">1", optional = true}
+rank-bm25 = {version = "^0.2.2", optional = true}
+geopandas = {version = "^0.13.1", optional = true}
+gitpython = {version = "^3.1.32", optional = true}
+feedparser = {version = "^6.0.10", optional = true}
+newspaper3k = {version = "^0.2.8", optional = true}
+xata = {version = "^1.0.0a7", optional = true}
+xmltodict = {version = "^0.13.0", optional = true}
+markdownify = {version = "^0.11.6", optional = true}
+assemblyai = {version = "^0.17.0", optional = true}
+dashvector = {version = "^1.0.1", optional = true}
+sqlite-vss = {version = "^0.1.2", optional = true}
+motor = {version = "^3.3.1", optional = true}
+timescale-vector = {version = "^0.0.1", optional = true}
+typer = {version= "^0.9.0", optional = true}
+anthropic = {version = "^0.3.11", optional = true}
+aiosqlite = {version = "^0.19.0", optional = true}
+rspace_client = {version = "^2.5.0", optional = true}
+upstash-redis = {version = "^0.15.0", optional = true}
+google-cloud-documentai = {version = "^2.20.1", optional = true}
+fireworks-ai = {version = "^0.9.0", optional = true}
+javelin-sdk = {version = "^0.1.8", optional = true}
+hologres-vector = {version = "^0.0.6", optional = true}
+praw = {version = "^7.7.1", optional = true}
+msal = {version = "^1.25.0", optional = true}
+databricks-vectorsearch = {version = "^0.21", optional = true}
+dgml-utils = {version = "^0.3.0", optional = true}
+datasets = {version = "^2.15.0", optional = true}
+
+[tool.poetry.group.test]
+optional = true
+
+[tool.poetry.group.test.dependencies]
+# The only dependencies that should be added are
+# dependencies used for running tests (e.g., pytest, freezegun, response).
+# Any dependencies that do not meet that criteria will be removed.
+pytest = "^7.3.0"
+pytest-cov = "^4.0.0"
+pytest-dotenv = "^0.5.2"
+duckdb-engine = "^0.9.2"
+pytest-watcher = "^0.2.6"
+freezegun = "^1.2.2"
+responses = "^0.22.0"
+pytest-asyncio = "^0.20.3"
+lark = "^1.1.5"
+pandas = "^2.0.0"
+pytest-mock = "^3.10.0"
+pytest-socket = "^0.6.0"
+syrupy = "^4.0.2"
+requests-mock = "^1.11.0"
+langchain-core = {path = "../core", develop = true}
+
+[tool.poetry.group.codespell]
+optional = true
+
+[tool.poetry.group.codespell.dependencies]
+codespell = "^2.2.0"
+
+[tool.poetry.group.test_integration]
+optional = true
+
+[tool.poetry.group.test_integration.dependencies]
+# Do not add dependencies in the test_integration group
+# Instead:
+# 1. Add an optional dependency to the main group
+# poetry add --optional [package name]
+# 2. Add the package name to the extended_testing extra (find it below)
+# 3. Relock the poetry file
+# poetry lock --no-update
+# 4. Favor unit tests not integration tests.
+# Use the @pytest.mark.requires(pkg_name) decorator in unit_tests.
+# Your tests should not rely on network access, as it prevents other
+# developers from being able to easily run them.
+# Instead write unit tests that use the `responses` library or mock.patch with
+# fixtures. Keep the fixtures minimal.
+# See CONTRIBUTING.md for more instructions on working with optional dependencies.
+# https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md#working-with-optional-dependencies
+pytest-vcr = "^1.0.2"
+wrapt = "^1.15.0"
+openai = "^1"
+python-dotenv = "^1.0.0"
+cassio = "^0.1.0"
+tiktoken = "^0.3.2"
+anthropic = "^0.3.11"
+langchain-core = { path = "../core", develop = true }
+fireworks-ai = "^0.9.0"
+boto3 = ">=1.28.57,<2"
+google-cloud-aiplatform = ">=1.37.0,<2"
+
+[tool.poetry.group.lint]
+optional = true
+
+[tool.poetry.group.lint.dependencies]
+ruff = "^0.1.5"
+
+[tool.poetry.group.typing.dependencies]
+mypy = "^0.991"
+types-pyyaml = "^6.0.12.2"
+types-requests = "^2.28.11.5"
+types-toml = "^0.10.8.1"
+types-pytz = "^2023.3.0.0"
+types-chardet = "^5.0.4.6"
+types-redis = "^4.3.21.6"
+mypy-protobuf = "^3.0.0"
+langchain-core = {path = "../core", develop = true}
+
+[tool.poetry.group.dev]
+optional = true
+
+[tool.poetry.group.dev.dependencies]
+jupyter = "^1.0.0"
+setuptools = "^67.6.1"
+langchain-core = {path = "../core", develop = true}
+
+[tool.poetry.extras]
+
+cli = ["typer"]
+
+# An extra used to be able to add extended testing.
+# Please use new-line on formatting to make it easier to add new packages without
+# merge-conflicts
+extended_testing = [
+ "aleph-alpha-client",
+ "aiosqlite",
+ "assemblyai",
+ "beautifulsoup4",
+ "bibtexparser",
+ "cassio",
+ "chardet",
+ "datasets",
+ "google-cloud-documentai",
+ "esprima",
+ "jq",
+ "pdfminer-six",
+ "pgvector",
+ "pypdf",
+ "pymupdf",
+ "pypdfium2",
+ "tqdm",
+ "lxml",
+ "atlassian-python-api",
+ "mwparserfromhell",
+ "mwxml",
+ "msal",
+ "pandas",
+ "telethon",
+ "psychicapi",
+ "gql",
+ "requests-toolbelt",
+ "html2text",
+ "numexpr",
+ "py-trello",
+ "scikit-learn",
+ "streamlit",
+ "pyspark",
+ "openai",
+ "sympy",
+ "rapidfuzz",
+ "jsonschema",
+ "rank-bm25",
+ "geopandas",
+ "jinja2",
+ "gitpython",
+ "newspaper3k",
+ "feedparser",
+ "xata",
+ "xmltodict",
+ "faiss-cpu",
+ "openapi-pydantic",
+ "markdownify",
+ "arxiv",
+ "dashvector",
+ "sqlite-vss",
+ "rapidocr-onnxruntime",
+ "motor",
+ "timescale-vector",
+ "anthropic",
+ "upstash-redis",
+ "rspace_client",
+ "fireworks-ai",
+ "javelin-sdk",
+ "hologres-vector",
+ "praw",
+ "databricks-vectorsearch",
+ "dgml-utils",
+ "cohere",
+]
+
+[tool.ruff]
+select = [
+ "E", # pycodestyle
+ "F", # pyflakes
+ "I", # isort
+]
+exclude = [
+ "tests/examples/non-utf8-encoding.py",
+ "tests/integration_tests/examples/non-utf8-encoding.py",
+]
+
+[tool.mypy]
+ignore_missing_imports = "True"
+disallow_untyped_defs = "True"
+exclude = ["notebooks", "examples", "example_data"]
+
+[tool.coverage.run]
+omit = [
+ "tests/*",
+]
+
+[build-system]
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
+
+[tool.pytest.ini_options]
+# --strict-markers will raise errors on unknown marks.
+# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
+#
+# https://docs.pytest.org/en/7.1.x/reference/reference.html
+# --strict-config any warnings encountered while parsing the `pytest`
+# section of the configuration file raise errors.
+#
+# https://github.com/tophat/syrupy
+# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
+addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
+# Registering custom markers.
+# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
+markers = [
+ "requires: mark tests as requiring a specific library",
+ "scheduled: mark tests to run in scheduled testing",
+ "compile: mark placeholder test used to compile integration tests without running them"
+]
+asyncio_mode = "auto"
+
+[tool.codespell]
+skip = '.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples'
+# Ignore latin etc
+ignore-regex = '.*(Stati Uniti|Tense=Pres).*'
+# whats is a typo but used frequently in queries so kept as is
+# aapply - async apply
+# unsecure - typo but part of API, decided to not bother for now
+ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin'
diff --git a/libs/community/scripts/check_imports.sh b/libs/community/scripts/check_imports.sh
new file mode 100755
index 00000000000..8da63d925a9
--- /dev/null
+++ b/libs/community/scripts/check_imports.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -eu
+
+# Initialize a variable to keep track of errors
+errors=0
+
+# make sure not importing from langchain_community or langchain_experimental
+git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
+git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
+
+# Decide on an exit status based on the errors
+if [ "$errors" -gt 0 ]; then
+ exit 1
+else
+ exit 0
+fi
diff --git a/libs/community/scripts/check_pydantic.sh b/libs/community/scripts/check_pydantic.sh
new file mode 100755
index 00000000000..06b5bb81ae2
--- /dev/null
+++ b/libs/community/scripts/check_pydantic.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+#
+# This script searches for lines starting with "import pydantic" or "from pydantic"
+# in tracked files within a Git repository.
+#
+# Usage: ./scripts/check_pydantic.sh /path/to/repository
+
+# Check if a path argument is provided
+if [ $# -ne 1 ]; then
+ echo "Usage: $0 /path/to/repository"
+ exit 1
+fi
+
+repository_path="$1"
+
+# Search for lines matching the pattern within the specified repository
+result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
+
+# Check if any matching lines were found
+if [ -n "$result" ]; then
+ echo "ERROR: The following lines need to be updated:"
+ echo "$result"
+ echo "Please replace the code with an import from langchain_core.pydantic_v1."
+ echo "For example, replace 'from pydantic import BaseModel'"
+ echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
+ exit 1
+fi
diff --git a/libs/community/scripts/lint_imports.sh b/libs/community/scripts/lint_imports.sh
new file mode 100755
index 00000000000..695613c7ba8
--- /dev/null
+++ b/libs/community/scripts/lint_imports.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -eu
+
+# Initialize a variable to keep track of errors
+errors=0
+
+# make sure not importing from langchain or langchain_experimental
+git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
+git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
+
+# Decide on an exit status based on the errors
+if [ "$errors" -gt 0 ]; then
+ exit 1
+else
+ exit 0
+fi
diff --git a/libs/langchain/tests/unit_tests/document_loaders/loaders/__init__.py b/libs/community/tests/__init__.py
similarity index 100%
rename from libs/langchain/tests/unit_tests/document_loaders/loaders/__init__.py
rename to libs/community/tests/__init__.py
diff --git a/libs/community/tests/examples/README.org b/libs/community/tests/examples/README.org
new file mode 100644
index 00000000000..5b9f4728040
--- /dev/null
+++ b/libs/community/tests/examples/README.org
@@ -0,0 +1,27 @@
+* Example Docs
+
+The sample docs directory contains the following files:
+
+- ~example-10k.html~ - A 10-K SEC filing in HTML format
+- ~layout-parser-paper.pdf~ - A PDF copy of the layout parser paper
+- ~factbook.xml~ / ~factbook.xsl~ - Example XML/XLS files that you
+ can use to test stylesheets
+
+These documents can be used to test out the parsers in the library. In
+addition, here are instructions for pulling in some sample docs that are
+too big to store in the repo.
+
+** XBRL 10-K
+
+You can get an example 10-K in inline XBRL format using the following
+~curl~. Note, you need to have the user agent set in the header or the
+SEC site will reject your request.
+
+#+BEGIN_SRC bash
+
+ curl -O \
+ -A '${organization} ${email}'
+ https://www.sec.gov/Archives/edgar/data/311094/000117184321001344/0001171843-21-001344.txt
+#+END_SRC
+
+You can parse this document using the HTML parser.
diff --git a/libs/community/tests/examples/README.rst b/libs/community/tests/examples/README.rst
new file mode 100644
index 00000000000..45630d0385d
--- /dev/null
+++ b/libs/community/tests/examples/README.rst
@@ -0,0 +1,28 @@
+Example Docs
+------------
+
+The sample docs directory contains the following files:
+
+- ``example-10k.html`` - A 10-K SEC filing in HTML format
+- ``layout-parser-paper.pdf`` - A PDF copy of the layout parser paper
+- ``factbook.xml``/``factbook.xsl`` - Example XML/XLS files that you
+ can use to test stylesheets
+
+These documents can be used to test out the parsers in the library. In
+addition, here are instructions for pulling in some sample docs that are
+too big to store in the repo.
+
+XBRL 10-K
+^^^^^^^^^
+
+You can get an example 10-K in inline XBRL format using the following
+``curl``. Note, you need to have the user agent set in the header or the
+SEC site will reject your request.
+
+.. code:: bash
+
+ curl -O \
+ -A '${organization} ${email}'
+ https://www.sec.gov/Archives/edgar/data/311094/000117184321001344/0001171843-21-001344.txt
+
+You can parse this document using the HTML parser.
diff --git a/libs/community/tests/examples/brandfetch-brandfetch-2.0.0-resolved.json b/libs/community/tests/examples/brandfetch-brandfetch-2.0.0-resolved.json
new file mode 100644
index 00000000000..de37dbf5fba
--- /dev/null
+++ b/libs/community/tests/examples/brandfetch-brandfetch-2.0.0-resolved.json
@@ -0,0 +1,282 @@
+{
+ "openapi": "3.0.1",
+ "info": {
+ "title": "Brandfetch API",
+ "description": "Brandfetch API (v2) for retrieving brand information.\n\nSee our [documentation](https://docs.brandfetch.com/) for further details. ",
+ "termsOfService": "https://brandfetch.com/terms",
+ "contact": {
+ "url": "https://brandfetch.com/developers"
+ },
+ "version": "2.0.0"
+ },
+ "externalDocs": {
+ "description": "Documentation",
+ "url": "https://docs.brandfetch.com/"
+ },
+ "servers": [
+ {
+ "url": "https://api.brandfetch.io/v2"
+ }
+ ],
+ "paths": {
+ "/brands/{domainOrId}": {
+ "get": {
+ "summary": "Retrieve a brand",
+ "description": "Fetch brand information by domain or ID\n\nFurther details here: https://docs.brandfetch.com/reference/retrieve-brand\n",
+ "parameters": [
+ {
+ "name": "domainOrId",
+ "in": "path",
+ "description": "Domain or ID of the brand",
+ "required": true,
+ "style": "simple",
+ "explode": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "responses": {
+ "200": {
+ "description": "Brand data",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Brand"
+ },
+ "examples": {
+ "brandfetch.com": {
+ "value": "{\"name\":\"Brandfetch\",\"domain\":\"brandfetch.com\",\"claimed\":true,\"description\":\"All brands. In one place\",\"links\":[{\"name\":\"twitter\",\"url\":\"https://twitter.com/brandfetch\"},{\"name\":\"linkedin\",\"url\":\"https://linkedin.com/company/brandfetch\"}],\"logos\":[{\"type\":\"logo\",\"theme\":\"light\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/id9WE9j86h.svg\",\"background\":\"transparent\",\"format\":\"svg\",\"size\":15555}]},{\"type\":\"logo\",\"theme\":\"dark\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idWbsK1VCy.png\",\"background\":\"transparent\",\"format\":\"png\",\"height\":215,\"width\":800,\"size\":33937},{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idtCMfbWO0.svg\",\"background\":\"transparent\",\"format\":\"svg\",\"height\":null,\"width\":null,\"size\":15567}]},{\"type\":\"symbol\",\"theme\":\"light\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idXGq6SIu2.svg\",\"background\":\"transparent\",\"format\":\"svg\",\"size\":2215}]},{\"type\":\"symbol\",\"theme\":\"dark\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/iddCQ52AR5.svg\",\"background\":\"transparent\",\"format\":\"svg\",\"size\":2215}]},{\"type\":\"icon\",\"theme\":\"dark\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idls3LaPPQ.png\",\"background\":null,\"format\":\"png\",\"height\":400,\"width\":400,\"size\":2565}]}],\"colors\":[{\"hex\":\"#0084ff\",\"type\":\"accent\",\"brightness\":113},{\"hex\":\"#00193E\",\"type\":\"brand\",\"brightness\":22},{\"hex\":\"#F03063\",\"type\":\"brand\",\"brightness\":93},{\"hex\":\"#7B0095\",\"type\":\"brand\",\"brightness\":37},{\"hex\":\"#76CC4B\",\"type\":\"brand\",\"brightness\":176},{\"hex\":\"#FFDA00\",\"type\":\"brand\",\"brightness\":210},{\"hex\":\"#000000\",\"type\":\"dark\",\"brightness\":0},{\"hex\":\"#ffffff\",\"type\":\"light\",\"brightness\":255}],\"fonts\":[{\"name\":\"Poppins\",\"type\":\"title\",\"origin\":\"google\",\"originId\":\"Poppins\",\"weights\":[]},{\"name\":\"Inter\",\"type\":\"body\",\"origin\":\"google\",\"originId\":\"Inter\",\"weights\":[]}],\"images\":[{\"type\":\"banner\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idUuia5imo.png\",\"background\":\"transparent\",\"format\":\"png\",\"height\":500,\"width\":1500,\"size\":5539}]}]}"
+ }
+ }
+ }
+ }
+ },
+ "400": {
+ "description": "Invalid domain or ID supplied"
+ },
+ "404": {
+ "description": "The brand does not exist or the domain can't be resolved."
+ }
+ },
+ "security": [
+ {
+ "bearerAuth": []
+ }
+ ]
+ }
+ }
+ },
+ "components": {
+ "schemas": {
+ "Brand": {
+ "required": [
+ "claimed",
+ "colors",
+ "description",
+ "domain",
+ "fonts",
+ "images",
+ "links",
+ "logos",
+ "name"
+ ],
+ "type": "object",
+ "properties": {
+ "images": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ImageAsset"
+ }
+ },
+ "fonts": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/FontAsset"
+ }
+ },
+ "domain": {
+ "type": "string"
+ },
+ "claimed": {
+ "type": "boolean"
+ },
+ "name": {
+ "type": "string"
+ },
+ "description": {
+ "type": "string"
+ },
+ "links": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Brand_links"
+ }
+ },
+ "logos": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ImageAsset"
+ }
+ },
+ "colors": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ColorAsset"
+ }
+ }
+ },
+ "description": "Object representing a brand"
+ },
+ "ColorAsset": {
+ "required": [
+ "brightness",
+ "hex",
+ "type"
+ ],
+ "type": "object",
+ "properties": {
+ "brightness": {
+ "type": "integer"
+ },
+ "hex": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "enum": [
+ "accent",
+ "brand",
+ "customizable",
+ "dark",
+ "light",
+ "vibrant"
+ ]
+ }
+ },
+ "description": "Brand color asset"
+ },
+ "FontAsset": {
+ "type": "object",
+ "properties": {
+ "originId": {
+ "type": "string"
+ },
+ "origin": {
+ "type": "string",
+ "enum": [
+ "adobe",
+ "custom",
+ "google",
+ "system"
+ ]
+ },
+ "name": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string"
+ },
+ "weights": {
+ "type": "array",
+ "items": {
+ "type": "number"
+ }
+ },
+ "items": {
+ "type": "string"
+ }
+ },
+ "description": "Brand font asset"
+ },
+ "ImageAsset": {
+ "required": [
+ "formats",
+ "theme",
+ "type"
+ ],
+ "type": "object",
+ "properties": {
+ "formats": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ImageFormat"
+ }
+ },
+ "theme": {
+ "type": "string",
+ "enum": [
+ "light",
+ "dark"
+ ]
+ },
+ "type": {
+ "type": "string",
+ "enum": [
+ "logo",
+ "icon",
+ "symbol",
+ "banner"
+ ]
+ }
+ },
+ "description": "Brand image asset"
+ },
+ "ImageFormat": {
+ "required": [
+ "background",
+ "format",
+ "size",
+ "src"
+ ],
+ "type": "object",
+ "properties": {
+ "size": {
+ "type": "integer"
+ },
+ "src": {
+ "type": "string"
+ },
+ "background": {
+ "type": "string",
+ "enum": [
+ "transparent"
+ ]
+ },
+ "format": {
+ "type": "string"
+ },
+ "width": {
+ "type": "integer"
+ },
+ "height": {
+ "type": "integer"
+ }
+ },
+ "description": "Brand image asset image format"
+ },
+ "Brand_links": {
+ "required": [
+ "name",
+ "url"
+ ],
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "url": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "securitySchemes": {
+ "bearerAuth": {
+ "type": "http",
+ "scheme": "bearer",
+ "bearerFormat": "API Key"
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/libs/community/tests/examples/default-encoding.py b/libs/community/tests/examples/default-encoding.py
new file mode 100644
index 00000000000..9a09cc8271f
--- /dev/null
+++ b/libs/community/tests/examples/default-encoding.py
@@ -0,0 +1 @@
+u = "π¦π"
diff --git a/libs/community/tests/examples/docusaurus-sitemap.xml b/libs/community/tests/examples/docusaurus-sitemap.xml
new file mode 100644
index 00000000000..eebae785b88
--- /dev/null
+++ b/libs/community/tests/examples/docusaurus-sitemap.xml
@@ -0,0 +1,42 @@
+
+
+
+ https://python.langchain.com/docs/integrations/document_loaders/sitemap
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/cookbook
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/docs/additional_resources
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/docs/modules/chains/how_to/
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/docs/use_cases/question_answering/local_retrieval_qa
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/docs/use_cases/summarization
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/
+ weekly
+ 0.5
+
+
\ No newline at end of file
diff --git a/libs/community/tests/examples/duplicate-chars.pdf b/libs/community/tests/examples/duplicate-chars.pdf
new file mode 100644
index 00000000000..47467cd035d
Binary files /dev/null and b/libs/community/tests/examples/duplicate-chars.pdf differ
diff --git a/libs/community/tests/examples/example-utf8.html b/libs/community/tests/examples/example-utf8.html
new file mode 100644
index 00000000000..f96e20fcedb
--- /dev/null
+++ b/libs/community/tests/examples/example-utf8.html
@@ -0,0 +1,25 @@
+
+
+ Chew dad's slippers
+
+
+
+ Instead of drinking water from the cat bowl, make sure to steal water from
+ the toilet
+
+
Chase the red dot
+
+ Munch, munch, chomp, chomp hate dogs. Spill litter box, scratch at owner,
+ destroy all furniture, especially couch get scared by sudden appearance of
+ cucumber cat is love, cat is life fat baby cat best buddy little guy for
+ catch eat throw up catch eat throw up bad birds jump on fridge. Purr like
+ a car engine oh yes, there is my human woman she does best pats ever that
+ all i like about her hiss meow .
+
+
+ Dead stare with ears cocked when βownersβ are asleep, cry for no apparent
+ reason meow all night. Plop down in the middle where everybody walks favor
+ packaging over toy. Sit on the laptop kitty pounce, trip, faceplant.
+
+ Instead of drinking water from the cat bowl, make sure to steal water from
+ the toilet
+
+
Chase the red dot
+
+ Munch, munch, chomp, chomp hate dogs. Spill litter box, scratch at owner,
+ destroy all furniture, especially couch get scared by sudden appearance of
+ cucumber cat is love, cat is life fat baby cat best buddy little guy for
+ catch eat throw up catch eat throw up bad birds jump on fridge. Purr like
+ a car engine oh yes, there is my human woman she does best pats ever that
+ all i like about her hiss meow .
+
+
+ Dead stare with ears cocked when owners are asleep, cry for no apparent
+ reason meow all night. Plop down in the middle where everybody walks favor
+ packaging over toy. Sit on the laptop kitty pounce, trip, faceplant.
+
If you have any comments about our WEB page, you can=20
+write us at the address shown above. However, due to=20
+the limited number of personnel in our corporate office, we are unable to=
+=20
+provide a direct response.
+
+
Copyright =C2=A9 2023-2023 LangChain =
+Inc.=20
+
+
+
+------MultipartBoundary--dYaUgeoeP18TqraaeOwkeZyu1vI09OtkFwH2rcnJMt------
diff --git a/libs/community/tests/examples/facebook_chat.json b/libs/community/tests/examples/facebook_chat.json
new file mode 100644
index 00000000000..68c9c0c2344
--- /dev/null
+++ b/libs/community/tests/examples/facebook_chat.json
@@ -0,0 +1,64 @@
+{
+ "participants": [{"name": "User 1"}, {"name": "User 2"}],
+ "messages": [
+ {"sender_name": "User 2", "timestamp_ms": 1675597571851, "content": "Bye!"},
+ {
+ "sender_name": "User 1",
+ "timestamp_ms": 1675597435669,
+ "content": "Oh no worries! Bye"
+ },
+ {
+ "sender_name": "User 2",
+ "timestamp_ms": 1675596277579,
+ "content": "No Im sorry it was my mistake, the blue one is not for sale"
+ },
+ {
+ "sender_name": "User 1",
+ "timestamp_ms": 1675595140251,
+ "content": "I thought you were selling the blue one!"
+ },
+ {
+ "sender_name": "User 1",
+ "timestamp_ms": 1675595109305,
+ "content": "Im not interested in this bag. Im interested in the blue one!"
+ },
+ {
+ "sender_name": "User 2",
+ "timestamp_ms": 1675595068468,
+ "content": "Here is $129"
+ },
+ {
+ "sender_name": "User 2",
+ "timestamp_ms": 1675595060730,
+ "photos": [
+ {"uri": "url_of_some_picture.jpg", "creation_timestamp": 1675595059}
+ ]
+ },
+ {
+ "sender_name": "User 2",
+ "timestamp_ms": 1675595045152,
+ "content": "Online is at least $100"
+ },
+ {
+ "sender_name": "User 1",
+ "timestamp_ms": 1675594799696,
+ "content": "How much do you want?"
+ },
+ {
+ "sender_name": "User 2",
+ "timestamp_ms": 1675577876645,
+ "content": "Goodmorning! $50 is too low."
+ },
+ {
+ "sender_name": "User 1",
+ "timestamp_ms": 1675549022673,
+ "content": "Hi! Im interested in your bag. Im offering $50. Let me know if you are interested. Thanks!"
+ }
+ ],
+ "title": "User 1 and User 2 chat",
+ "is_still_participant": true,
+ "thread_path": "inbox/User 1 and User 2 chat",
+ "magic_words": [],
+ "image": {"uri": "image_of_the_chat.jpg", "creation_timestamp": 1675549016},
+ "joinable_mode": {"mode": 1, "link": ""}
+}
diff --git a/libs/community/tests/examples/factbook.xml b/libs/community/tests/examples/factbook.xml
new file mode 100644
index 00000000000..d059ee9d0c5
--- /dev/null
+++ b/libs/community/tests/examples/factbook.xml
@@ -0,0 +1,27 @@
+
+
+
+ United States
+ Washington, DC
+ Joe Biden
+ Baseball
+
+
+ Canada
+ Ottawa
+ Justin Trudeau
+ Hockey
+
+
+ France
+ Paris
+ Emmanuel Macron
+ Soccer
+
+
+ Trinidad & Tobado
+ Port of Spain
+ Keith Rowley
+ Track & Field
+
+
diff --git a/libs/community/tests/examples/fake-email-attachment.eml b/libs/community/tests/examples/fake-email-attachment.eml
new file mode 100644
index 00000000000..5d8b0367247
--- /dev/null
+++ b/libs/community/tests/examples/fake-email-attachment.eml
@@ -0,0 +1,50 @@
+MIME-Version: 1.0
+Date: Fri, 23 Dec 2022 12:08:48 -0600
+Message-ID:
+Subject: Fake email with attachment
+From: Mallori Harrell
+To: Mallori Harrell
+Content-Type: multipart/mixed; boundary="0000000000005d654405f082adb7"
+
+--0000000000005d654405f082adb7
+Content-Type: multipart/alternative; boundary="0000000000005d654205f082adb5"
+
+--0000000000005d654205f082adb5
+Content-Type: text/plain; charset="UTF-8"
+
+Hello!
+
+Here's the attachments!
+
+It includes:
+
+ - Lots of whitespace
+ - Little to no content
+ - and is a quick read
+
+Best,
+
+Mallori
+
+--0000000000005d654205f082adb5
+Content-Type: text/html; charset="UTF-8"
+Content-Transfer-Encoding: quoted-printable
+
+
Hello!=C2=A0
Here's the attachments=
+!
It includes:
Lots of whitespace
Little=C2=
+=A0to no content
and is a quick read
Best,
Mallori
+
+--0000000000005d654205f082adb5--
+--0000000000005d654405f082adb7
+Content-Type: text/plain; charset="US-ASCII"; name="fake-attachment.txt"
+Content-Disposition: attachment; filename="fake-attachment.txt"
+Content-Transfer-Encoding: base64
+X-Attachment-Id: f_lc0tto5j0
+Content-ID:
+
+SGV5IHRoaXMgaXMgYSBmYWtlIGF0dGFjaG1lbnQh
+--0000000000005d654405f082adb7--
\ No newline at end of file
diff --git a/libs/community/tests/examples/fake.odt b/libs/community/tests/examples/fake.odt
new file mode 100644
index 00000000000..90504997238
Binary files /dev/null and b/libs/community/tests/examples/fake.odt differ
diff --git a/libs/community/tests/examples/hello.msg b/libs/community/tests/examples/hello.msg
new file mode 100644
index 00000000000..0dac0e86a9c
Binary files /dev/null and b/libs/community/tests/examples/hello.msg differ
diff --git a/libs/community/tests/examples/hello.pdf b/libs/community/tests/examples/hello.pdf
new file mode 100644
index 00000000000..4eb6f2ac534
Binary files /dev/null and b/libs/community/tests/examples/hello.pdf differ
diff --git a/libs/community/tests/examples/hello_world.js b/libs/community/tests/examples/hello_world.js
new file mode 100644
index 00000000000..1d41c876c8d
--- /dev/null
+++ b/libs/community/tests/examples/hello_world.js
@@ -0,0 +1,12 @@
+class HelloWorld {
+ sayHello() {
+ console.log("Hello World!");
+ }
+}
+
+function main() {
+ const hello = new HelloWorld();
+ hello.sayHello();
+}
+
+main();
diff --git a/libs/community/tests/examples/hello_world.py b/libs/community/tests/examples/hello_world.py
new file mode 100644
index 00000000000..3f0294febb4
--- /dev/null
+++ b/libs/community/tests/examples/hello_world.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python3
+
+import sys
+
+
+def main() -> int:
+ print("Hello World!")
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/libs/langchain/tests/unit_tests/document_loaders/sample_documents/layout-parser-paper.pdf b/libs/community/tests/examples/layout-parser-paper.pdf
similarity index 100%
rename from libs/langchain/tests/unit_tests/document_loaders/sample_documents/layout-parser-paper.pdf
rename to libs/community/tests/examples/layout-parser-paper.pdf
diff --git a/libs/community/tests/examples/multi-page-forms-sample-2-page.pdf b/libs/community/tests/examples/multi-page-forms-sample-2-page.pdf
new file mode 100644
index 00000000000..de6ddd0f7e8
Binary files /dev/null and b/libs/community/tests/examples/multi-page-forms-sample-2-page.pdf differ
diff --git a/libs/community/tests/examples/non-utf8-encoding.py b/libs/community/tests/examples/non-utf8-encoding.py
new file mode 100644
index 00000000000..e00f46c5258
--- /dev/null
+++ b/libs/community/tests/examples/non-utf8-encoding.py
@@ -0,0 +1,3 @@
+# coding: iso-8859-5
+# ±Άΰαβγδεζηθικλμνξο <- Cyrillic characters
+u = "βπΔ"
diff --git a/libs/community/tests/examples/sample_rss_feeds.opml b/libs/community/tests/examples/sample_rss_feeds.opml
new file mode 100644
index 00000000000..290b2c5db2e
--- /dev/null
+++ b/libs/community/tests/examples/sample_rss_feeds.opml
@@ -0,0 +1,13 @@
+
+
+
+
+ Sample RSS feed subscriptions
+
+
+
+
+
+
+
+
diff --git a/libs/community/tests/examples/sitemap.xml b/libs/community/tests/examples/sitemap.xml
new file mode 100644
index 00000000000..1629211233e
--- /dev/null
+++ b/libs/community/tests/examples/sitemap.xml
@@ -0,0 +1,35 @@
+
+
+
+
+ https://python.langchain.com/en/stable/
+
+
+ 2023-05-04T16:15:31.377584+00:00
+
+ weekly
+ 1
+
+
+
+ https://python.langchain.com/en/latest/
+
+
+ 2023-05-05T07:52:19.633878+00:00
+
+ daily
+ 0.9
+
+
+
+ https://python.langchain.com/en/harrison-docs-refactor-3-24/
+
+
+ 2023-03-27T02:32:55.132916+00:00
+
+ monthly
+ 0.8
+
+
+
\ No newline at end of file
diff --git a/libs/community/tests/examples/slack_export.zip b/libs/community/tests/examples/slack_export.zip
new file mode 100644
index 00000000000..756809ad719
Binary files /dev/null and b/libs/community/tests/examples/slack_export.zip differ
diff --git a/libs/community/tests/examples/stanley-cups.csv b/libs/community/tests/examples/stanley-cups.csv
new file mode 100644
index 00000000000..482a10ddfd1
--- /dev/null
+++ b/libs/community/tests/examples/stanley-cups.csv
@@ -0,0 +1,5 @@
+Stanley Cups,,
+Team,Location,Stanley Cups
+Blues,STL,1
+Flyers,PHI,2
+Maple Leafs,TOR,13
\ No newline at end of file
diff --git a/libs/community/tests/examples/stanley-cups.tsv b/libs/community/tests/examples/stanley-cups.tsv
new file mode 100644
index 00000000000..314be466da6
--- /dev/null
+++ b/libs/community/tests/examples/stanley-cups.tsv
@@ -0,0 +1,5 @@
+Stanley Cups
+Team Location Stanley Cups
+Blues STL 1
+Flyers PHI 2
+Maple Leafs TOR 13
diff --git a/libs/community/tests/examples/stanley-cups.xlsx b/libs/community/tests/examples/stanley-cups.xlsx
new file mode 100644
index 00000000000..ebc66599b2a
Binary files /dev/null and b/libs/community/tests/examples/stanley-cups.xlsx differ
diff --git a/libs/community/tests/examples/whatsapp_chat.txt b/libs/community/tests/examples/whatsapp_chat.txt
new file mode 100644
index 00000000000..605af130f2f
--- /dev/null
+++ b/libs/community/tests/examples/whatsapp_chat.txt
@@ -0,0 +1,10 @@
+[05.05.23, 15:48:11] James: Hi here
+[11/8/21, 9:41:32 AM] User name: Message 123
+1/23/23, 3:19 AM - User 2: Bye!
+1/23/23, 3:22_AM - User 1: And let me know if anything changes
+[1/24/21, 12:41:03 PM] ~ User name 2: Of course!
+[2023/5/4, 16:13:23] ~ User 2: See you!
+7/19/22, 11:32β―PM - User 1: Hello
+7/20/22, 11:32β―am - User 2: Goodbye
+4/20/23, 9:42β―am - User 3:
+6/29/23, 12:16β―am - User 4: This message was deleted
diff --git a/libs/langchain/tests/unit_tests/document_loaders/loaders/vendors/__init__.py b/libs/community/tests/integration_tests/__init__.py
similarity index 100%
rename from libs/langchain/tests/unit_tests/document_loaders/loaders/vendors/__init__.py
rename to libs/community/tests/integration_tests/__init__.py
diff --git a/libs/langchain/tests/unit_tests/document_loaders/parsers/language/__init__.py b/libs/community/tests/integration_tests/adapters/__init__.py
similarity index 100%
rename from libs/langchain/tests/unit_tests/document_loaders/parsers/language/__init__.py
rename to libs/community/tests/integration_tests/adapters/__init__.py
diff --git a/libs/langchain/tests/integration_tests/adapters/test_openai.py b/libs/community/tests/integration_tests/adapters/test_openai.py
similarity index 98%
rename from libs/langchain/tests/integration_tests/adapters/test_openai.py
rename to libs/community/tests/integration_tests/adapters/test_openai.py
index b9279e8952e..3f8f6c01c8d 100644
--- a/libs/langchain/tests/integration_tests/adapters/test_openai.py
+++ b/libs/community/tests/integration_tests/adapters/test_openai.py
@@ -1,6 +1,6 @@
from typing import Any
-from langchain.adapters import openai as lcopenai
+from langchain_community.adapters import openai as lcopenai
def _test_no_stream(**kwargs: Any) -> None:
diff --git a/libs/langchain/tests/unit_tests/document_loaders/sample_documents/__init__.py b/libs/community/tests/integration_tests/callbacks/__init__.py
similarity index 100%
rename from libs/langchain/tests/unit_tests/document_loaders/sample_documents/__init__.py
rename to libs/community/tests/integration_tests/callbacks/__init__.py
diff --git a/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py b/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py
new file mode 100644
index 00000000000..d4941b6a377
--- /dev/null
+++ b/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py
@@ -0,0 +1,297 @@
+"""Integration tests for the langchain tracer module."""
+import asyncio
+import os
+
+from aiohttp import ClientSession
+from langchain_core.callbacks.manager import atrace_as_chain_group, trace_as_chain_group
+from langchain_core.prompts import PromptTemplate
+from langchain_core.tracers.context import tracing_enabled, tracing_v2_enabled
+
+from langchain_community.chat_models import ChatOpenAI
+from langchain_community.llms import OpenAI
+
+questions = [
+ (
+ "Who won the US Open men's final in 2019? "
+ "What is his age raised to the 0.334 power?"
+ ),
+ (
+ "Who is Olivia Wilde's boyfriend? "
+ "What is his current age raised to the 0.23 power?"
+ ),
+ (
+ "Who won the most recent formula 1 grand prix? "
+ "What is their age raised to the 0.23 power?"
+ ),
+ (
+ "Who won the US Open women's final in 2019? "
+ "What is her age raised to the 0.34 power?"
+ ),
+ ("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"),
+]
+
+
+def test_tracing_sequential() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_TRACING"] = "true"
+
+ for q in questions[:3]:
+ llm = OpenAI(temperature=0)
+ tools = load_tools(["llm-math", "serpapi"], llm=llm)
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ agent.run(q)
+
+
+def test_tracing_session_env_var() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_TRACING"] = "true"
+ os.environ["LANGCHAIN_SESSION"] = "my_session"
+
+ llm = OpenAI(temperature=0)
+ tools = load_tools(["llm-math", "serpapi"], llm=llm)
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ agent.run(questions[0])
+ if "LANGCHAIN_SESSION" in os.environ:
+ del os.environ["LANGCHAIN_SESSION"]
+
+
+async def test_tracing_concurrent() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_TRACING"] = "true"
+ aiosession = ClientSession()
+ llm = OpenAI(temperature=0)
+ async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
+ agent = initialize_agent(
+ async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ tasks = [agent.arun(q) for q in questions[:3]]
+ await asyncio.gather(*tasks)
+ await aiosession.close()
+
+
+async def test_tracing_concurrent_bw_compat_environ() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_HANDLER"] = "langchain"
+ if "LANGCHAIN_TRACING" in os.environ:
+ del os.environ["LANGCHAIN_TRACING"]
+ aiosession = ClientSession()
+ llm = OpenAI(temperature=0)
+ async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
+ agent = initialize_agent(
+ async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ tasks = [agent.arun(q) for q in questions[:3]]
+ await asyncio.gather(*tasks)
+ await aiosession.close()
+ if "LANGCHAIN_HANDLER" in os.environ:
+ del os.environ["LANGCHAIN_HANDLER"]
+
+
+def test_tracing_context_manager() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ llm = OpenAI(temperature=0)
+ tools = load_tools(["llm-math", "serpapi"], llm=llm)
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ if "LANGCHAIN_TRACING" in os.environ:
+ del os.environ["LANGCHAIN_TRACING"]
+ with tracing_enabled() as session:
+ assert session
+ agent.run(questions[0]) # this should be traced
+
+ agent.run(questions[0]) # this should not be traced
+
+
+async def test_tracing_context_manager_async() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ llm = OpenAI(temperature=0)
+ async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
+ agent = initialize_agent(
+ async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ if "LANGCHAIN_TRACING" in os.environ:
+ del os.environ["LANGCHAIN_TRACING"]
+
+ # start a background task
+ task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
+ with tracing_enabled() as session:
+ assert session
+ tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
+ await asyncio.gather(*tasks)
+
+ await task
+
+
+async def test_tracing_v2_environment_variable() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
+
+ aiosession = ClientSession()
+ llm = OpenAI(temperature=0)
+ async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
+ agent = initialize_agent(
+ async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ tasks = [agent.arun(q) for q in questions[:3]]
+ await asyncio.gather(*tasks)
+ await aiosession.close()
+
+
+def test_tracing_v2_context_manager() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ llm = ChatOpenAI(temperature=0)
+ tools = load_tools(["llm-math", "serpapi"], llm=llm)
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ if "LANGCHAIN_TRACING_V2" in os.environ:
+ del os.environ["LANGCHAIN_TRACING_V2"]
+ with tracing_v2_enabled():
+ agent.run(questions[0]) # this should be traced
+
+ agent.run(questions[0]) # this should not be traced
+
+
+def test_tracing_v2_chain_with_tags() -> None:
+ from langchain.chains.constitutional_ai.base import ConstitutionalChain
+ from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
+ from langchain.chains.llm import LLMChain
+
+ llm = OpenAI(temperature=0)
+ chain = ConstitutionalChain.from_llm(
+ llm,
+ chain=LLMChain.from_string(llm, "Q: {question} A:"),
+ tags=["only-root"],
+ constitutional_principles=[
+ ConstitutionalPrinciple(
+ critique_request="Tell if this answer is good.",
+ revision_request="Give a better answer.",
+ )
+ ],
+ )
+ if "LANGCHAIN_TRACING_V2" in os.environ:
+ del os.environ["LANGCHAIN_TRACING_V2"]
+ with tracing_v2_enabled():
+ chain.run("what is the meaning of life", tags=["a-tag"])
+
+
+def test_tracing_v2_agent_with_metadata() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
+ llm = OpenAI(temperature=0)
+ chat = ChatOpenAI(temperature=0)
+ tools = load_tools(["llm-math", "serpapi"], llm=llm)
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ chat_agent = initialize_agent(
+ tools, chat, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
+ chat_agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
+
+
+async def test_tracing_v2_async_agent_with_metadata() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
+ llm = OpenAI(temperature=0, metadata={"f": "g", "h": "i"})
+ chat = ChatOpenAI(temperature=0, metadata={"f": "g", "h": "i"})
+ async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
+ agent = initialize_agent(
+ async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ chat_agent = initialize_agent(
+ async_tools,
+ chat,
+ agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
+ verbose=True,
+ )
+ await agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
+ await chat_agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
+
+
+def test_trace_as_group() -> None:
+ from langchain.chains.llm import LLMChain
+
+ llm = OpenAI(temperature=0.9)
+ prompt = PromptTemplate(
+ input_variables=["product"],
+ template="What is a good name for a company that makes {product}?",
+ )
+ chain = LLMChain(llm=llm, prompt=prompt)
+ with trace_as_chain_group("my_group", inputs={"input": "cars"}) as group_manager:
+ chain.run(product="cars", callbacks=group_manager)
+ chain.run(product="computers", callbacks=group_manager)
+ final_res = chain.run(product="toys", callbacks=group_manager)
+ group_manager.on_chain_end({"output": final_res})
+
+ with trace_as_chain_group("my_group_2", inputs={"input": "toys"}) as group_manager:
+ final_res = chain.run(product="toys", callbacks=group_manager)
+ group_manager.on_chain_end({"output": final_res})
+
+
+def test_trace_as_group_with_env_set() -> None:
+ from langchain.chains.llm import LLMChain
+
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
+ llm = OpenAI(temperature=0.9)
+ prompt = PromptTemplate(
+ input_variables=["product"],
+ template="What is a good name for a company that makes {product}?",
+ )
+ chain = LLMChain(llm=llm, prompt=prompt)
+ with trace_as_chain_group(
+ "my_group_env_set", inputs={"input": "cars"}
+ ) as group_manager:
+ chain.run(product="cars", callbacks=group_manager)
+ chain.run(product="computers", callbacks=group_manager)
+ final_res = chain.run(product="toys", callbacks=group_manager)
+ group_manager.on_chain_end({"output": final_res})
+
+ with trace_as_chain_group(
+ "my_group_2_env_set", inputs={"input": "toys"}
+ ) as group_manager:
+ final_res = chain.run(product="toys", callbacks=group_manager)
+ group_manager.on_chain_end({"output": final_res})
+
+
+async def test_trace_as_group_async() -> None:
+ from langchain.chains.llm import LLMChain
+
+ llm = OpenAI(temperature=0.9)
+ prompt = PromptTemplate(
+ input_variables=["product"],
+ template="What is a good name for a company that makes {product}?",
+ )
+ chain = LLMChain(llm=llm, prompt=prompt)
+ async with atrace_as_chain_group("my_async_group") as group_manager:
+ await chain.arun(product="cars", callbacks=group_manager)
+ await chain.arun(product="computers", callbacks=group_manager)
+ await chain.arun(product="toys", callbacks=group_manager)
+
+ async with atrace_as_chain_group(
+ "my_async_group_2", inputs={"input": "toys"}
+ ) as group_manager:
+ res = await asyncio.gather(
+ *[
+ chain.arun(product="toys", callbacks=group_manager),
+ chain.arun(product="computers", callbacks=group_manager),
+ chain.arun(product="cars", callbacks=group_manager),
+ ]
+ )
+ await group_manager.on_chain_end({"output": res})
diff --git a/libs/community/tests/integration_tests/callbacks/test_openai_callback.py b/libs/community/tests/integration_tests/callbacks/test_openai_callback.py
new file mode 100644
index 00000000000..5112f4dd84e
--- /dev/null
+++ b/libs/community/tests/integration_tests/callbacks/test_openai_callback.py
@@ -0,0 +1,68 @@
+"""Integration tests for the langchain tracer module."""
+import asyncio
+
+from langchain_community.callbacks import get_openai_callback
+from langchain_community.llms import OpenAI
+
+
+async def test_openai_callback() -> None:
+ llm = OpenAI(temperature=0)
+ with get_openai_callback() as cb:
+ llm("What is the square root of 4?")
+
+ total_tokens = cb.total_tokens
+ assert total_tokens > 0
+
+ with get_openai_callback() as cb:
+ llm("What is the square root of 4?")
+ llm("What is the square root of 4?")
+
+ assert cb.total_tokens == total_tokens * 2
+
+ with get_openai_callback() as cb:
+ await asyncio.gather(
+ *[llm.agenerate(["What is the square root of 4?"]) for _ in range(3)]
+ )
+
+ assert cb.total_tokens == total_tokens * 3
+
+ task = asyncio.create_task(llm.agenerate(["What is the square root of 4?"]))
+ with get_openai_callback() as cb:
+ await llm.agenerate(["What is the square root of 4?"])
+
+ await task
+ assert cb.total_tokens == total_tokens
+
+
+def test_openai_callback_batch_llm() -> None:
+ llm = OpenAI(temperature=0)
+ with get_openai_callback() as cb:
+ llm.generate(["What is the square root of 4?", "What is the square root of 4?"])
+
+ assert cb.total_tokens > 0
+ total_tokens = cb.total_tokens
+
+ with get_openai_callback() as cb:
+ llm("What is the square root of 4?")
+ llm("What is the square root of 4?")
+
+ assert cb.total_tokens == total_tokens
+
+
+def test_openai_callback_agent() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ llm = OpenAI(temperature=0)
+ tools = load_tools(["serpapi", "llm-math"], llm=llm)
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ with get_openai_callback() as cb:
+ agent.run(
+ "Who is Olivia Wilde's boyfriend? "
+ "What is his current age raised to the 0.23 power?"
+ )
+ print(f"Total Tokens: {cb.total_tokens}")
+ print(f"Prompt Tokens: {cb.prompt_tokens}")
+ print(f"Completion Tokens: {cb.completion_tokens}")
+ print(f"Total Cost (USD): ${cb.total_cost}")
diff --git a/libs/community/tests/integration_tests/callbacks/test_streamlit_callback.py b/libs/community/tests/integration_tests/callbacks/test_streamlit_callback.py
new file mode 100644
index 00000000000..13777d0b9fa
--- /dev/null
+++ b/libs/community/tests/integration_tests/callbacks/test_streamlit_callback.py
@@ -0,0 +1,30 @@
+"""Integration tests for the StreamlitCallbackHandler module."""
+
+import pytest
+
+# Import the internal StreamlitCallbackHandler from its module - and not from
+# the `langchain_community.callbacks.streamlit` package - so that we don't end up using
+# Streamlit's externally-provided callback handler.
+from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
+ StreamlitCallbackHandler,
+)
+from langchain_community.llms import OpenAI
+
+
+@pytest.mark.requires("streamlit")
+def test_streamlit_callback_agent() -> None:
+ import streamlit as st
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ streamlit_callback = StreamlitCallbackHandler(st.container())
+
+ llm = OpenAI(temperature=0)
+ tools = load_tools(["serpapi", "llm-math"], llm=llm)
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ agent.run(
+ "Who is Olivia Wilde's boyfriend? "
+ "What is his current age raised to the 0.23 power?",
+ callbacks=[streamlit_callback],
+ )
diff --git a/libs/community/tests/integration_tests/callbacks/test_wandb_tracer.py b/libs/community/tests/integration_tests/callbacks/test_wandb_tracer.py
new file mode 100644
index 00000000000..7553d3198fe
--- /dev/null
+++ b/libs/community/tests/integration_tests/callbacks/test_wandb_tracer.py
@@ -0,0 +1,123 @@
+"""Integration tests for the langchain tracer module."""
+import asyncio
+import os
+
+from aiohttp import ClientSession
+
+from langchain_community.callbacks import wandb_tracing_enabled
+from langchain_community.llms import OpenAI
+
+questions = [
+ (
+ "Who won the US Open men's final in 2019? "
+ "What is his age raised to the 0.334 power?"
+ ),
+ (
+ "Who is Olivia Wilde's boyfriend? "
+ "What is his current age raised to the 0.23 power?"
+ ),
+ (
+ "Who won the most recent formula 1 grand prix? "
+ "What is their age raised to the 0.23 power?"
+ ),
+ (
+ "Who won the US Open women's final in 2019? "
+ "What is her age raised to the 0.34 power?"
+ ),
+ ("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"),
+]
+
+
+def test_tracing_sequential() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
+ os.environ["WANDB_PROJECT"] = "langchain-tracing"
+
+ for q in questions[:3]:
+ llm = OpenAI(temperature=0)
+ tools = load_tools(
+ ["llm-math", "serpapi"],
+ llm=llm,
+ )
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ agent.run(q)
+
+
+def test_tracing_session_env_var() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
+
+ llm = OpenAI(temperature=0)
+ tools = load_tools(
+ ["llm-math", "serpapi"],
+ llm=llm,
+ )
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ agent.run(questions[0])
+
+
+async def test_tracing_concurrent() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
+ aiosession = ClientSession()
+ llm = OpenAI(temperature=0)
+ async_tools = load_tools(
+ ["llm-math", "serpapi"],
+ llm=llm,
+ aiosession=aiosession,
+ )
+ agent = initialize_agent(
+ async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ tasks = [agent.arun(q) for q in questions[:3]]
+ await asyncio.gather(*tasks)
+ await aiosession.close()
+
+
+def test_tracing_context_manager() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ llm = OpenAI(temperature=0)
+ tools = load_tools(
+ ["llm-math", "serpapi"],
+ llm=llm,
+ )
+ agent = initialize_agent(
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ if "LANGCHAIN_WANDB_TRACING" in os.environ:
+ del os.environ["LANGCHAIN_WANDB_TRACING"]
+ with wandb_tracing_enabled():
+ agent.run(questions[0]) # this should be traced
+
+ agent.run(questions[0]) # this should not be traced
+
+
+async def test_tracing_context_manager_async() -> None:
+ from langchain.agents import AgentType, initialize_agent, load_tools
+
+ llm = OpenAI(temperature=0)
+ async_tools = load_tools(
+ ["llm-math", "serpapi"],
+ llm=llm,
+ )
+ agent = initialize_agent(
+ async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
+ )
+ if "LANGCHAIN_WANDB_TRACING" in os.environ:
+ del os.environ["LANGCHAIN_TRACING"]
+
+ # start a background task
+ task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
+ with wandb_tracing_enabled():
+ tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
+ await asyncio.gather(*tasks)
+
+ await task
diff --git a/libs/langchain/tests/unit_tests/tools/eden_ai/__init__.py b/libs/community/tests/integration_tests/chat_message_histories/__init__.py
similarity index 100%
rename from libs/langchain/tests/unit_tests/tools/eden_ai/__init__.py
rename to libs/community/tests/integration_tests/chat_message_histories/__init__.py
diff --git a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py b/libs/community/tests/integration_tests/chat_message_histories/test_streamlit.py
similarity index 96%
rename from libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py
rename to libs/community/tests/integration_tests/chat_message_histories/test_streamlit.py
index 84a29bfc013..16a304bf381 100644
--- a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py
+++ b/libs/community/tests/integration_tests/chat_message_histories/test_streamlit.py
@@ -5,7 +5,7 @@ test_script = """
import json
import streamlit as st
from langchain.memory import ConversationBufferMemory
- from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_core.messages import message_to_dict
message_history = StreamlitChatMessageHistory()
diff --git a/libs/langchain/tests/integration_tests/memory/chat_message_histories/test_zep.py b/libs/community/tests/integration_tests/chat_message_histories/test_zep.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/memory/chat_message_histories/test_zep.py
rename to libs/community/tests/integration_tests/chat_message_histories/test_zep.py
index fa8e47b1565..9391f84cc8f 100644
--- a/libs/langchain/tests/integration_tests/memory/chat_message_histories/test_zep.py
+++ b/libs/community/tests/integration_tests/chat_message_histories/test_zep.py
@@ -4,7 +4,7 @@ import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from pytest_mock import MockerFixture
-from langchain.memory.chat_message_histories import ZepChatMessageHistory
+from langchain_community.chat_message_histories import ZepChatMessageHistory
if TYPE_CHECKING:
from zep_python import ZepClient
diff --git a/libs/langchain/tests/unit_tests/tools/file_management/__init__.py b/libs/community/tests/integration_tests/chat_models/__init__.py
similarity index 100%
rename from libs/langchain/tests/unit_tests/tools/file_management/__init__.py
rename to libs/community/tests/integration_tests/chat_models/__init__.py
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py b/libs/community/tests/integration_tests/chat_models/test_anthropic.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/chat_models/test_anthropic.py
rename to libs/community/tests/integration_tests/chat_models/test_anthropic.py
index cce21e39238..a1ddc14ead1 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py
+++ b/libs/community/tests/integration_tests/chat_models/test_anthropic.py
@@ -2,11 +2,11 @@
from typing import List
import pytest
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models.anthropic import (
+from langchain_community.chat_models.anthropic import (
ChatAnthropic,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_azure_openai.py b/libs/community/tests/integration_tests/chat_models/test_azure_openai.py
similarity index 98%
rename from libs/langchain/tests/integration_tests/chat_models/test_azure_openai.py
rename to libs/community/tests/integration_tests/chat_models/test_azure_openai.py
index 9a69008f656..23812775c0e 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_azure_openai.py
+++ b/libs/community/tests/integration_tests/chat_models/test_azure_openai.py
@@ -3,11 +3,11 @@ import os
from typing import Any
import pytest
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models import AzureChatOpenAI
+from langchain_community.chat_models import AzureChatOpenAI
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "")
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_azureml_endpoint.py b/libs/community/tests/integration_tests/chat_models/test_azureml_endpoint.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/chat_models/test_azureml_endpoint.py
rename to libs/community/tests/integration_tests/chat_models/test_azureml_endpoint.py
index 929074866bd..d8871e784f0 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_azureml_endpoint.py
+++ b/libs/community/tests/integration_tests/chat_models/test_azureml_endpoint.py
@@ -3,7 +3,7 @@
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
-from langchain.chat_models.azureml_endpoint import (
+from langchain_community.chat_models.azureml_endpoint import (
AzureMLChatOnlineEndpoint,
LlamaContentFormatter,
)
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py b/libs/community/tests/integration_tests/chat_models/test_baichuan.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/chat_models/test_baichuan.py
rename to libs/community/tests/integration_tests/chat_models/test_baichuan.py
index d4689641155..0ad3ab74799 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py
+++ b/libs/community/tests/integration_tests/chat_models/test_baichuan.py
@@ -1,6 +1,6 @@
from langchain_core.messages import AIMessage, HumanMessage
-from langchain.chat_models.baichuan import ChatBaichuan
+from langchain_community.chat_models.baichuan import ChatBaichuan
def test_chat_baichuan() -> None:
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_baiduqianfan.py b/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/chat_models/test_baiduqianfan.py
rename to libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py
index e8a4dfae62e..c0596937618 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_baiduqianfan.py
+++ b/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py
@@ -3,7 +3,7 @@ from typing import cast
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
-from langchain.chat_models.baidu_qianfan_endpoint import (
+from langchain_community.chat_models.baidu_qianfan_endpoint import (
QianfanChatEndpoint,
)
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_bedrock.py b/libs/community/tests/integration_tests/chat_models/test_bedrock.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/chat_models/test_bedrock.py
rename to libs/community/tests/integration_tests/chat_models/test_bedrock.py
index 1750f53f730..301260803d9 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_bedrock.py
+++ b/libs/community/tests/integration_tests/chat_models/test_bedrock.py
@@ -2,11 +2,11 @@
from typing import Any
import pytest
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, LLMResult
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models import BedrockChat
+from langchain_community.chat_models import BedrockChat
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py b/libs/community/tests/integration_tests/chat_models/test_ernie.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/chat_models/test_ernie.py
rename to libs/community/tests/integration_tests/chat_models/test_ernie.py
index 6db6321d1cf..4d472fdd9ef 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py
+++ b/libs/community/tests/integration_tests/chat_models/test_ernie.py
@@ -1,7 +1,7 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage
-from langchain.chat_models.ernie import ErnieBotChat
+from langchain_community.chat_models.ernie import ErnieBotChat
def test_chat_ernie_bot() -> None:
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/community/tests/integration_tests/chat_models/test_fireworks.py
similarity index 98%
rename from libs/langchain/tests/integration_tests/chat_models/test_fireworks.py
rename to libs/community/tests/integration_tests/chat_models/test_fireworks.py
index fc25e7febcc..9ab943be873 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py
+++ b/libs/community/tests/integration_tests/chat_models/test_fireworks.py
@@ -6,7 +6,7 @@ import pytest
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
-from langchain.chat_models.fireworks import ChatFireworks
+from langchain_community.chat_models.fireworks import ChatFireworks
if sys.version_info < (3, 9):
pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True)
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_google_palm.py b/libs/community/tests/integration_tests/chat_models/test_google_palm.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/chat_models/test_google_palm.py
rename to libs/community/tests/integration_tests/chat_models/test_google_palm.py
index 3e2ae6eb046..6bb0b4db474 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_google_palm.py
+++ b/libs/community/tests/integration_tests/chat_models/test_google_palm.py
@@ -7,7 +7,7 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
-from langchain.chat_models import ChatGooglePalm
+from langchain_community.chat_models import ChatGooglePalm
def test_chat_google_palm() -> None:
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_hunyuan.py b/libs/community/tests/integration_tests/chat_models/test_hunyuan.py
similarity index 91%
rename from libs/langchain/tests/integration_tests/chat_models/test_hunyuan.py
rename to libs/community/tests/integration_tests/chat_models/test_hunyuan.py
index 47b60864acf..0901cc4cf6f 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_hunyuan.py
+++ b/libs/community/tests/integration_tests/chat_models/test_hunyuan.py
@@ -1,6 +1,6 @@
from langchain_core.messages import AIMessage, HumanMessage
-from langchain.chat_models.hunyuan import ChatHunyuan
+from langchain_community.chat_models.hunyuan import ChatHunyuan
def test_chat_hunyuan() -> None:
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py b/libs/community/tests/integration_tests/chat_models/test_jinachat.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/chat_models/test_jinachat.py
rename to libs/community/tests/integration_tests/chat_models/test_jinachat.py
index 85ec7c9d45f..0d704ce3867 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py
+++ b/libs/community/tests/integration_tests/chat_models/test_jinachat.py
@@ -3,13 +3,13 @@
from typing import cast
import pytest
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models.jinachat import JinaChat
+from langchain_community.chat_models.jinachat import JinaChat
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_konko.py b/libs/community/tests/integration_tests/chat_models/test_konko.py
similarity index 98%
rename from libs/langchain/tests/integration_tests/chat_models/test_konko.py
rename to libs/community/tests/integration_tests/chat_models/test_konko.py
index 56f479b3026..7cfb5be1c25 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_konko.py
+++ b/libs/community/tests/integration_tests/chat_models/test_konko.py
@@ -2,11 +2,11 @@
from typing import Any
import pytest
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models.konko import ChatKonko
+from langchain_community.chat_models.konko import ChatKonko
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_litellm.py b/libs/community/tests/integration_tests/chat_models/test_litellm.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/chat_models/test_litellm.py
rename to libs/community/tests/integration_tests/chat_models/test_litellm.py
index 571a287daaf..c71d0d3ac1c 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_litellm.py
+++ b/libs/community/tests/integration_tests/chat_models/test_litellm.py
@@ -1,13 +1,13 @@
"""Test Anthropic API wrapper."""
from typing import List
+from langchain_core.callbacks import (
+ CallbackManager,
+)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
-from langchain.callbacks.manager import (
- CallbackManager,
-)
-from langchain.chat_models.litellm import ChatLiteLLM
+from langchain_community.chat_models.litellm import ChatLiteLLM
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/community/tests/integration_tests/chat_models/test_openai.py b/libs/community/tests/integration_tests/chat_models/test_openai.py
new file mode 100644
index 00000000000..40eed8670a0
--- /dev/null
+++ b/libs/community/tests/integration_tests/chat_models/test_openai.py
@@ -0,0 +1,332 @@
+"""Test ChatOpenAI wrapper."""
+from typing import Any, Optional
+
+import pytest
+from langchain_core.callbacks import CallbackManager
+from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
+from langchain_core.outputs import (
+ ChatGeneration,
+ ChatResult,
+ LLMResult,
+)
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.pydantic_v1 import BaseModel, Field
+
+from langchain_community.chat_models.openai import ChatOpenAI
+from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
+
+
+@pytest.mark.scheduled
+def test_chat_openai() -> None:
+ """Test ChatOpenAI wrapper."""
+ chat = ChatOpenAI(
+ temperature=0.7,
+ base_url=None,
+ organization=None,
+ openai_proxy=None,
+ timeout=10.0,
+ max_retries=3,
+ http_client=None,
+ n=1,
+ max_tokens=10,
+ default_headers=None,
+ default_query=None,
+ )
+ message = HumanMessage(content="Hello")
+ response = chat([message])
+ assert isinstance(response, BaseMessage)
+ assert isinstance(response.content, str)
+
+
+def test_chat_openai_model() -> None:
+ """Test ChatOpenAI wrapper handles model_name."""
+ chat = ChatOpenAI(model="foo")
+ assert chat.model_name == "foo"
+ chat = ChatOpenAI(model_name="bar")
+ assert chat.model_name == "bar"
+
+
+def test_chat_openai_system_message() -> None:
+ """Test ChatOpenAI wrapper with system message."""
+ chat = ChatOpenAI(max_tokens=10)
+ system_message = SystemMessage(content="You are to chat with the user.")
+ human_message = HumanMessage(content="Hello")
+ response = chat([system_message, human_message])
+ assert isinstance(response, BaseMessage)
+ assert isinstance(response.content, str)
+
+
+@pytest.mark.scheduled
+def test_chat_openai_generate() -> None:
+ """Test ChatOpenAI wrapper with generate."""
+ chat = ChatOpenAI(max_tokens=10, n=2)
+ message = HumanMessage(content="Hello")
+ response = chat.generate([[message], [message]])
+ assert isinstance(response, LLMResult)
+ assert len(response.generations) == 2
+ assert response.llm_output
+ for generations in response.generations:
+ assert len(generations) == 2
+ for generation in generations:
+ assert isinstance(generation, ChatGeneration)
+ assert isinstance(generation.text, str)
+ assert generation.text == generation.message.content
+
+
+@pytest.mark.scheduled
+def test_chat_openai_multiple_completions() -> None:
+ """Test ChatOpenAI wrapper with multiple completions."""
+ chat = ChatOpenAI(max_tokens=10, n=5)
+ message = HumanMessage(content="Hello")
+ response = chat._generate([message])
+ assert isinstance(response, ChatResult)
+ assert len(response.generations) == 5
+ for generation in response.generations:
+ assert isinstance(generation.message, BaseMessage)
+ assert isinstance(generation.message.content, str)
+
+
+@pytest.mark.scheduled
+def test_chat_openai_streaming() -> None:
+ """Test that streaming correctly invokes on_llm_new_token callback."""
+ callback_handler = FakeCallbackHandler()
+ callback_manager = CallbackManager([callback_handler])
+ chat = ChatOpenAI(
+ max_tokens=10,
+ streaming=True,
+ temperature=0,
+ callback_manager=callback_manager,
+ verbose=True,
+ )
+ message = HumanMessage(content="Hello")
+ response = chat([message])
+ assert callback_handler.llm_streams > 0
+ assert isinstance(response, BaseMessage)
+
+
+@pytest.mark.scheduled
+def test_chat_openai_streaming_generation_info() -> None:
+ """Test that generation info is preserved when streaming."""
+
+ class _FakeCallback(FakeCallbackHandler):
+ saved_things: dict = {}
+
+ def on_llm_end(
+ self,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Any:
+ # Save the generation
+ self.saved_things["generation"] = args[0]
+
+ callback = _FakeCallback()
+ callback_manager = CallbackManager([callback])
+ chat = ChatOpenAI(
+ max_tokens=2,
+ temperature=0,
+ callback_manager=callback_manager,
+ )
+ list(chat.stream("hi"))
+ generation = callback.saved_things["generation"]
+ # `Hello!` is two tokens, assert that that is what is returned
+ assert generation.generations[0][0].text == "Hello!"
+
+
+def test_chat_openai_llm_output_contains_model_name() -> None:
+ """Test llm_output contains model_name."""
+ chat = ChatOpenAI(max_tokens=10)
+ message = HumanMessage(content="Hello")
+ llm_result = chat.generate([[message]])
+ assert llm_result.llm_output is not None
+ assert llm_result.llm_output["model_name"] == chat.model_name
+
+
+def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
+ """Test llm_output contains model_name."""
+ chat = ChatOpenAI(max_tokens=10, streaming=True)
+ message = HumanMessage(content="Hello")
+ llm_result = chat.generate([[message]])
+ assert llm_result.llm_output is not None
+ assert llm_result.llm_output["model_name"] == chat.model_name
+
+
+def test_chat_openai_invalid_streaming_params() -> None:
+ """Test that streaming correctly invokes on_llm_new_token callback."""
+ with pytest.raises(ValueError):
+ ChatOpenAI(
+ max_tokens=10,
+ streaming=True,
+ temperature=0,
+ n=5,
+ )
+
+
+@pytest.mark.scheduled
+async def test_async_chat_openai() -> None:
+ """Test async generation."""
+ chat = ChatOpenAI(max_tokens=10, n=2)
+ message = HumanMessage(content="Hello")
+ response = await chat.agenerate([[message], [message]])
+ assert isinstance(response, LLMResult)
+ assert len(response.generations) == 2
+ assert response.llm_output
+ for generations in response.generations:
+ assert len(generations) == 2
+ for generation in generations:
+ assert isinstance(generation, ChatGeneration)
+ assert isinstance(generation.text, str)
+ assert generation.text == generation.message.content
+
+
+@pytest.mark.scheduled
+async def test_async_chat_openai_streaming() -> None:
+ """Test that streaming correctly invokes on_llm_new_token callback."""
+ callback_handler = FakeCallbackHandler()
+ callback_manager = CallbackManager([callback_handler])
+ chat = ChatOpenAI(
+ max_tokens=10,
+ streaming=True,
+ temperature=0,
+ callback_manager=callback_manager,
+ verbose=True,
+ )
+ message = HumanMessage(content="Hello")
+ response = await chat.agenerate([[message], [message]])
+ assert callback_handler.llm_streams > 0
+ assert isinstance(response, LLMResult)
+ assert len(response.generations) == 2
+ for generations in response.generations:
+ assert len(generations) == 1
+ for generation in generations:
+ assert isinstance(generation, ChatGeneration)
+ assert isinstance(generation.text, str)
+ assert generation.text == generation.message.content
+
+
+@pytest.mark.scheduled
+async def test_async_chat_openai_bind_functions() -> None:
+ """Test ChatOpenAI wrapper with multiple completions."""
+
+ class Person(BaseModel):
+ """Identifying information about a person."""
+
+ name: str = Field(..., title="Name", description="The person's name")
+ age: int = Field(..., title="Age", description="The person's age")
+ fav_food: Optional[str] = Field(
+ default=None, title="Fav Food", description="The person's favorite food"
+ )
+
+ chat = ChatOpenAI(
+ max_tokens=30,
+ n=1,
+ streaming=True,
+ ).bind_functions(functions=[Person], function_call="Person")
+
+ prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", "Use the provided Person function"),
+ ("user", "{input}"),
+ ]
+ )
+
+ chain = prompt | chat
+
+ message = HumanMessage(content="Sally is 13 years old")
+ response = await chain.abatch([{"input": message}])
+
+ assert isinstance(response, list)
+ assert len(response) == 1
+ for generation in response:
+ assert isinstance(generation, AIMessage)
+
+
+def test_chat_openai_extra_kwargs() -> None:
+ """Test extra kwargs to chat openai."""
+ # Check that foo is saved in extra_kwargs.
+ llm = ChatOpenAI(foo=3, max_tokens=10)
+ assert llm.max_tokens == 10
+ assert llm.model_kwargs == {"foo": 3}
+
+ # Test that if extra_kwargs are provided, they are added to it.
+ llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2})
+ assert llm.model_kwargs == {"foo": 3, "bar": 2}
+
+ # Test that if provided twice it errors
+ with pytest.raises(ValueError):
+ ChatOpenAI(foo=3, model_kwargs={"foo": 2})
+
+ # Test that if explicit param is specified in kwargs it errors
+ with pytest.raises(ValueError):
+ ChatOpenAI(model_kwargs={"temperature": 0.2})
+
+ # Test that "model" cannot be specified in kwargs
+ with pytest.raises(ValueError):
+ ChatOpenAI(model_kwargs={"model": "text-davinci-003"})
+
+
+@pytest.mark.scheduled
+def test_openai_streaming() -> None:
+ """Test streaming tokens from OpenAI."""
+ llm = ChatOpenAI(max_tokens=10)
+
+ for token in llm.stream("I'm Pickle Rick"):
+ assert isinstance(token.content, str)
+
+
+@pytest.mark.scheduled
+async def test_openai_astream() -> None:
+ """Test streaming tokens from OpenAI."""
+ llm = ChatOpenAI(max_tokens=10)
+
+ async for token in llm.astream("I'm Pickle Rick"):
+ assert isinstance(token.content, str)
+
+
+@pytest.mark.scheduled
+async def test_openai_abatch() -> None:
+ """Test streaming tokens from ChatOpenAI."""
+ llm = ChatOpenAI(max_tokens=10)
+
+ result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
+ for token in result:
+ assert isinstance(token.content, str)
+
+
+@pytest.mark.scheduled
+async def test_openai_abatch_tags() -> None:
+ """Test batch tokens from ChatOpenAI."""
+ llm = ChatOpenAI(max_tokens=10)
+
+ result = await llm.abatch(
+ ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
+ )
+ for token in result:
+ assert isinstance(token.content, str)
+
+
+@pytest.mark.scheduled
+def test_openai_batch() -> None:
+ """Test batch tokens from ChatOpenAI."""
+ llm = ChatOpenAI(max_tokens=10)
+
+ result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
+ for token in result:
+ assert isinstance(token.content, str)
+
+
+@pytest.mark.scheduled
+async def test_openai_ainvoke() -> None:
+ """Test invoke tokens from ChatOpenAI."""
+ llm = ChatOpenAI(max_tokens=10)
+
+ result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
+ assert isinstance(result.content, str)
+
+
+@pytest.mark.scheduled
+def test_openai_invoke() -> None:
+ """Test invoke tokens from ChatOpenAI."""
+ llm = ChatOpenAI(max_tokens=10)
+
+ result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
+ assert isinstance(result.content, str)
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py b/libs/community/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py
rename to libs/community/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py
index 0095a5e4a28..136153e5855 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py
+++ b/libs/community/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py
@@ -1,11 +1,11 @@
"""Test AliCloud Pai Eas Chat Model."""
import os
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
+from langchain_community.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_promptlayer_openai.py b/libs/community/tests/integration_tests/chat_models/test_promptlayer_openai.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/chat_models/test_promptlayer_openai.py
rename to libs/community/tests/integration_tests/chat_models/test_promptlayer_openai.py
index 3622701d08f..d037056d72d 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_promptlayer_openai.py
+++ b/libs/community/tests/integration_tests/chat_models/test_promptlayer_openai.py
@@ -1,11 +1,11 @@
"""Test PromptLayerChatOpenAI wrapper."""
import pytest
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
+from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py b/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py
new file mode 100644
index 00000000000..88bfc66a382
--- /dev/null
+++ b/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py
@@ -0,0 +1,219 @@
+"""Test Baidu Qianfan Chat Endpoint."""
+
+from typing import Any
+
+from langchain_core.callbacks import CallbackManager
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ FunctionMessage,
+ HumanMessage,
+)
+from langchain_core.outputs import ChatGeneration, LLMResult
+from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
+
+from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
+from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
+
+_FUNCTIONS: Any = [
+ {
+ "name": "format_person_info",
+ "description": (
+ "Output formatter. Should always be used to format your response to the"
+ " user."
+ ),
+ "parameters": {
+ "title": "Person",
+ "description": "Identifying information about a person.",
+ "type": "object",
+ "properties": {
+ "name": {
+ "title": "Name",
+ "description": "The person's name",
+ "type": "string",
+ },
+ "age": {
+ "title": "Age",
+ "description": "The person's age",
+ "type": "integer",
+ },
+ "fav_food": {
+ "title": "Fav Food",
+ "description": "The person's favorite food",
+ "type": "string",
+ },
+ },
+ "required": ["name", "age"],
+ },
+ },
+ {
+ "name": "get_current_temperature",
+ "description": ("Used to get the location's temperature."),
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "city name",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["centigrade", "Fahrenheit"],
+ },
+ },
+ "required": ["location", "unit"],
+ },
+ "responses": {
+ "type": "object",
+ "properties": {
+ "temperature": {
+ "type": "integer",
+ "description": "city temperature",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["centigrade", "Fahrenheit"],
+ },
+ },
+ },
+ },
+]
+
+
+def test_default_call() -> None:
+ """Test default model(`ERNIE-Bot`) call."""
+ chat = QianfanChatEndpoint()
+ response = chat(messages=[HumanMessage(content="Hello")])
+ assert isinstance(response, BaseMessage)
+ assert isinstance(response.content, str)
+
+
+def test_model() -> None:
+ """Test model kwarg works."""
+ chat = QianfanChatEndpoint(model="BLOOMZ-7B")
+ response = chat(messages=[HumanMessage(content="Hello")])
+ assert isinstance(response, BaseMessage)
+ assert isinstance(response.content, str)
+
+
+def test_model_param() -> None:
+ """Test model params works."""
+ chat = QianfanChatEndpoint()
+ response = chat(model="BLOOMZ-7B", messages=[HumanMessage(content="Hello")])
+ assert isinstance(response, BaseMessage)
+ assert isinstance(response.content, str)
+
+
+def test_endpoint() -> None:
+ """Test user custom model deployments like some open source models."""
+ chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
+ response = chat(messages=[HumanMessage(content="Hello")])
+ assert isinstance(response, BaseMessage)
+ assert isinstance(response.content, str)
+
+
+def test_endpoint_param() -> None:
+ """Test user custom model deployments like some open source models."""
+ chat = QianfanChatEndpoint()
+ response = chat(
+ messages=[
+ HumanMessage(endpoint="qianfan_bloomz_7b_compressed", content="Hello")
+ ]
+ )
+ assert isinstance(response, BaseMessage)
+ assert isinstance(response.content, str)
+
+
+def test_multiple_history() -> None:
+ """Tests multiple history works."""
+ chat = QianfanChatEndpoint()
+
+ response = chat(
+ messages=[
+ HumanMessage(content="Hello."),
+ AIMessage(content="Hello!"),
+ HumanMessage(content="How are you doing?"),
+ ]
+ )
+ assert isinstance(response, BaseMessage)
+ assert isinstance(response.content, str)
+
+
+def test_stream() -> None:
+ """Test that stream works."""
+ chat = QianfanChatEndpoint(streaming=True)
+ callback_handler = FakeCallbackHandler()
+ callback_manager = CallbackManager([callback_handler])
+ response = chat(
+ messages=[
+ HumanMessage(content="Hello."),
+ AIMessage(content="Hello!"),
+ HumanMessage(content="Who are you?"),
+ ],
+ stream=True,
+ callbacks=callback_manager,
+ )
+ assert callback_handler.llm_streams > 0
+ assert isinstance(response.content, str)
+
+
+def test_multiple_messages() -> None:
+ """Tests multiple messages works."""
+ chat = QianfanChatEndpoint()
+ message = HumanMessage(content="Hi, how are you.")
+ response = chat.generate([[message], [message]])
+
+ assert isinstance(response, LLMResult)
+ assert len(response.generations) == 2
+ for generations in response.generations:
+ assert len(generations) == 1
+ for generation in generations:
+ assert isinstance(generation, ChatGeneration)
+ assert isinstance(generation.text, str)
+ assert generation.text == generation.message.content
+
+
+def test_functions_call_thoughts() -> None:
+ chat = QianfanChatEndpoint(model="ERNIE-Bot")
+
+ prompt_tmpl = "Use the given functions to answer following question: {input}"
+ prompt_msgs = [
+ HumanMessagePromptTemplate.from_template(prompt_tmpl),
+ ]
+ prompt = ChatPromptTemplate(messages=prompt_msgs)
+
+ chain = prompt | chat.bind(functions=_FUNCTIONS)
+
+ message = HumanMessage(content="What's the temperature in Shanghai today?")
+ response = chain.batch([{"input": message}])
+ assert isinstance(response[0], AIMessage)
+ assert "function_call" in response[0].additional_kwargs
+
+
+def test_functions_call() -> None:
+ chat = QianfanChatEndpoint(model="ERNIE-Bot")
+
+ prompt = ChatPromptTemplate(
+ messages=[
+ HumanMessage(content="What's the temperature in Shanghai today?"),
+ AIMessage(
+ content="",
+ additional_kwargs={
+ "function_call": {
+ "name": "get_current_temperature",
+ "thoughts": "i will use get_current_temperature "
+ "to resolve the questions",
+ "arguments": '{"location":"Shanghai","unit":"centigrade"}',
+ }
+ },
+ ),
+ FunctionMessage(
+ name="get_current_weather",
+ content='{"temperature": "25", \
+ "unit": "ζζ°εΊ¦", "description": "ζ΄ζ"}',
+ ),
+ ]
+ )
+ chain = prompt | chat.bind(functions=_FUNCTIONS)
+ resp = chain.invoke({})
+ assert isinstance(resp, AIMessage)
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_tongyi.py b/libs/community/tests/integration_tests/chat_models/test_tongyi.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/chat_models/test_tongyi.py
rename to libs/community/tests/integration_tests/chat_models/test_tongyi.py
index a743cdb16a7..ebb92b24b23 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_tongyi.py
+++ b/libs/community/tests/integration_tests/chat_models/test_tongyi.py
@@ -1,10 +1,10 @@
"""Test Alibaba Tongyi Chat Model."""
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models.tongyi import ChatTongyi
+from langchain_community.chat_models.tongyi import ChatTongyi
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py b/libs/community/tests/integration_tests/chat_models/test_vertexai.py
similarity index 98%
rename from libs/langchain/tests/integration_tests/chat_models/test_vertexai.py
rename to libs/community/tests/integration_tests/chat_models/test_vertexai.py
index 790fccd70a0..9e84645e66a 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py
+++ b/libs/community/tests/integration_tests/chat_models/test_vertexai.py
@@ -19,8 +19,11 @@ from langchain_core.messages import (
)
from langchain_core.outputs import LLMResult
-from langchain.chat_models import ChatVertexAI
-from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
+from langchain_community.chat_models import ChatVertexAI
+from langchain_community.chat_models.vertexai import (
+ _parse_chat_history,
+ _parse_examples,
+)
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_volcengine_maas.py b/libs/community/tests/integration_tests/chat_models/test_volcengine_maas.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/chat_models/test_volcengine_maas.py
rename to libs/community/tests/integration_tests/chat_models/test_volcengine_maas.py
index bb52a50d7ae..2aa84b30de4 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_volcengine_maas.py
+++ b/libs/community/tests/integration_tests/chat_models/test_volcengine_maas.py
@@ -1,10 +1,10 @@
"""Test volc engine maas chat model."""
+from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
-from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
+from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
diff --git a/libs/langchain/tests/integration_tests/document_loaders/__init__.py b/libs/community/tests/integration_tests/document_loaders/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/document_loaders/__init__.py
rename to libs/community/tests/integration_tests/document_loaders/__init__.py
diff --git a/libs/langchain/tests/unit_tests/tools/openapi/__init__.py b/libs/community/tests/integration_tests/document_loaders/parsers/__init__.py
similarity index 100%
rename from libs/langchain/tests/unit_tests/tools/openapi/__init__.py
rename to libs/community/tests/integration_tests/document_loaders/parsers/__init__.py
diff --git a/libs/langchain/tests/unit_tests/document_loaders/parsers/test_docai.py b/libs/community/tests/integration_tests/document_loaders/parsers/test_docai.py
similarity index 93%
rename from libs/langchain/tests/unit_tests/document_loaders/parsers/test_docai.py
rename to libs/community/tests/integration_tests/document_loaders/parsers/test_docai.py
index d2d6d4c4ff8..ffbb42215ef 100644
--- a/libs/langchain/tests/unit_tests/document_loaders/parsers/test_docai.py
+++ b/libs/community/tests/integration_tests/document_loaders/parsers/test_docai.py
@@ -3,7 +3,7 @@ from unittest.mock import ANY, patch
import pytest
-from langchain.document_loaders.parsers import DocAIParser
+from langchain_community.document_loaders.parsers import DocAIParser
@pytest.mark.requires("google.cloud", "google.cloud.documentai")
diff --git a/libs/community/tests/integration_tests/document_loaders/parsers/test_language.py b/libs/community/tests/integration_tests/document_loaders/parsers/test_language.py
new file mode 100644
index 00000000000..c28789c7cd3
--- /dev/null
+++ b/libs/community/tests/integration_tests/document_loaders/parsers/test_language.py
@@ -0,0 +1,182 @@
+from pathlib import Path
+
+import pytest
+
+from langchain_community.document_loaders.concurrent import ConcurrentLoader
+from langchain_community.document_loaders.generic import GenericLoader
+from langchain_community.document_loaders.parsers import LanguageParser
+
+
+def test_language_loader_for_python() -> None:
+ """Test Python loader with parser enabled."""
+ file_path = Path(__file__).parent.parent.parent / "examples"
+ loader = GenericLoader.from_filesystem(
+ file_path, glob="hello_world.py", parser=LanguageParser(parser_threshold=5)
+ )
+ docs = loader.load()
+
+ assert len(docs) == 2
+
+ metadata = docs[0].metadata
+ assert metadata["source"] == str(file_path / "hello_world.py")
+ assert metadata["content_type"] == "functions_classes"
+ assert metadata["language"] == "python"
+ metadata = docs[1].metadata
+ assert metadata["source"] == str(file_path / "hello_world.py")
+ assert metadata["content_type"] == "simplified_code"
+ assert metadata["language"] == "python"
+
+ assert (
+ docs[0].page_content
+ == """def main():
+ print("Hello World!")
+
+ return 0"""
+ )
+ assert (
+ docs[1].page_content
+ == """#!/usr/bin/env python3
+
+import sys
+
+
+# Code for: def main():
+
+
+if __name__ == "__main__":
+ sys.exit(main())"""
+ )
+
+
+def test_language_loader_for_python_with_parser_threshold() -> None:
+ """Test Python loader with parser enabled and below threshold."""
+ file_path = Path(__file__).parent.parent.parent / "examples"
+ loader = GenericLoader.from_filesystem(
+ file_path,
+ glob="hello_world.py",
+ parser=LanguageParser(language="python", parser_threshold=1000),
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+
+
+def esprima_installed() -> bool:
+ try:
+ import esprima # noqa: F401
+
+ return True
+ except Exception as e:
+ print(f"esprima not installed, skipping test {e}")
+ return False
+
+
+@pytest.mark.skipif(not esprima_installed(), reason="requires esprima package")
+def test_language_loader_for_javascript() -> None:
+ """Test JavaScript loader with parser enabled."""
+ file_path = Path(__file__).parent.parent.parent / "examples"
+ loader = GenericLoader.from_filesystem(
+ file_path, glob="hello_world.js", parser=LanguageParser(parser_threshold=5)
+ )
+ docs = loader.load()
+
+ assert len(docs) == 3
+
+ metadata = docs[0].metadata
+ assert metadata["source"] == str(file_path / "hello_world.js")
+ assert metadata["content_type"] == "functions_classes"
+ assert metadata["language"] == "js"
+ metadata = docs[1].metadata
+ assert metadata["source"] == str(file_path / "hello_world.js")
+ assert metadata["content_type"] == "functions_classes"
+ assert metadata["language"] == "js"
+ metadata = docs[2].metadata
+ assert metadata["source"] == str(file_path / "hello_world.js")
+ assert metadata["content_type"] == "simplified_code"
+ assert metadata["language"] == "js"
+
+ assert (
+ docs[0].page_content
+ == """class HelloWorld {
+ sayHello() {
+ console.log("Hello World!");
+ }
+}"""
+ )
+ assert (
+ docs[1].page_content
+ == """function main() {
+ const hello = new HelloWorld();
+ hello.sayHello();
+}"""
+ )
+ assert (
+ docs[2].page_content
+ == """// Code for: class HelloWorld {
+
+// Code for: function main() {
+
+main();"""
+ )
+
+
+def test_language_loader_for_javascript_with_parser_threshold() -> None:
+ """Test JavaScript loader with parser enabled and below threshold."""
+ file_path = Path(__file__).parent.parent.parent / "examples"
+ loader = GenericLoader.from_filesystem(
+ file_path,
+ glob="hello_world.js",
+ parser=LanguageParser(language="js", parser_threshold=1000),
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+
+
+def test_concurrent_language_loader_for_javascript_with_parser_threshold() -> None:
+ """Test JavaScript ConcurrentLoader with parser enabled and below threshold."""
+ file_path = Path(__file__).parent.parent.parent / "examples"
+ loader = ConcurrentLoader.from_filesystem(
+ file_path,
+ glob="hello_world.js",
+ parser=LanguageParser(language="js", parser_threshold=1000),
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+
+
+def test_concurrent_language_loader_for_python_with_parser_threshold() -> None:
+ """Test Python ConcurrentLoader with parser enabled and below threshold."""
+ file_path = Path(__file__).parent.parent.parent / "examples"
+ loader = ConcurrentLoader.from_filesystem(
+ file_path,
+ glob="hello_world.py",
+ parser=LanguageParser(language="python", parser_threshold=1000),
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+
+
+@pytest.mark.skipif(not esprima_installed(), reason="requires esprima package")
+def test_concurrent_language_loader_for_javascript() -> None:
+ """Test JavaScript ConcurrentLoader with parser enabled."""
+ file_path = Path(__file__).parent.parent.parent / "examples"
+ loader = ConcurrentLoader.from_filesystem(
+ file_path, glob="hello_world.js", parser=LanguageParser(parser_threshold=5)
+ )
+ docs = loader.load()
+
+ assert len(docs) == 3
+
+
+def test_concurrent_language_loader_for_python() -> None:
+ """Test Python ConcurrentLoader with parser enabled."""
+ file_path = Path(__file__).parent.parent.parent / "examples"
+ loader = ConcurrentLoader.from_filesystem(
+ file_path, glob="hello_world.py", parser=LanguageParser(parser_threshold=5)
+ )
+ docs = loader.load()
+
+ assert len(docs) == 2
diff --git a/libs/langchain/tests/integration_tests/document_loaders/parsers/test_pdf_parsers.py b/libs/community/tests/integration_tests/document_loaders/parsers/test_pdf_parsers.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/parsers/test_pdf_parsers.py
rename to libs/community/tests/integration_tests/document_loaders/parsers/test_pdf_parsers.py
index 408498c126a..6b4d55d8822 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/parsers/test_pdf_parsers.py
+++ b/libs/community/tests/integration_tests/document_loaders/parsers/test_pdf_parsers.py
@@ -2,9 +2,9 @@
from pathlib import Path
from typing import Iterator
-from langchain.document_loaders.base import BaseBlobParser
-from langchain.document_loaders.blob_loaders import Blob
-from langchain.document_loaders.parsers.pdf import (
+from langchain_community.document_loaders.base import BaseBlobParser
+from langchain_community.document_loaders.blob_loaders import Blob
+from langchain_community.document_loaders.parsers.pdf import (
PDFMinerParser,
PDFPlumberParser,
PyMuPDFParser,
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_arxiv.py b/libs/community/tests/integration_tests/document_loaders/test_arxiv.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/document_loaders/test_arxiv.py
rename to libs/community/tests/integration_tests/document_loaders/test_arxiv.py
index 5cbf8957672..55ad01b9c40 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_arxiv.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_arxiv.py
@@ -3,7 +3,7 @@ from typing import List
import pytest
from langchain_core.documents import Document
-from langchain.document_loaders.arxiv import ArxivLoader
+from langchain_community.document_loaders.arxiv import ArxivLoader
def assert_docs(docs: List[Document]) -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_bigquery.py b/libs/community/tests/integration_tests/document_loaders/test_bigquery.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/test_bigquery.py
rename to libs/community/tests/integration_tests/document_loaders/test_bigquery.py
index 24b55056b43..df492acd53f 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_bigquery.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_bigquery.py
@@ -1,6 +1,6 @@
import pytest
-from langchain.document_loaders.bigquery import BigQueryLoader
+from langchain_community.document_loaders.bigquery import BigQueryLoader
try:
from google.cloud import bigquery # noqa: F401
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_bilibili.py b/libs/community/tests/integration_tests/document_loaders/test_bilibili.py
similarity index 88%
rename from libs/langchain/tests/integration_tests/document_loaders/test_bilibili.py
rename to libs/community/tests/integration_tests/document_loaders/test_bilibili.py
index cf1affdf2e3..e9bea92917c 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_bilibili.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_bilibili.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders import BiliBiliLoader
+from langchain_community.document_loaders import BiliBiliLoader
def test_bilibili_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_blockchain.py b/libs/community/tests/integration_tests/document_loaders/test_blockchain.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/document_loaders/test_blockchain.py
rename to libs/community/tests/integration_tests/document_loaders/test_blockchain.py
index b9dbd92d257..8ea688df584 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_blockchain.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_blockchain.py
@@ -3,8 +3,8 @@ import time
import pytest
-from langchain.document_loaders import BlockchainDocumentLoader
-from langchain.document_loaders.blockchain import BlockchainType
+from langchain_community.document_loaders import BlockchainDocumentLoader
+from langchain_community.document_loaders.blockchain import BlockchainType
if "ALCHEMY_API_KEY" in os.environ:
alchemyKeySet = True
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_confluence.py b/libs/community/tests/integration_tests/document_loaders/test_confluence.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/document_loaders/test_confluence.py
rename to libs/community/tests/integration_tests/document_loaders/test_confluence.py
index 983bc254704..2a469fb9573 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_confluence.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_confluence.py
@@ -1,6 +1,6 @@
import pytest
-from langchain.document_loaders.confluence import ConfluenceLoader
+from langchain_community.document_loaders.confluence import ConfluenceLoader
try:
from atlassian import Confluence # noqa: F401
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_couchbase.py b/libs/community/tests/integration_tests/document_loaders/test_couchbase.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/document_loaders/test_couchbase.py
rename to libs/community/tests/integration_tests/document_loaders/test_couchbase.py
index d4585d0796b..f4008679626 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_couchbase.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_couchbase.py
@@ -1,6 +1,6 @@
import unittest
-from langchain.document_loaders.couchbase import CouchbaseLoader
+from langchain_community.document_loaders.couchbase import CouchbaseLoader
try:
import couchbase # noqa: F401
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_csv_loader.py b/libs/community/tests/integration_tests/document_loaders/test_csv_loader.py
similarity index 83%
rename from libs/langchain/tests/integration_tests/document_loaders/test_csv_loader.py
rename to libs/community/tests/integration_tests/document_loaders/test_csv_loader.py
index ffce01cf17b..9b57d859d03 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_csv_loader.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_csv_loader.py
@@ -1,7 +1,7 @@
import os
from pathlib import Path
-from langchain.document_loaders import UnstructuredCSVLoader
+from langchain_community.document_loaders import UnstructuredCSVLoader
EXAMPLE_DIRECTORY = file_path = Path(__file__).parent.parent / "examples"
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_dataframe.py b/libs/community/tests/integration_tests/document_loaders/test_dataframe.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/document_loaders/test_dataframe.py
rename to libs/community/tests/integration_tests/document_loaders/test_dataframe.py
index c1be686995b..a41d4f38d60 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_dataframe.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_dataframe.py
@@ -2,7 +2,7 @@ import pandas as pd
import pytest
from langchain_core.documents import Document
-from langchain.document_loaders import DataFrameLoader
+from langchain_community.document_loaders import DataFrameLoader
@pytest.fixture
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_docusaurus.py b/libs/community/tests/integration_tests/document_loaders/test_docusaurus.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/test_docusaurus.py
rename to libs/community/tests/integration_tests/document_loaders/test_docusaurus.py
index 53323ae9e4e..b72f3495240 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_docusaurus.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_docusaurus.py
@@ -1,6 +1,6 @@
from pathlib import Path
-from langchain.document_loaders import DocusaurusLoader
+from langchain_community.document_loaders import DocusaurusLoader
DOCS_URL = str(Path(__file__).parent.parent / "examples/docusaurus-sitemap.xml")
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_duckdb.py b/libs/community/tests/integration_tests/document_loaders/test_duckdb.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/test_duckdb.py
rename to libs/community/tests/integration_tests/document_loaders/test_duckdb.py
index a91e352b5ad..62ec846b8f6 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_duckdb.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_duckdb.py
@@ -1,6 +1,6 @@
import unittest
-from langchain.document_loaders.duckdb_loader import DuckDBLoader
+from langchain_community.document_loaders.duckdb_loader import DuckDBLoader
try:
import duckdb # noqa: F401
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_email.py b/libs/community/tests/integration_tests/document_loaders/test_email.py
similarity index 91%
rename from libs/langchain/tests/integration_tests/document_loaders/test_email.py
rename to libs/community/tests/integration_tests/document_loaders/test_email.py
index b89cc19c29b..4e2c6e04a0a 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_email.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_email.py
@@ -1,6 +1,9 @@
from pathlib import Path
-from langchain.document_loaders import OutlookMessageLoader, UnstructuredEmailLoader
+from langchain_community.document_loaders import (
+ OutlookMessageLoader,
+ UnstructuredEmailLoader,
+)
def test_outlook_message_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_etherscan.py b/libs/community/tests/integration_tests/document_loaders/test_etherscan.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/document_loaders/test_etherscan.py
rename to libs/community/tests/integration_tests/document_loaders/test_etherscan.py
index e88ea51f3d2..7f503879565 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_etherscan.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_etherscan.py
@@ -2,7 +2,7 @@ import os
import pytest
-from langchain.document_loaders import EtherscanLoader
+from langchain_community.document_loaders import EtherscanLoader
if "ETHERSCAN_API_KEY" in os.environ:
etherscan_key_set = True
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_excel.py b/libs/community/tests/integration_tests/document_loaders/test_excel.py
similarity index 83%
rename from libs/langchain/tests/integration_tests/document_loaders/test_excel.py
rename to libs/community/tests/integration_tests/document_loaders/test_excel.py
index c8fbe07f692..8d60065b59a 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_excel.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_excel.py
@@ -1,7 +1,7 @@
import os
from pathlib import Path
-from langchain.document_loaders import UnstructuredExcelLoader
+from langchain_community.document_loaders import UnstructuredExcelLoader
EXAMPLE_DIRECTORY = file_path = Path(__file__).parent.parent / "examples"
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_facebook_chat.py b/libs/community/tests/integration_tests/document_loaders/test_facebook_chat.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/test_facebook_chat.py
rename to libs/community/tests/integration_tests/document_loaders/test_facebook_chat.py
index eaa8f91242b..0770e3a5c31 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_facebook_chat.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_facebook_chat.py
@@ -1,6 +1,6 @@
from pathlib import Path
-from langchain.document_loaders import FacebookChatLoader
+from langchain_community.document_loaders import FacebookChatLoader
def test_facebook_chat_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_fauna.py b/libs/community/tests/integration_tests/document_loaders/test_fauna.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/document_loaders/test_fauna.py
rename to libs/community/tests/integration_tests/document_loaders/test_fauna.py
index 81588d93422..c91afcb1c83 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_fauna.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_fauna.py
@@ -1,6 +1,6 @@
import unittest
-from langchain.document_loaders.fauna import FaunaLoader
+from langchain_community.document_loaders.fauna import FaunaLoader
try:
import fauna # noqa: F401
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_figma.py b/libs/community/tests/integration_tests/document_loaders/test_figma.py
similarity index 75%
rename from libs/langchain/tests/integration_tests/document_loaders/test_figma.py
rename to libs/community/tests/integration_tests/document_loaders/test_figma.py
index 00fa6488e26..ac78335adea 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_figma.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_figma.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.figma import FigmaFileLoader
+from langchain_community.document_loaders.figma import FigmaFileLoader
ACCESS_TOKEN = ""
IDS = ""
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py b/libs/community/tests/integration_tests/document_loaders/test_geodataframe.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py
rename to libs/community/tests/integration_tests/document_loaders/test_geodataframe.py
index 9417a478442..ef2a62f9531 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_geodataframe.py
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import pytest
from langchain_core.documents import Document
-from langchain.document_loaders import GeoDataFrameLoader
+from langchain_community.document_loaders import GeoDataFrameLoader
if TYPE_CHECKING:
from geopandas import GeoDataFrame
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_gitbook.py b/libs/community/tests/integration_tests/document_loaders/test_gitbook.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/document_loaders/test_gitbook.py
rename to libs/community/tests/integration_tests/document_loaders/test_gitbook.py
index d6519a55e60..fa2a6838bd4 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_gitbook.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_gitbook.py
@@ -2,7 +2,7 @@ from typing import Optional
import pytest
-from langchain.document_loaders.gitbook import GitbookLoader
+from langchain_community.document_loaders.gitbook import GitbookLoader
class TestGitbookLoader:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_github.py b/libs/community/tests/integration_tests/document_loaders/test_github.py
similarity index 82%
rename from libs/langchain/tests/integration_tests/document_loaders/test_github.py
rename to libs/community/tests/integration_tests/document_loaders/test_github.py
index 2e437a6b554..a3ad86a0f99 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_github.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_github.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.github import GitHubIssuesLoader
+from langchain_community.document_loaders.github import GitHubIssuesLoader
def test_issues_load() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_google_speech_to_text.py b/libs/community/tests/integration_tests/document_loaders/test_google_speech_to_text.py
similarity index 90%
rename from libs/langchain/tests/integration_tests/document_loaders/test_google_speech_to_text.py
rename to libs/community/tests/integration_tests/document_loaders/test_google_speech_to_text.py
index c5041cb16bc..fe3b768c675 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_google_speech_to_text.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_google_speech_to_text.py
@@ -9,7 +9,9 @@ to set up the app and configure authentication.
import pytest
-from langchain.document_loaders.google_speech_to_text import GoogleSpeechToTextLoader
+from langchain_community.document_loaders.google_speech_to_text import (
+ GoogleSpeechToTextLoader,
+)
@pytest.mark.requires("google.api_core")
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_ifixit.py b/libs/community/tests/integration_tests/document_loaders/test_ifixit.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/test_ifixit.py
rename to libs/community/tests/integration_tests/document_loaders/test_ifixit.py
index c97be49e1a1..28a83a5b8bf 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_ifixit.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_ifixit.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.ifixit import IFixitLoader
+from langchain_community.document_loaders.ifixit import IFixitLoader
def test_ifixit_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_joplin.py b/libs/community/tests/integration_tests/document_loaders/test_joplin.py
similarity index 80%
rename from libs/langchain/tests/integration_tests/document_loaders/test_joplin.py
rename to libs/community/tests/integration_tests/document_loaders/test_joplin.py
index 76a1918b6a1..a15078233de 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_joplin.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_joplin.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.joplin import JoplinLoader
+from langchain_community.document_loaders.joplin import JoplinLoader
def test_joplin_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_json_loader.py b/libs/community/tests/integration_tests/document_loaders/test_json_loader.py
similarity index 88%
rename from libs/langchain/tests/integration_tests/document_loaders/test_json_loader.py
rename to libs/community/tests/integration_tests/document_loaders/test_json_loader.py
index bdca42c40ce..8f85d9b0191 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_json_loader.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_json_loader.py
@@ -1,6 +1,6 @@
from pathlib import Path
-from langchain.document_loaders import JSONLoader
+from langchain_community.document_loaders import JSONLoader
def test_json_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_lakefs.py b/libs/community/tests/integration_tests/document_loaders/test_lakefs.py
similarity index 81%
rename from libs/langchain/tests/integration_tests/document_loaders/test_lakefs.py
rename to libs/community/tests/integration_tests/document_loaders/test_lakefs.py
index c840ce80743..209a40dfd8e 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_lakefs.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_lakefs.py
@@ -6,12 +6,14 @@ import pytest
import requests_mock
from requests_mock.mocker import Mocker
-from langchain.document_loaders.lakefs import LakeFSLoader
+from langchain_community.document_loaders.lakefs import LakeFSLoader
@pytest.fixture
def mock_lakefs_client() -> Any:
- with patch("langchain.document_loaders.lakefs.LakeFSClient") as mock_lakefs_client:
+ with patch(
+ "langchain_community.document_loaders.lakefs.LakeFSClient"
+ ) as mock_lakefs_client:
mock_lakefs_client.return_value.ls_objects.return_value = [
("path_bla.txt", "https://physical_address_bla")
]
@@ -21,7 +23,9 @@ def mock_lakefs_client() -> Any:
@pytest.fixture
def mock_lakefs_client_no_presign_not_local() -> Any:
- with patch("langchain.document_loaders.lakefs.LakeFSClient") as mock_lakefs_client:
+ with patch(
+ "langchain_community.document_loaders.lakefs.LakeFSClient"
+ ) as mock_lakefs_client:
mock_lakefs_client.return_value.ls_objects.return_value = [
("path_bla.txt", "https://physical_address_bla")
]
@@ -32,7 +36,7 @@ def mock_lakefs_client_no_presign_not_local() -> Any:
@pytest.fixture
def mock_unstructured_local() -> Any:
with patch(
- "langchain.document_loaders.lakefs.UnstructuredLakeFSLoader"
+ "langchain_community.document_loaders.lakefs.UnstructuredLakeFSLoader"
) as mock_unstructured_lakefs:
mock_unstructured_lakefs.return_value.load.return_value = [
("text content", "pdf content")
@@ -42,7 +46,9 @@ def mock_unstructured_local() -> Any:
@pytest.fixture
def mock_lakefs_client_no_presign_local() -> Any:
- with patch("langchain.document_loaders.lakefs.LakeFSClient") as mock_lakefs_client:
+ with patch(
+ "langchain_community.document_loaders.lakefs.LakeFSClient"
+ ) as mock_lakefs_client:
mock_lakefs_client.return_value.ls_objects.return_value = [
("path_bla.txt", "local:///physical_address_bla")
]
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_max_compute.py b/libs/community/tests/integration_tests/document_loaders/test_language.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/document_loaders/test_max_compute.py
rename to libs/community/tests/integration_tests/document_loaders/test_language.py
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_larksuite.py b/libs/community/tests/integration_tests/document_loaders/test_larksuite.py
similarity index 79%
rename from libs/langchain/tests/integration_tests/document_loaders/test_larksuite.py
rename to libs/community/tests/integration_tests/document_loaders/test_larksuite.py
index 147d8ee8018..61d251dd929 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_larksuite.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_larksuite.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.larksuite import LarkSuiteDocLoader
+from langchain_community.document_loaders.larksuite import LarkSuiteDocLoader
DOMAIN = ""
ACCESS_TOKEN = ""
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_mastodon.py b/libs/community/tests/integration_tests/document_loaders/test_mastodon.py
similarity index 85%
rename from libs/langchain/tests/integration_tests/document_loaders/test_mastodon.py
rename to libs/community/tests/integration_tests/document_loaders/test_mastodon.py
index 6988c0758af..b7f04fa67fb 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_mastodon.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_mastodon.py
@@ -1,5 +1,5 @@
"""Tests for the Mastodon toots loader"""
-from langchain.document_loaders import MastodonTootsLoader
+from langchain_community.document_loaders import MastodonTootsLoader
def test_mastodon_toots_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/fixtures.py b/libs/community/tests/integration_tests/document_loaders/test_max_compute.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/vectorstores/qdrant/fixtures.py
rename to libs/community/tests/integration_tests/document_loaders/test_max_compute.py
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_modern_treasury.py b/libs/community/tests/integration_tests/document_loaders/test_modern_treasury.py
similarity index 73%
rename from libs/langchain/tests/integration_tests/document_loaders/test_modern_treasury.py
rename to libs/community/tests/integration_tests/document_loaders/test_modern_treasury.py
index 3ce8c71123a..62bad5a0b91 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_modern_treasury.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_modern_treasury.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.modern_treasury import ModernTreasuryLoader
+from langchain_community.document_loaders.modern_treasury import ModernTreasuryLoader
def test_modern_treasury_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_news.py b/libs/community/tests/integration_tests/document_loaders/test_news.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/document_loaders/test_news.py
rename to libs/community/tests/integration_tests/document_loaders/test_news.py
index 2507df34f68..e34848b302a 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_news.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_news.py
@@ -3,7 +3,7 @@ import random
import pytest
import requests
-from langchain.document_loaders import NewsURLLoader
+from langchain_community.document_loaders import NewsURLLoader
def get_random_news_url() -> str:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_nuclia.py b/libs/community/tests/integration_tests/document_loaders/test_nuclia.py
similarity index 83%
rename from libs/langchain/tests/integration_tests/document_loaders/test_nuclia.py
rename to libs/community/tests/integration_tests/document_loaders/test_nuclia.py
index 6d7131e19f3..c63949fff68 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_nuclia.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_nuclia.py
@@ -3,8 +3,8 @@ import os
from typing import Any
from unittest import mock
-from langchain.document_loaders.nuclia import NucliaLoader
-from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
+from langchain_community.document_loaders.nuclia import NucliaLoader
+from langchain_community.tools.nuclia.tool import NucliaUnderstandingAPI
def fakerun(**args: Any) -> Any:
@@ -32,7 +32,8 @@ def fakerun(**args: Any) -> Any:
@mock.patch.dict(os.environ, {"NUCLIA_NUA_KEY": "_a_key_"})
def test_nuclia_loader() -> None:
with mock.patch(
- "langchain.tools.nuclia.tool.NucliaUnderstandingAPI._run", new_callable=fakerun
+ "langchain_community.tools.nuclia.tool.NucliaUnderstandingAPI._run",
+ new_callable=fakerun,
):
nua = NucliaUnderstandingAPI(enable_ml=False)
loader = NucliaLoader("/whatever/file.mp3", nua)
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_odt.py b/libs/community/tests/integration_tests/document_loaders/test_odt.py
similarity index 79%
rename from libs/langchain/tests/integration_tests/document_loaders/test_odt.py
rename to libs/community/tests/integration_tests/document_loaders/test_odt.py
index 0aa833ceb6c..7ffd36c9f1d 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_odt.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_odt.py
@@ -1,6 +1,6 @@
from pathlib import Path
-from langchain.document_loaders import UnstructuredODTLoader
+from langchain_community.document_loaders import UnstructuredODTLoader
def test_unstructured_odt_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_org_mode.py b/libs/community/tests/integration_tests/document_loaders/test_org_mode.py
similarity index 83%
rename from libs/langchain/tests/integration_tests/document_loaders/test_org_mode.py
rename to libs/community/tests/integration_tests/document_loaders/test_org_mode.py
index 157d76c0c40..a72b0ddda20 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_org_mode.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_org_mode.py
@@ -1,7 +1,7 @@
import os
from pathlib import Path
-from langchain.document_loaders import UnstructuredOrgModeLoader
+from langchain_community.document_loaders import UnstructuredOrgModeLoader
EXAMPLE_DIRECTORY = file_path = Path(__file__).parent.parent / "examples"
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_pdf.py b/libs/community/tests/integration_tests/document_loaders/test_pdf.py
similarity index 99%
rename from libs/langchain/tests/integration_tests/document_loaders/test_pdf.py
rename to libs/community/tests/integration_tests/document_loaders/test_pdf.py
index d83ba9cf7a1..97c005b46d0 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_pdf.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_pdf.py
@@ -3,7 +3,7 @@ from typing import Sequence, Union
import pytest
-from langchain.document_loaders import (
+from langchain_community.document_loaders import (
AmazonTextractPDFLoader,
MathpixPDFLoader,
PDFMinerLoader,
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py b/libs/community/tests/integration_tests/document_loaders/test_polars_dataframe.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py
rename to libs/community/tests/integration_tests/document_loaders/test_polars_dataframe.py
index 80ca3892115..10abae72724 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_polars_dataframe.py
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import pytest
from langchain_core.documents import Document
-from langchain.document_loaders import PolarsDataFrameLoader
+from langchain_community.document_loaders import PolarsDataFrameLoader
if TYPE_CHECKING:
import polars as pl
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_pubmed.py b/libs/community/tests/integration_tests/document_loaders/test_pubmed.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/document_loaders/test_pubmed.py
rename to libs/community/tests/integration_tests/document_loaders/test_pubmed.py
index 45a23157f9f..bea8afabbad 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_pubmed.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_pubmed.py
@@ -4,7 +4,7 @@ from typing import List
import pytest
from langchain_core.documents import Document
-from langchain.document_loaders import PubMedLoader
+from langchain_community.document_loaders import PubMedLoader
xmltodict = pytest.importorskip("xmltodict")
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py b/libs/community/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py
similarity index 90%
rename from libs/langchain/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py
rename to libs/community/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py
index 0d979939e08..1a453e1e22e 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py
@@ -3,7 +3,9 @@ import string
from langchain_core.documents import Document
-from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
+from langchain_community.document_loaders.pyspark_dataframe import (
+ PySparkDataFrameLoader,
+)
def test_pyspark_loader_load_valid_data() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_python.py b/libs/community/tests/integration_tests/document_loaders/test_python.py
similarity index 86%
rename from libs/langchain/tests/integration_tests/document_loaders/test_python.py
rename to libs/community/tests/integration_tests/document_loaders/test_python.py
index f4b2b3ae6fe..6b847ec1e3d 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_python.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_python.py
@@ -2,7 +2,7 @@ from pathlib import Path
import pytest
-from langchain.document_loaders.python import PythonLoader
+from langchain_community.document_loaders.python import PythonLoader
@pytest.mark.parametrize("filename", ["default-encoding.py", "non-utf8-encoding.py"])
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_quip.py b/libs/community/tests/integration_tests/document_loaders/test_quip.py
similarity index 99%
rename from libs/langchain/tests/integration_tests/document_loaders/test_quip.py
rename to libs/community/tests/integration_tests/document_loaders/test_quip.py
index 6d4ea446891..aea0056db86 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_quip.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_quip.py
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest
from langchain_core.documents import Document
-from langchain.document_loaders.quip import QuipLoader
+from langchain_community.document_loaders.quip import QuipLoader
try:
from quip_api.quip import QuipClient # noqa: F401
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_recursive_url_loader.py b/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/document_loaders/test_recursive_url_loader.py
rename to libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py
index bbfc5586134..ff8083bcd50 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_recursive_url_loader.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.recursive_url_loader import RecursiveUrlLoader
+from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
def test_async_recursive_url_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_rocksetdb.py b/libs/community/tests/integration_tests/document_loaders/test_rocksetdb.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/document_loaders/test_rocksetdb.py
rename to libs/community/tests/integration_tests/document_loaders/test_rocksetdb.py
index c1d2edd0a19..7df2f4592f7 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_rocksetdb.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_rocksetdb.py
@@ -3,7 +3,7 @@ import os
from langchain_core.documents import Document
-from langchain.document_loaders import RocksetLoader
+from langchain_community.document_loaders import RocksetLoader
logger = logging.getLogger(__name__)
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_rss.py b/libs/community/tests/integration_tests/document_loaders/test_rss.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/test_rss.py
rename to libs/community/tests/integration_tests/document_loaders/test_rss.py
index 093a38ffb44..1f9b41fa634 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_rss.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_rss.py
@@ -1,6 +1,6 @@
from pathlib import Path
-from langchain.document_loaders.rss import RSSFeedLoader
+from langchain_community.document_loaders.rss import RSSFeedLoader
def test_rss_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_rst.py b/libs/community/tests/integration_tests/document_loaders/test_rst.py
similarity index 83%
rename from libs/langchain/tests/integration_tests/document_loaders/test_rst.py
rename to libs/community/tests/integration_tests/document_loaders/test_rst.py
index ead71c3dc6c..3c129d81a56 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_rst.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_rst.py
@@ -1,7 +1,7 @@
import os
from pathlib import Path
-from langchain.document_loaders import UnstructuredRSTLoader
+from langchain_community.document_loaders import UnstructuredRSTLoader
EXAMPLE_DIRECTORY = file_path = Path(__file__).parent.parent / "examples"
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sitemap.py b/libs/community/tests/integration_tests/document_loaders/test_sitemap.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/document_loaders/test_sitemap.py
rename to libs/community/tests/integration_tests/document_loaders/test_sitemap.py
index 7ab448dfe76..de43f697b6f 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_sitemap.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_sitemap.py
@@ -3,8 +3,8 @@ from typing import Any
import pytest
-from langchain.document_loaders import SitemapLoader
-from langchain.document_loaders.sitemap import _extract_scheme_and_domain
+from langchain_community.document_loaders import SitemapLoader
+from langchain_community.document_loaders.sitemap import _extract_scheme_and_domain
def test_sitemap() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_slack.py b/libs/community/tests/integration_tests/document_loaders/test_slack.py
similarity index 91%
rename from libs/langchain/tests/integration_tests/document_loaders/test_slack.py
rename to libs/community/tests/integration_tests/document_loaders/test_slack.py
index 7baa1319fc9..20df9e8ff8e 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_slack.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_slack.py
@@ -1,7 +1,7 @@
"""Tests for the Slack directory loader"""
from pathlib import Path
-from langchain.document_loaders import SlackDirectoryLoader
+from langchain_community.document_loaders import SlackDirectoryLoader
def test_slack_directory_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_spreedly.py b/libs/community/tests/integration_tests/document_loaders/test_spreedly.py
similarity index 77%
rename from libs/langchain/tests/integration_tests/document_loaders/test_spreedly.py
rename to libs/community/tests/integration_tests/document_loaders/test_spreedly.py
index bb49802569c..e7dfbc1615c 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_spreedly.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_spreedly.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.spreedly import SpreedlyLoader
+from langchain_community.document_loaders.spreedly import SpreedlyLoader
def test_spreedly_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_stripe.py b/libs/community/tests/integration_tests/document_loaders/test_stripe.py
similarity index 72%
rename from libs/langchain/tests/integration_tests/document_loaders/test_stripe.py
rename to libs/community/tests/integration_tests/document_loaders/test_stripe.py
index e8484ab6e05..0f9fda9e9e8 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_stripe.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_stripe.py
@@ -1,4 +1,4 @@
-from langchain.document_loaders.stripe import StripeLoader
+from langchain_community.document_loaders.stripe import StripeLoader
def test_stripe_loader() -> None:
diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_telegram.py b/libs/community/tests/integration_tests/document_loaders/test_telegram.py
similarity index 91%
rename from libs/langchain/tests/unit_tests/document_loaders/test_telegram.py
rename to libs/community/tests/integration_tests/document_loaders/test_telegram.py
index 4fbe32a7244..df90af025f9 100644
--- a/libs/langchain/tests/unit_tests/document_loaders/test_telegram.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_telegram.py
@@ -2,7 +2,10 @@ from pathlib import Path
import pytest
-from langchain.document_loaders import TelegramChatApiLoader, TelegramChatFileLoader
+from langchain_community.document_loaders import (
+ TelegramChatApiLoader,
+ TelegramChatFileLoader,
+)
def test_telegram_chat_file_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py b/libs/community/tests/integration_tests/document_loaders/test_tensorflow_datasets.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py
rename to libs/community/tests/integration_tests/document_loaders/test_tensorflow_datasets.py
index 82c3dc87809..7498b65e56e 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_tensorflow_datasets.py
@@ -7,7 +7,9 @@ import pytest
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import ValidationError
-from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader
+from langchain_community.document_loaders.tensorflow_datasets import (
+ TensorflowDatasetLoader,
+)
if TYPE_CHECKING:
import tensorflow as tf # noqa: E402
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_tsv.py b/libs/community/tests/integration_tests/document_loaders/test_tsv.py
similarity index 83%
rename from libs/langchain/tests/integration_tests/document_loaders/test_tsv.py
rename to libs/community/tests/integration_tests/document_loaders/test_tsv.py
index 2834fc61c36..3321ca3af1c 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_tsv.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_tsv.py
@@ -1,7 +1,7 @@
import os
from pathlib import Path
-from langchain.document_loaders import UnstructuredTSVLoader
+from langchain_community.document_loaders import UnstructuredTSVLoader
EXAMPLE_DIRECTORY = file_path = Path(__file__).parent.parent / "examples"
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_unstructured.py b/libs/community/tests/integration_tests/document_loaders/test_unstructured.py
similarity index 98%
rename from libs/langchain/tests/integration_tests/document_loaders/test_unstructured.py
rename to libs/community/tests/integration_tests/document_loaders/test_unstructured.py
index 735f3ee0b14..bb1d809ca5d 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_unstructured.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_unstructured.py
@@ -2,7 +2,7 @@ import os
from contextlib import ExitStack
from pathlib import Path
-from langchain.document_loaders import (
+from langchain_community.document_loaders import (
UnstructuredAPIFileIOLoader,
UnstructuredAPIFileLoader,
UnstructuredFileLoader,
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_url.py b/libs/community/tests/integration_tests/document_loaders/test_url.py
similarity index 86%
rename from libs/langchain/tests/integration_tests/document_loaders/test_url.py
rename to libs/community/tests/integration_tests/document_loaders/test_url.py
index f61a8114124..8c8e15c8f55 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_url.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_url.py
@@ -1,6 +1,6 @@
import pytest
-from langchain.document_loaders import UnstructuredURLLoader
+from langchain_community.document_loaders import UnstructuredURLLoader
def test_continue_on_failure_true() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_url_playwright.py b/libs/community/tests/integration_tests/document_loaders/test_url_playwright.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/document_loaders/test_url_playwright.py
rename to libs/community/tests/integration_tests/document_loaders/test_url_playwright.py
index 22e0cac01ab..ee70736f0d0 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_url_playwright.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_url_playwright.py
@@ -1,8 +1,8 @@
"""Tests for the Playwright URL loader"""
from typing import TYPE_CHECKING
-from langchain.document_loaders import PlaywrightURLLoader
-from langchain.document_loaders.url_playwright import PlaywrightEvaluator
+from langchain_community.document_loaders import PlaywrightURLLoader
+from langchain_community.document_loaders.url_playwright import PlaywrightEvaluator
if TYPE_CHECKING:
from playwright.async_api import Browser as AsyncBrowser
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_whatsapp_chat.py b/libs/community/tests/integration_tests/document_loaders/test_whatsapp_chat.py
similarity index 92%
rename from libs/langchain/tests/integration_tests/document_loaders/test_whatsapp_chat.py
rename to libs/community/tests/integration_tests/document_loaders/test_whatsapp_chat.py
index be59d4f2f22..2884026ceb5 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_whatsapp_chat.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_whatsapp_chat.py
@@ -1,6 +1,6 @@
from pathlib import Path
-from langchain.document_loaders import WhatsAppChatLoader
+from langchain_community.document_loaders import WhatsAppChatLoader
def test_whatsapp_chat_loader() -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_wikipedia.py b/libs/community/tests/integration_tests/document_loaders/test_wikipedia.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/document_loaders/test_wikipedia.py
rename to libs/community/tests/integration_tests/document_loaders/test_wikipedia.py
index 929b68f402a..63e9c84d852 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_wikipedia.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_wikipedia.py
@@ -3,7 +3,7 @@ from typing import List
from langchain_core.documents import Document
-from langchain.document_loaders import WikipediaLoader
+from langchain_community.document_loaders import WikipediaLoader
def assert_docs(docs: List[Document], all_meta: bool = False) -> None:
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_xml.py b/libs/community/tests/integration_tests/document_loaders/test_xml.py
similarity index 83%
rename from libs/langchain/tests/integration_tests/document_loaders/test_xml.py
rename to libs/community/tests/integration_tests/document_loaders/test_xml.py
index a4ea69e728d..133c863d316 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_xml.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_xml.py
@@ -1,7 +1,7 @@
import os
from pathlib import Path
-from langchain.document_loaders import UnstructuredXMLLoader
+from langchain_community.document_loaders import UnstructuredXMLLoader
EXAMPLE_DIRECTORY = file_path = Path(__file__).parent.parent / "examples"
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_xorbits.py b/libs/community/tests/integration_tests/document_loaders/test_xorbits.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/document_loaders/test_xorbits.py
rename to libs/community/tests/integration_tests/document_loaders/test_xorbits.py
index 9f6407e1062..c030684b788 100644
--- a/libs/langchain/tests/integration_tests/document_loaders/test_xorbits.py
+++ b/libs/community/tests/integration_tests/document_loaders/test_xorbits.py
@@ -1,7 +1,7 @@
import pytest
from langchain_core.documents import Document
-from langchain.document_loaders import XorbitsLoader
+from langchain_community.document_loaders import XorbitsLoader
try:
import xorbits # noqa: F401
diff --git a/libs/langchain/tests/integration_tests/embeddings/__init__.py b/libs/community/tests/integration_tests/embeddings/__init__.py
similarity index 100%
rename from libs/langchain/tests/integration_tests/embeddings/__init__.py
rename to libs/community/tests/integration_tests/embeddings/__init__.py
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_awa.py b/libs/community/tests/integration_tests/embeddings/test_awa.py
similarity index 89%
rename from libs/langchain/tests/integration_tests/embeddings/test_awa.py
rename to libs/community/tests/integration_tests/embeddings/test_awa.py
index bbe5e3b30f0..68a71085cd3 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_awa.py
+++ b/libs/community/tests/integration_tests/embeddings/test_awa.py
@@ -1,5 +1,5 @@
"""Test Awa Embedding"""
-from langchain.embeddings.awa import AwaEmbeddings
+from langchain_community.embeddings.awa import AwaEmbeddings
def test_awa_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_azure_openai.py b/libs/community/tests/integration_tests/embeddings/test_azure_openai.py
similarity index 98%
rename from libs/langchain/tests/integration_tests/embeddings/test_azure_openai.py
rename to libs/community/tests/integration_tests/embeddings/test_azure_openai.py
index 2ab52ca7d63..ec1e5f47209 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_azure_openai.py
+++ b/libs/community/tests/integration_tests/embeddings/test_azure_openai.py
@@ -5,7 +5,7 @@ from typing import Any
import numpy as np
import pytest
-from langchain.embeddings import AzureOpenAIEmbeddings
+from langchain_community.embeddings import AzureOpenAIEmbeddings
OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "")
OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_bookend.py b/libs/community/tests/integration_tests/embeddings/test_bookend.py
similarity index 92%
rename from libs/langchain/tests/integration_tests/embeddings/test_bookend.py
rename to libs/community/tests/integration_tests/embeddings/test_bookend.py
index 940f6706380..c15036f14bd 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_bookend.py
+++ b/libs/community/tests/integration_tests/embeddings/test_bookend.py
@@ -1,5 +1,5 @@
"""Test Bookend AI embeddings."""
-from langchain.embeddings.bookend import BookendEmbeddings
+from langchain_community.embeddings.bookend import BookendEmbeddings
def test_bookend_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py b/libs/community/tests/integration_tests/embeddings/test_cloudflare_workersai.py
similarity index 93%
rename from libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py
rename to libs/community/tests/integration_tests/embeddings/test_cloudflare_workersai.py
index 24ac0313717..55261f51725 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py
+++ b/libs/community/tests/integration_tests/embeddings/test_cloudflare_workersai.py
@@ -2,7 +2,9 @@
import responses
-from langchain.embeddings.cloudflare_workersai import CloudflareWorkersAIEmbeddings
+from langchain_community.embeddings.cloudflare_workersai import (
+ CloudflareWorkersAIEmbeddings,
+)
@responses.activate
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_cohere.py b/libs/community/tests/integration_tests/embeddings/test_cohere.py
similarity index 88%
rename from libs/langchain/tests/integration_tests/embeddings/test_cohere.py
rename to libs/community/tests/integration_tests/embeddings/test_cohere.py
index 4e2aec50d23..3ca7fd6f2ce 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_cohere.py
+++ b/libs/community/tests/integration_tests/embeddings/test_cohere.py
@@ -1,5 +1,5 @@
"""Test cohere embeddings."""
-from langchain.embeddings.cohere import CohereEmbeddings
+from langchain_community.embeddings.cohere import CohereEmbeddings
def test_cohere_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_dashscope.py b/libs/community/tests/integration_tests/embeddings/test_dashscope.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/embeddings/test_dashscope.py
rename to libs/community/tests/integration_tests/embeddings/test_dashscope.py
index f61c3805e1a..4c189c53550 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_dashscope.py
+++ b/libs/community/tests/integration_tests/embeddings/test_dashscope.py
@@ -1,7 +1,7 @@
"""Test dashscope embeddings."""
import numpy as np
-from langchain.embeddings.dashscope import DashScopeEmbeddings
+from langchain_community.embeddings.dashscope import DashScopeEmbeddings
def test_dashscope_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_deepinfra.py b/libs/community/tests/integration_tests/embeddings/test_deepinfra.py
similarity index 90%
rename from libs/langchain/tests/integration_tests/embeddings/test_deepinfra.py
rename to libs/community/tests/integration_tests/embeddings/test_deepinfra.py
index 17099615322..8b3fe25e667 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_deepinfra.py
+++ b/libs/community/tests/integration_tests/embeddings/test_deepinfra.py
@@ -1,6 +1,6 @@
"""Test DeepInfra API wrapper."""
-from langchain.embeddings import DeepInfraEmbeddings
+from langchain_community.embeddings import DeepInfraEmbeddings
def test_deepinfra_call() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_edenai.py b/libs/community/tests/integration_tests/embeddings/test_edenai.py
similarity index 90%
rename from libs/langchain/tests/integration_tests/embeddings/test_edenai.py
rename to libs/community/tests/integration_tests/embeddings/test_edenai.py
index bfc392da0b9..f200e30d35f 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_edenai.py
+++ b/libs/community/tests/integration_tests/embeddings/test_edenai.py
@@ -1,6 +1,6 @@
"""Test edenai embeddings."""
-from langchain.embeddings.edenai import EdenAiEmbeddings
+from langchain_community.embeddings.edenai import EdenAiEmbeddings
def test_edenai_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_elasticsearch.py b/libs/community/tests/integration_tests/embeddings/test_elasticsearch.py
similarity index 92%
rename from libs/langchain/tests/integration_tests/embeddings/test_elasticsearch.py
rename to libs/community/tests/integration_tests/embeddings/test_elasticsearch.py
index 2c01683113e..8547eb691b1 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_elasticsearch.py
+++ b/libs/community/tests/integration_tests/embeddings/test_elasticsearch.py
@@ -2,7 +2,7 @@
import pytest
-from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
+from langchain_community.embeddings.elasticsearch import ElasticsearchEmbeddings
@pytest.fixture
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_embaas.py b/libs/community/tests/integration_tests/embeddings/test_embaas.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/embeddings/test_embaas.py
rename to libs/community/tests/integration_tests/embeddings/test_embaas.py
index 8a13f4d9965..27d52f189a7 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_embaas.py
+++ b/libs/community/tests/integration_tests/embeddings/test_embaas.py
@@ -1,7 +1,7 @@
"""Test embaas embeddings."""
import responses
-from langchain.embeddings.embaas import EMBAAS_API_URL, EmbaasEmbeddings
+from langchain_community.embeddings.embaas import EMBAAS_API_URL, EmbaasEmbeddings
def test_embaas_embed_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_ernie.py b/libs/community/tests/integration_tests/embeddings/test_ernie.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/embeddings/test_ernie.py
rename to libs/community/tests/integration_tests/embeddings/test_ernie.py
index 9f47f1572fd..0a261fe30bd 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_ernie.py
+++ b/libs/community/tests/integration_tests/embeddings/test_ernie.py
@@ -1,6 +1,6 @@
import pytest
-from langchain.embeddings.ernie import ErnieEmbeddings
+from langchain_community.embeddings.ernie import ErnieEmbeddings
def test_embedding_documents_1() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_fastembed.py b/libs/community/tests/integration_tests/embeddings/test_fastembed.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/embeddings/test_fastembed.py
rename to libs/community/tests/integration_tests/embeddings/test_fastembed.py
index d690037ab77..80215e90563 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_fastembed.py
+++ b/libs/community/tests/integration_tests/embeddings/test_fastembed.py
@@ -1,7 +1,7 @@
"""Test FastEmbed embeddings."""
import pytest
-from langchain.embeddings.fastembed import FastEmbedEmbeddings
+from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
@pytest.mark.parametrize(
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_google_palm.py b/libs/community/tests/integration_tests/embeddings/test_google_palm.py
similarity index 92%
rename from libs/langchain/tests/integration_tests/embeddings/test_google_palm.py
rename to libs/community/tests/integration_tests/embeddings/test_google_palm.py
index 251c266cd9b..047edc76e16 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_google_palm.py
+++ b/libs/community/tests/integration_tests/embeddings/test_google_palm.py
@@ -3,7 +3,7 @@
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
"""
-from langchain.embeddings.google_palm import GooglePalmEmbeddings
+from langchain_community.embeddings.google_palm import GooglePalmEmbeddings
def test_google_palm_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_huggingface.py b/libs/community/tests/integration_tests/embeddings/test_huggingface.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/embeddings/test_huggingface.py
rename to libs/community/tests/integration_tests/embeddings/test_huggingface.py
index 9558d3a0ced..c780a9f3c5e 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_huggingface.py
+++ b/libs/community/tests/integration_tests/embeddings/test_huggingface.py
@@ -1,6 +1,6 @@
"""Test huggingface embeddings."""
-from langchain.embeddings.huggingface import (
+from langchain_community.embeddings.huggingface import (
HuggingFaceEmbeddings,
HuggingFaceInstructEmbeddings,
)
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_huggingface_hub.py b/libs/community/tests/integration_tests/embeddings/test_huggingface_hub.py
similarity index 92%
rename from libs/langchain/tests/integration_tests/embeddings/test_huggingface_hub.py
rename to libs/community/tests/integration_tests/embeddings/test_huggingface_hub.py
index 42dd55dbe63..b22bc940887 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_huggingface_hub.py
+++ b/libs/community/tests/integration_tests/embeddings/test_huggingface_hub.py
@@ -1,7 +1,7 @@
"""Test HuggingFaceHub embeddings."""
import pytest
-from langchain.embeddings import HuggingFaceHubEmbeddings
+from langchain_community.embeddings import HuggingFaceHubEmbeddings
def test_huggingfacehub_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_jina.py b/libs/community/tests/integration_tests/embeddings/test_jina.py
similarity index 89%
rename from libs/langchain/tests/integration_tests/embeddings/test_jina.py
rename to libs/community/tests/integration_tests/embeddings/test_jina.py
index 668c258f6f5..4c1dcca9b12 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_jina.py
+++ b/libs/community/tests/integration_tests/embeddings/test_jina.py
@@ -1,5 +1,5 @@
"""Test jina embeddings."""
-from langchain.embeddings.jina import JinaEmbeddings
+from langchain_community.embeddings.jina import JinaEmbeddings
def test_jina_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_johnsnowlabs.py b/libs/community/tests/integration_tests/embeddings/test_johnsnowlabs.py
similarity index 87%
rename from libs/langchain/tests/integration_tests/embeddings/test_johnsnowlabs.py
rename to libs/community/tests/integration_tests/embeddings/test_johnsnowlabs.py
index 3def60b56e7..157853f586b 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_johnsnowlabs.py
+++ b/libs/community/tests/integration_tests/embeddings/test_johnsnowlabs.py
@@ -1,6 +1,6 @@
"""Test johnsnowlabs embeddings."""
-from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
+from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
def test_johnsnowlabs_embed_document() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_llamacpp.py b/libs/community/tests/integration_tests/embeddings/test_llamacpp.py
similarity index 95%
rename from libs/langchain/tests/integration_tests/embeddings/test_llamacpp.py
rename to libs/community/tests/integration_tests/embeddings/test_llamacpp.py
index 36aed8e9f8e..d604f8aa107 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_llamacpp.py
+++ b/libs/community/tests/integration_tests/embeddings/test_llamacpp.py
@@ -3,7 +3,7 @@
import os
from urllib.request import urlretrieve
-from langchain.embeddings.llamacpp import LlamaCppEmbeddings
+from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
def get_model() -> str:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_modelscope_hub.py b/libs/community/tests/integration_tests/embeddings/test_modelscope_hub.py
similarity index 87%
rename from libs/langchain/tests/integration_tests/embeddings/test_modelscope_hub.py
rename to libs/community/tests/integration_tests/embeddings/test_modelscope_hub.py
index 103568af889..de4ed78f5da 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_modelscope_hub.py
+++ b/libs/community/tests/integration_tests/embeddings/test_modelscope_hub.py
@@ -1,5 +1,5 @@
"""Test modelscope embeddings."""
-from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
+from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings
def test_modelscope_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_mosaicml.py b/libs/community/tests/integration_tests/embeddings/test_mosaicml.py
similarity index 96%
rename from libs/langchain/tests/integration_tests/embeddings/test_mosaicml.py
rename to libs/community/tests/integration_tests/embeddings/test_mosaicml.py
index ae0bec3ddac..902d9315a58 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_mosaicml.py
+++ b/libs/community/tests/integration_tests/embeddings/test_mosaicml.py
@@ -1,5 +1,5 @@
"""Test mosaicml embeddings."""
-from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
+from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings
def test_mosaicml_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_octoai_embeddings.py b/libs/community/tests/integration_tests/embeddings/test_octoai_embeddings.py
similarity index 93%
rename from libs/langchain/tests/integration_tests/embeddings/test_octoai_embeddings.py
rename to libs/community/tests/integration_tests/embeddings/test_octoai_embeddings.py
index 5cd9aea5dc3..05cd3925521 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_octoai_embeddings.py
+++ b/libs/community/tests/integration_tests/embeddings/test_octoai_embeddings.py
@@ -1,6 +1,6 @@
"""Test octoai embeddings."""
-from langchain.embeddings.octoai_embeddings import (
+from langchain_community.embeddings.octoai_embeddings import (
OctoAIEmbeddings,
)
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_openai.py b/libs/community/tests/integration_tests/embeddings/test_openai.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/embeddings/test_openai.py
rename to libs/community/tests/integration_tests/embeddings/test_openai.py
index 8fab5789ccf..f59fb9f55df 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_openai.py
+++ b/libs/community/tests/integration_tests/embeddings/test_openai.py
@@ -2,7 +2,7 @@
import numpy as np
import pytest
-from langchain.embeddings.openai import OpenAIEmbeddings
+from langchain_community.embeddings.openai import OpenAIEmbeddings
@pytest.mark.scheduled
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_qianfan_endpoint.py b/libs/community/tests/integration_tests/embeddings/test_qianfan_endpoint.py
similarity index 87%
rename from libs/langchain/tests/integration_tests/embeddings/test_qianfan_endpoint.py
rename to libs/community/tests/integration_tests/embeddings/test_qianfan_endpoint.py
index 5c707bcc2f7..f257f61a021 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_qianfan_endpoint.py
+++ b/libs/community/tests/integration_tests/embeddings/test_qianfan_endpoint.py
@@ -1,5 +1,7 @@
"""Test Baidu Qianfan Embedding Endpoint."""
-from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
+from langchain_community.embeddings.baidu_qianfan_endpoint import (
+ QianfanEmbeddingsEndpoint,
+)
def test_embedding_multiple_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_self_hosted.py b/libs/community/tests/integration_tests/embeddings/test_self_hosted.py
similarity index 98%
rename from libs/langchain/tests/integration_tests/embeddings/test_self_hosted.py
rename to libs/community/tests/integration_tests/embeddings/test_self_hosted.py
index cb317d11c28..a270c2ce6b7 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_self_hosted.py
+++ b/libs/community/tests/integration_tests/embeddings/test_self_hosted.py
@@ -1,7 +1,7 @@
"""Test self-hosted embeddings."""
from typing import Any
-from langchain.embeddings import (
+from langchain_community.embeddings import (
SelfHostedEmbeddings,
SelfHostedHuggingFaceEmbeddings,
SelfHostedHuggingFaceInstructEmbeddings,
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_sentence_transformer.py b/libs/community/tests/integration_tests/embeddings/test_sentence_transformer.py
similarity index 89%
rename from libs/langchain/tests/integration_tests/embeddings/test_sentence_transformer.py
rename to libs/community/tests/integration_tests/embeddings/test_sentence_transformer.py
index ce253ef49c8..3890b2fd3c8 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_sentence_transformer.py
+++ b/libs/community/tests/integration_tests/embeddings/test_sentence_transformer.py
@@ -1,8 +1,10 @@
# flake8: noqa
"""Test sentence_transformer embeddings."""
-from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
-from langchain.vectorstores import Chroma
+from langchain_community.embeddings.sentence_transformer import (
+ SentenceTransformerEmbeddings,
+)
+from langchain_community.vectorstores import Chroma
def test_sentence_transformer_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_tensorflow_hub.py b/libs/community/tests/integration_tests/embeddings/test_tensorflow_hub.py
similarity index 89%
rename from libs/langchain/tests/integration_tests/embeddings/test_tensorflow_hub.py
rename to libs/community/tests/integration_tests/embeddings/test_tensorflow_hub.py
index 96bb007361f..be8062228da 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_tensorflow_hub.py
+++ b/libs/community/tests/integration_tests/embeddings/test_tensorflow_hub.py
@@ -1,5 +1,5 @@
"""Test TensorflowHub embeddings."""
-from langchain.embeddings import TensorflowHubEmbeddings
+from langchain_community.embeddings import TensorflowHubEmbeddings
def test_tensorflowhub_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py b/libs/community/tests/integration_tests/embeddings/test_vertexai.py
similarity index 94%
rename from libs/langchain/tests/integration_tests/embeddings/test_vertexai.py
rename to libs/community/tests/integration_tests/embeddings/test_vertexai.py
index bdfe1d24b26..10128469f2e 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py
+++ b/libs/community/tests/integration_tests/embeddings/test_vertexai.py
@@ -5,7 +5,7 @@ pip install google-cloud-aiplatform>=1.35.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
-from langchain.embeddings import VertexAIEmbeddings
+from langchain_community.embeddings import VertexAIEmbeddings
def test_embedding_documents() -> None:
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_voyageai.py b/libs/community/tests/integration_tests/embeddings/test_voyageai.py
similarity index 93%
rename from libs/langchain/tests/integration_tests/embeddings/test_voyageai.py
rename to libs/community/tests/integration_tests/embeddings/test_voyageai.py
index 623ea4551b1..b23dbd7f538 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_voyageai.py
+++ b/libs/community/tests/integration_tests/embeddings/test_voyageai.py
@@ -1,5 +1,5 @@
"""Test voyage embeddings."""
-from langchain.embeddings.voyageai import VoyageEmbeddings
+from langchain_community.embeddings.voyageai import VoyageEmbeddings
# Please set VOYAGE_API_KEY in the environment variables
MODEL = "voyage-01"
diff --git a/libs/langchain/tests/integration_tests/embeddings/test_xinference.py b/libs/community/tests/integration_tests/embeddings/test_xinference.py
similarity index 97%
rename from libs/langchain/tests/integration_tests/embeddings/test_xinference.py
rename to libs/community/tests/integration_tests/embeddings/test_xinference.py
index f488d497f7c..f09fe8fe4f1 100644
--- a/libs/langchain/tests/integration_tests/embeddings/test_xinference.py
+++ b/libs/community/tests/integration_tests/embeddings/test_xinference.py
@@ -4,7 +4,7 @@ from typing import AsyncGenerator, Tuple
import pytest_asyncio
-from langchain.embeddings import XinferenceEmbeddings
+from langchain_community.embeddings import XinferenceEmbeddings
@pytest_asyncio.fixture
diff --git a/libs/community/tests/integration_tests/examples/README.org b/libs/community/tests/integration_tests/examples/README.org
new file mode 100644
index 00000000000..5b9f4728040
--- /dev/null
+++ b/libs/community/tests/integration_tests/examples/README.org
@@ -0,0 +1,27 @@
+* Example Docs
+
+The sample docs directory contains the following files:
+
+- ~example-10k.html~ - A 10-K SEC filing in HTML format
+- ~layout-parser-paper.pdf~ - A PDF copy of the layout parser paper
+- ~factbook.xml~ / ~factbook.xsl~ - Example XML/XLS files that you
+ can use to test stylesheets
+
+These documents can be used to test out the parsers in the library. In
+addition, here are instructions for pulling in some sample docs that are
+too big to store in the repo.
+
+** XBRL 10-K
+
+You can get an example 10-K in inline XBRL format using the following
+~curl~. Note, you need to have the user agent set in the header or the
+SEC site will reject your request.
+
+#+BEGIN_SRC bash
+
+ curl -O \
+ -A '${organization} ${email}'
+ https://www.sec.gov/Archives/edgar/data/311094/000117184321001344/0001171843-21-001344.txt
+#+END_SRC
+
+You can parse this document using the HTML parser.
diff --git a/libs/community/tests/integration_tests/examples/README.rst b/libs/community/tests/integration_tests/examples/README.rst
new file mode 100644
index 00000000000..45630d0385d
--- /dev/null
+++ b/libs/community/tests/integration_tests/examples/README.rst
@@ -0,0 +1,28 @@
+Example Docs
+------------
+
+The sample docs directory contains the following files:
+
+- ``example-10k.html`` - A 10-K SEC filing in HTML format
+- ``layout-parser-paper.pdf`` - A PDF copy of the layout parser paper
+- ``factbook.xml``/``factbook.xsl`` - Example XML/XLS files that you
+ can use to test stylesheets
+
+These documents can be used to test out the parsers in the library. In
+addition, here are instructions for pulling in some sample docs that are
+too big to store in the repo.
+
+XBRL 10-K
+^^^^^^^^^
+
+You can get an example 10-K in inline XBRL format using the following
+``curl``. Note, you need to have the user agent set in the header or the
+SEC site will reject your request.
+
+.. code:: bash
+
+ curl -O \
+ -A '${organization} ${email}'
+ https://www.sec.gov/Archives/edgar/data/311094/000117184321001344/0001171843-21-001344.txt
+
+You can parse this document using the HTML parser.
diff --git a/libs/community/tests/integration_tests/examples/brandfetch-brandfetch-2.0.0-resolved.json b/libs/community/tests/integration_tests/examples/brandfetch-brandfetch-2.0.0-resolved.json
new file mode 100644
index 00000000000..de37dbf5fba
--- /dev/null
+++ b/libs/community/tests/integration_tests/examples/brandfetch-brandfetch-2.0.0-resolved.json
@@ -0,0 +1,282 @@
+{
+ "openapi": "3.0.1",
+ "info": {
+ "title": "Brandfetch API",
+ "description": "Brandfetch API (v2) for retrieving brand information.\n\nSee our [documentation](https://docs.brandfetch.com/) for further details. ",
+ "termsOfService": "https://brandfetch.com/terms",
+ "contact": {
+ "url": "https://brandfetch.com/developers"
+ },
+ "version": "2.0.0"
+ },
+ "externalDocs": {
+ "description": "Documentation",
+ "url": "https://docs.brandfetch.com/"
+ },
+ "servers": [
+ {
+ "url": "https://api.brandfetch.io/v2"
+ }
+ ],
+ "paths": {
+ "/brands/{domainOrId}": {
+ "get": {
+ "summary": "Retrieve a brand",
+ "description": "Fetch brand information by domain or ID\n\nFurther details here: https://docs.brandfetch.com/reference/retrieve-brand\n",
+ "parameters": [
+ {
+ "name": "domainOrId",
+ "in": "path",
+ "description": "Domain or ID of the brand",
+ "required": true,
+ "style": "simple",
+ "explode": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "responses": {
+ "200": {
+ "description": "Brand data",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Brand"
+ },
+ "examples": {
+ "brandfetch.com": {
+ "value": "{\"name\":\"Brandfetch\",\"domain\":\"brandfetch.com\",\"claimed\":true,\"description\":\"All brands. In one place\",\"links\":[{\"name\":\"twitter\",\"url\":\"https://twitter.com/brandfetch\"},{\"name\":\"linkedin\",\"url\":\"https://linkedin.com/company/brandfetch\"}],\"logos\":[{\"type\":\"logo\",\"theme\":\"light\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/id9WE9j86h.svg\",\"background\":\"transparent\",\"format\":\"svg\",\"size\":15555}]},{\"type\":\"logo\",\"theme\":\"dark\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idWbsK1VCy.png\",\"background\":\"transparent\",\"format\":\"png\",\"height\":215,\"width\":800,\"size\":33937},{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idtCMfbWO0.svg\",\"background\":\"transparent\",\"format\":\"svg\",\"height\":null,\"width\":null,\"size\":15567}]},{\"type\":\"symbol\",\"theme\":\"light\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idXGq6SIu2.svg\",\"background\":\"transparent\",\"format\":\"svg\",\"size\":2215}]},{\"type\":\"symbol\",\"theme\":\"dark\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/iddCQ52AR5.svg\",\"background\":\"transparent\",\"format\":\"svg\",\"size\":2215}]},{\"type\":\"icon\",\"theme\":\"dark\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idls3LaPPQ.png\",\"background\":null,\"format\":\"png\",\"height\":400,\"width\":400,\"size\":2565}]}],\"colors\":[{\"hex\":\"#0084ff\",\"type\":\"accent\",\"brightness\":113},{\"hex\":\"#00193E\",\"type\":\"brand\",\"brightness\":22},{\"hex\":\"#F03063\",\"type\":\"brand\",\"brightness\":93},{\"hex\":\"#7B0095\",\"type\":\"brand\",\"brightness\":37},{\"hex\":\"#76CC4B\",\"type\":\"brand\",\"brightness\":176},{\"hex\":\"#FFDA00\",\"type\":\"brand\",\"brightness\":210},{\"hex\":\"#000000\",\"type\":\"dark\",\"brightness\":0},{\"hex\":\"#ffffff\",\"type\":\"light\",\"brightness\":255}],\"fonts\":[{\"name\":\"Poppins\",\"type\":\"title\",\"origin\":\"google\",\"originId\":\"Poppins\",\"weights\":[]},{\"name\":\"Inter\",\"type\":\"body\",\"origin\":\"google\",\"originId\":\"Inter\",\"weights\":[]}],\"images\":[{\"type\":\"banner\",\"formats\":[{\"src\":\"https://asset.brandfetch.io/idL0iThUh6/idUuia5imo.png\",\"background\":\"transparent\",\"format\":\"png\",\"height\":500,\"width\":1500,\"size\":5539}]}]}"
+ }
+ }
+ }
+ }
+ },
+ "400": {
+ "description": "Invalid domain or ID supplied"
+ },
+ "404": {
+ "description": "The brand does not exist or the domain can't be resolved."
+ }
+ },
+ "security": [
+ {
+ "bearerAuth": []
+ }
+ ]
+ }
+ }
+ },
+ "components": {
+ "schemas": {
+ "Brand": {
+ "required": [
+ "claimed",
+ "colors",
+ "description",
+ "domain",
+ "fonts",
+ "images",
+ "links",
+ "logos",
+ "name"
+ ],
+ "type": "object",
+ "properties": {
+ "images": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ImageAsset"
+ }
+ },
+ "fonts": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/FontAsset"
+ }
+ },
+ "domain": {
+ "type": "string"
+ },
+ "claimed": {
+ "type": "boolean"
+ },
+ "name": {
+ "type": "string"
+ },
+ "description": {
+ "type": "string"
+ },
+ "links": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Brand_links"
+ }
+ },
+ "logos": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ImageAsset"
+ }
+ },
+ "colors": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ColorAsset"
+ }
+ }
+ },
+ "description": "Object representing a brand"
+ },
+ "ColorAsset": {
+ "required": [
+ "brightness",
+ "hex",
+ "type"
+ ],
+ "type": "object",
+ "properties": {
+ "brightness": {
+ "type": "integer"
+ },
+ "hex": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "enum": [
+ "accent",
+ "brand",
+ "customizable",
+ "dark",
+ "light",
+ "vibrant"
+ ]
+ }
+ },
+ "description": "Brand color asset"
+ },
+ "FontAsset": {
+ "type": "object",
+ "properties": {
+ "originId": {
+ "type": "string"
+ },
+ "origin": {
+ "type": "string",
+ "enum": [
+ "adobe",
+ "custom",
+ "google",
+ "system"
+ ]
+ },
+ "name": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string"
+ },
+ "weights": {
+ "type": "array",
+ "items": {
+ "type": "number"
+ }
+ },
+ "items": {
+ "type": "string"
+ }
+ },
+ "description": "Brand font asset"
+ },
+ "ImageAsset": {
+ "required": [
+ "formats",
+ "theme",
+ "type"
+ ],
+ "type": "object",
+ "properties": {
+ "formats": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ImageFormat"
+ }
+ },
+ "theme": {
+ "type": "string",
+ "enum": [
+ "light",
+ "dark"
+ ]
+ },
+ "type": {
+ "type": "string",
+ "enum": [
+ "logo",
+ "icon",
+ "symbol",
+ "banner"
+ ]
+ }
+ },
+ "description": "Brand image asset"
+ },
+ "ImageFormat": {
+ "required": [
+ "background",
+ "format",
+ "size",
+ "src"
+ ],
+ "type": "object",
+ "properties": {
+ "size": {
+ "type": "integer"
+ },
+ "src": {
+ "type": "string"
+ },
+ "background": {
+ "type": "string",
+ "enum": [
+ "transparent"
+ ]
+ },
+ "format": {
+ "type": "string"
+ },
+ "width": {
+ "type": "integer"
+ },
+ "height": {
+ "type": "integer"
+ }
+ },
+ "description": "Brand image asset image format"
+ },
+ "Brand_links": {
+ "required": [
+ "name",
+ "url"
+ ],
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "url": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "securitySchemes": {
+ "bearerAuth": {
+ "type": "http",
+ "scheme": "bearer",
+ "bearerFormat": "API Key"
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/libs/community/tests/integration_tests/examples/default-encoding.py b/libs/community/tests/integration_tests/examples/default-encoding.py
new file mode 100644
index 00000000000..9a09cc8271f
--- /dev/null
+++ b/libs/community/tests/integration_tests/examples/default-encoding.py
@@ -0,0 +1 @@
+u = "π¦π"
diff --git a/libs/community/tests/integration_tests/examples/docusaurus-sitemap.xml b/libs/community/tests/integration_tests/examples/docusaurus-sitemap.xml
new file mode 100644
index 00000000000..eebae785b88
--- /dev/null
+++ b/libs/community/tests/integration_tests/examples/docusaurus-sitemap.xml
@@ -0,0 +1,42 @@
+
+
+
+ https://python.langchain.com/docs/integrations/document_loaders/sitemap
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/cookbook
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/docs/additional_resources
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/docs/modules/chains/how_to/
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/docs/use_cases/question_answering/local_retrieval_qa
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/docs/use_cases/summarization
+ weekly
+ 0.5
+
+
+ https://python.langchain.com/
+ weekly
+ 0.5
+
+
\ No newline at end of file
diff --git a/libs/community/tests/integration_tests/examples/duplicate-chars.pdf b/libs/community/tests/integration_tests/examples/duplicate-chars.pdf
new file mode 100644
index 00000000000..47467cd035d
Binary files /dev/null and b/libs/community/tests/integration_tests/examples/duplicate-chars.pdf differ
diff --git a/libs/community/tests/integration_tests/examples/example-utf8.html b/libs/community/tests/integration_tests/examples/example-utf8.html
new file mode 100644
index 00000000000..f96e20fcedb
--- /dev/null
+++ b/libs/community/tests/integration_tests/examples/example-utf8.html
@@ -0,0 +1,25 @@
+
+
+ Chew dad's slippers
+
+
+
+ Instead of drinking water from the cat bowl, make sure to steal water from
+ the toilet
+
+
Chase the red dot
+
+ Munch, munch, chomp, chomp hate dogs. Spill litter box, scratch at owner,
+ destroy all furniture, especially couch get scared by sudden appearance of
+ cucumber cat is love, cat is life fat baby cat best buddy little guy for
+ catch eat throw up catch eat throw up bad birds jump on fridge. Purr like
+ a car engine oh yes, there is my human woman she does best pats ever that
+ all i like about her hiss meow .
+
+
+ Dead stare with ears cocked when βownersβ are asleep, cry for no apparent
+ reason meow all night. Plop down in the middle where everybody walks favor
+ packaging over toy. Sit on the laptop kitty pounce, trip, faceplant.
+
+ Instead of drinking water from the cat bowl, make sure to steal water from
+ the toilet
+
+
Chase the red dot
+
+ Munch, munch, chomp, chomp hate dogs. Spill litter box, scratch at owner,
+ destroy all furniture, especially couch get scared by sudden appearance of
+ cucumber cat is love, cat is life fat baby cat best buddy little guy for
+ catch eat throw up catch eat throw up bad birds jump on fridge. Purr like
+ a car engine oh yes, there is my human woman she does best pats ever that
+ all i like about her hiss meow .
+
+
+ Dead stare with ears cocked when owners are asleep, cry for no apparent
+ reason meow all night. Plop down in the middle where everybody walks favor
+ packaging over toy. Sit on the laptop kitty pounce, trip, faceplant.
+