diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 8a39944be09..3a8a7ebd9be 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -70,17 +70,18 @@ Install Poetry: **[documentation on how to install it](https://python-poetry.org ❗Note: If you use `Conda` or `Pyenv` as your environment/package manager, after installing Poetry, tell Poetry to use the virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`) -### Core vs. Experimental +### Different packages -This repository contains three separate projects: -- `langchain`: core langchain code, abstractions, and use cases. -- `langchain_core`: contain interfaces for key abstractions as well as logic for combining them in chains (LCEL). -- `langchain_experimental`: see the [Experimental README](https://github.com/langchain-ai/langchain/tree/master/libs/experimental/README.md) for more information. +This repository contains multiple packages: +- `langchain-core`: Base interfaces for key abstractions as well as logic for combining them in chains (LangChain Expression Language). +- `langchain-community`: Third-party integrations of various components. +- `langchain`: Chains, agents, and retrieval logic that makes up the cognitive architecture of your applications. +- `langchain-experimental`: Components and chains that are experimental, either in the sense that the techniques are novel and still being tested, or they require giving the LLM more access than would be possible in most production systems. Each of these has its own development environment. Docs are run from the top-level makefile, but development is split across separate test & release flows. -For this quickstart, start with langchain core: +For this quickstart, start with langchain: ```bash cd libs/langchain @@ -330,15 +331,50 @@ what you wanted by clicking the `View deployment` or `Visit Preview` buttons on This will take you to a preview of the documentation changes. This preview is created by [Vercel](https://vercel.com/docs/getting-started-with-vercel). -## 🏭 Release Process +## πŸ“• Releases & Versioning As of now, LangChain has an ad hoc release process: releases are cut with high frequency by -a developer and published to [PyPI](https://pypi.org/project/langchain/). +a maintainer and published to [PyPI](https://pypi.org/). +The different packages are versioned slightly differently. -LangChain follows the [semver](https://semver.org/) versioning standard. However, as pre-1.0 software, -even patch releases may contain [non-backwards-compatible changes](https://semver.org/#spec-item-4). +### `langchain-core` -### 🌟 Recognition +`langchain-core` is currently on version `0.1.x`. + +As `langchain-core` contains the base abstractions and runtime for the whole LangChain ecosystem, we will communicate any breaking changes with advance notice and version bumps. The exception for this is anything in `langchain_core.beta`. The reason for `langchain_core.beta` is that given the rate of change of the field, being able to move quickly is still a priority, and this module is our attempt to do so. + +Minor version increases will occur for: + +- Breaking changes for any public interfaces NOT in `langchain_core.beta` + +Patch version increases will occur for: + +- Bug fixes +- New features +- Any changes to private interfaces +- Any changes to `langchain_core.beta` + +### `langchain` + +`langchain` is currently on version `0.0.x` + +All changes will be accompanied by a patch version increase. Any changes to public interfaces are nearly always done in a backwards compatible way and will be communicated ahead of time when they are not backwards compatible. + +We are targeting January 2024 for a release of `langchain` v0.1, at which point `langchain` will adopt the same versioning policy as `langchain-core`. + +### `langchain-community` + +`langchain-community` is currently on version `0.0.x` + +All changes will be accompanied by a patch version increase. + +### `langchain-experimental` + +`langchain-experimental` is currently on version `0.0.x` + +All changes will be accompanied by a patch version increase. + +## 🌟 Recognition If your contribution has made its way into a release, we will want to give you credit on Twitter (only if you want though)! If you have a Twitter account you would like us to mention, please let us know in the PR or through another means. diff --git a/.github/scripts/check_diff.py b/.github/scripts/check_diff.py index 3a93baa588e..939fd9a31b8 100644 --- a/.github/scripts/check_diff.py +++ b/.github/scripts/check_diff.py @@ -5,6 +5,7 @@ ALL_DIRS = { "libs/core", "libs/langchain", "libs/experimental", + "libs/community", } if __name__ == "__main__": diff --git a/.github/workflows/_all_ci.yml b/.github/workflows/_all_ci.yml index 37aacd42653..58bcc80336a 100644 --- a/.github/workflows/_all_ci.yml +++ b/.github/workflows/_all_ci.yml @@ -18,6 +18,7 @@ on: - libs/langchain - libs/core - libs/experimental + - libs/community # If another push to the same PR or branch happens while this workflow is still running, diff --git a/.github/workflows/_lint.yml b/.github/workflows/_lint.yml index c94b23ae0d9..e34305aa0bc 100644 --- a/.github/workflows/_lint.yml +++ b/.github/workflows/_lint.yml @@ -85,7 +85,8 @@ jobs: with: path: | ${{ env.WORKDIR }}/.mypy_cache - key: mypy-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }} + key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }} + - name: Analysing the code with our lint working-directory: ${{ inputs.working-directory }} @@ -105,13 +106,13 @@ jobs: run: | poetry install --with test - - name: Get .mypy_cache to speed up mypy + - name: Get .mypy_cache_test to speed up mypy uses: actions/cache@v3 env: SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" with: path: | - ${{ env.WORKDIR }}/.mypy_cache + ${{ env.WORKDIR }}/.mypy_cache_test key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }} - name: Analysing the code with our lint diff --git a/.github/workflows/_release.yml b/.github/workflows/_release.yml index 9c306063447..7eb32ba9f55 100644 --- a/.github/workflows/_release.yml +++ b/.github/workflows/_release.yml @@ -7,6 +7,17 @@ on: required: true type: string description: "From which folder this pipeline executes" + workflow_dispatch: + inputs: + working-directory: + required: true + type: choice + default: 'libs/langchain' + options: + - libs/langchain + - libs/core + - libs/experimental + - libs/community env: PYTHON_VERSION: "3.10" @@ -14,7 +25,7 @@ env: jobs: build: - if: github.ref == 'refs/heads/master' +# if: github.ref == 'refs/heads/master' runs-on: ubuntu-latest outputs: diff --git a/.github/workflows/_test_release.yml b/.github/workflows/_test_release.yml index 0fc25a75164..dc67d1b430e 100644 --- a/.github/workflows/_test_release.yml +++ b/.github/workflows/_test_release.yml @@ -14,7 +14,7 @@ env: jobs: build: - if: github.ref == 'refs/heads/master' +# if: github.ref == 'refs/heads/master' runs-on: ubuntu-latest outputs: diff --git a/.github/workflows/langchain_openai_release.yml b/.github/workflows/langchain_openai_release.yml new file mode 100644 index 00000000000..244c292c2e3 --- /dev/null +++ b/.github/workflows/langchain_openai_release.yml @@ -0,0 +1,13 @@ +--- +name: libs/core Release + +on: + workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI + +jobs: + release: + uses: + ./.github/workflows/_release.yml + with: + working-directory: libs/core + secrets: inherit diff --git a/.scripts/community_split/libs/community/langchain_community/__init__.py b/.scripts/community_split/libs/community/langchain_community/__init__.py new file mode 100644 index 00000000000..ac7aeef6faf --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/__init__.py @@ -0,0 +1,9 @@ +"""Main entrypoint into package.""" +from importlib import metadata + +try: + __version__ = metadata.version(__package__) +except metadata.PackageNotFoundError: + # Case where package metadata is not available. + __version__ = "" +del metadata # optional, avoids polluting the results of dir(__package__) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/__init__.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/__init__.py new file mode 100644 index 00000000000..f23fb6a7e82 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/__init__.py @@ -0,0 +1,79 @@ +"""Agent toolkits contain integrations with various resources and services. + +LangChain has a large ecosystem of integrations with various external resources +like local and remote file systems, APIs and databases. + +These integrations allow developers to create versatile applications that combine the +power of LLMs with the ability to access, interact with and manipulate external +resources. + +When developing an application, developers should inspect the capabilities and +permissions of the tools that underlie the given agent toolkit, and determine +whether permissions of the given toolkit are appropriate for the application. + +See [Security](https://python.langchain.com/docs/security) for more information. +""" +from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit +from langchain_community.agent_toolkits.amadeus.toolkit import AmadeusToolkit +from langchain_community.agent_toolkits.azure_cognitive_services import ( + AzureCognitiveServicesToolkit, +) +from langchain_community.agent_toolkits.conversational_retrieval.openai_functions import ( # noqa: E501 + create_conversational_retrieval_agent, +) +from langchain_community.agent_toolkits.file_management.toolkit import ( + FileManagementToolkit, +) +from langchain_community.agent_toolkits.gmail.toolkit import GmailToolkit +from langchain_community.agent_toolkits.jira.toolkit import JiraToolkit +from langchain_community.agent_toolkits.json.base import create_json_agent +from langchain_community.agent_toolkits.json.toolkit import JsonToolkit +from langchain_community.agent_toolkits.multion.toolkit import MultionToolkit +from langchain_community.agent_toolkits.nasa.toolkit import NasaToolkit +from langchain_community.agent_toolkits.nla.toolkit import NLAToolkit +from langchain_community.agent_toolkits.office365.toolkit import O365Toolkit +from langchain_community.agent_toolkits.openapi.base import create_openapi_agent +from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit +from langchain_community.agent_toolkits.playwright.toolkit import ( + PlayWrightBrowserToolkit, +) +from langchain_community.agent_toolkits.powerbi.base import create_pbi_agent +from langchain_community.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent +from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit +from langchain_community.agent_toolkits.slack.toolkit import SlackToolkit +from langchain_community.agent_toolkits.spark_sql.base import create_spark_sql_agent +from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit +from langchain_community.agent_toolkits.sql.base import create_sql_agent +from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain_community.agent_toolkits.steam.toolkit import SteamToolkit +from langchain_community.agent_toolkits.zapier.toolkit import ZapierToolkit + + +__all__ = [ + "AINetworkToolkit", + "AmadeusToolkit", + "AzureCognitiveServicesToolkit", + "FileManagementToolkit", + "GmailToolkit", + "JiraToolkit", + "JsonToolkit", + "MultionToolkit", + "NasaToolkit", + "NLAToolkit", + "O365Toolkit", + "OpenAPIToolkit", + "PlayWrightBrowserToolkit", + "PowerBIToolkit", + "SlackToolkit", + "SteamToolkit", + "SQLDatabaseToolkit", + "SparkSQLToolkit", + "ZapierToolkit", + "create_json_agent", + "create_openapi_agent", + "create_pbi_agent", + "create_pbi_chat_agent", + "create_spark_sql_agent", + "create_sql_agent", + "create_conversational_retrieval_agent", +] diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py new file mode 100644 index 00000000000..e1c2c40bcfb --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import Any, List, Optional, TYPE_CHECKING + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.memory import BaseMemory +from langchain_core.messages import SystemMessage +from langchain_core.prompts.chat import MessagesPlaceholder +from langchain_core.tools import BaseTool + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def _get_default_system_message() -> SystemMessage: + return SystemMessage( + content=( + "Do your best to answer the questions. " + "Feel free to use any tools available to look up " + "relevant information, only if necessary" + ) + ) + +def create_conversational_retrieval_agent( + llm: BaseLanguageModel, + tools: List[BaseTool], + remember_intermediate_steps: bool = True, + memory_key: str = "chat_history", + system_message: Optional[SystemMessage] = None, + verbose: bool = False, + max_token_limit: int = 2000, + **kwargs: Any, +) -> AgentExecutor: + """A convenience method for creating a conversational retrieval agent. + + Args: + llm: The language model to use, should be ChatOpenAI + tools: A list of tools the agent has access to + remember_intermediate_steps: Whether the agent should remember intermediate + steps or not. Intermediate steps refer to prior action/observation + pairs from previous questions. The benefit of remembering these is if + there is relevant information in there, the agent can use it to answer + follow up questions. The downside is it will take up more tokens. + memory_key: The name of the memory key in the prompt. + system_message: The system message to use. By default, a basic one will + be used. + verbose: Whether or not the final AgentExecutor should be verbose or not, + defaults to False. + max_token_limit: The max number of tokens to keep around in memory. + Defaults to 2000. + + Returns: + An agent executor initialized appropriately + """ + from langchain.agents.agent import AgentExecutor + from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( + AgentTokenBufferMemory, + ) + from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent + from langchain.memory.token_buffer import ConversationTokenBufferMemory + + if remember_intermediate_steps: + memory: BaseMemory = AgentTokenBufferMemory( + memory_key=memory_key, llm=llm, max_token_limit=max_token_limit + ) + else: + memory = ConversationTokenBufferMemory( + memory_key=memory_key, + return_messages=True, + output_key="output", + llm=llm, + max_token_limit=max_token_limit, + ) + + _system_message = system_message or _get_default_system_message() + prompt = OpenAIFunctionsAgent.create_prompt( + system_message=_system_message, + extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)], + ) + agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) + return AgentExecutor( + agent=agent, + tools=tools, + memory=memory, + verbose=verbose, + return_intermediate_steps=remember_intermediate_steps, + **kwargs, + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/json/base.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/json/base.py new file mode 100644 index 00000000000..0d9f10d14c6 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/json/base.py @@ -0,0 +1,53 @@ +"""Json agent.""" +from __future__ import annotations +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX +from langchain_community.agent_toolkits.json.toolkit import JsonToolkit + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def create_json_agent( + llm: BaseLanguageModel, + toolkit: JsonToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = JSON_PREFIX, + suffix: str = JSON_SUFFIX, + format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a json agent from an LLM and tools.""" + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + tools = toolkit.get_tools() + prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {} + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + **prompt_params, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/nla/tool.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/nla/tool.py new file mode 100644 index 00000000000..7d94b2f9bca --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/nla/tool.py @@ -0,0 +1,57 @@ +"""Tool for interacting with a single API with natural language definition.""" + +from __future__ import annotations +from typing import Any, Optional, TYPE_CHECKING + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.tools import Tool + +from langchain_community.tools.openapi.utils.api_models import APIOperation +from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec +from langchain_community.utilities.requests import Requests + +if TYPE_CHECKING: + from langchain.chains.api.openapi.chain import OpenAPIEndpointChain + + +class NLATool(Tool): + """Natural Language API Tool.""" + + @classmethod + def from_open_api_endpoint_chain( + cls, chain: OpenAPIEndpointChain, api_title: str + ) -> "NLATool": + """Convert an endpoint chain to an API endpoint tool.""" + expanded_name = ( + f'{api_title.replace(" ", "_")}.{chain.api_operation.operation_id}' + ) + description = ( + f"I'm an AI from {api_title}. Instruct what you want," + " and I'll assist via an API with description:" + f" {chain.api_operation.description}" + ) + return cls(name=expanded_name, func=chain.run, description=description) + + @classmethod + def from_llm_and_method( + cls, + llm: BaseLanguageModel, + path: str, + method: str, + spec: OpenAPISpec, + requests: Optional[Requests] = None, + verbose: bool = False, + return_intermediate_steps: bool = False, + **kwargs: Any, + ) -> "NLATool": + """Instantiate the tool from the specified path and method.""" + api_operation = APIOperation.from_openapi_spec(spec, path, method) + chain = OpenAPIEndpointChain.from_api_operation( + api_operation, + llm, + requests=requests, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + **kwargs, + ) + return cls.from_open_api_endpoint_chain(chain, spec.info.title) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/base.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/base.py new file mode 100644 index 00000000000..f8dc14d85f3 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/base.py @@ -0,0 +1,77 @@ +"""OpenAPI spec agent.""" +from __future__ import annotations +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.openapi.prompt import ( + OPENAPI_PREFIX, + OPENAPI_SUFFIX, +) +from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def create_openapi_agent( + llm: BaseLanguageModel, + toolkit: OpenAPIToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = OPENAPI_PREFIX, + suffix: str = OPENAPI_SUFFIX, + format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + verbose: bool = False, + return_intermediate_steps: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct an OpenAPI agent from an LLM and tools. + + *Security Note*: When creating an OpenAPI agent, check the permissions + and capabilities of the underlying toolkit. + + For example, if the default implementation of OpenAPIToolkit + uses the RequestsToolkit which contains tools to make arbitrary + network requests against any URL (e.g., GET, POST, PATCH, PUT, DELETE), + + Control access to who can submit issue requests using this toolkit and + what network access it has. + + See https://python.langchain.com/docs/security for more information. + """ + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + tools = toolkit.get_tools() + prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {} + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + **prompt_params + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/planner.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/planner.py new file mode 100644 index 00000000000..55702b3bb09 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/planner.py @@ -0,0 +1,370 @@ +"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach.""" +import json +import re +from functools import partial +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING + +import yaml +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.pydantic_v1 import Field +from langchain_core.tools import BaseTool, Tool +from langchain_community.llms import OpenAI + +from langchain_community.agent_toolkits.openapi.planner_prompt import ( + API_CONTROLLER_PROMPT, + API_CONTROLLER_TOOL_DESCRIPTION, + API_CONTROLLER_TOOL_NAME, + API_ORCHESTRATOR_PROMPT, + API_PLANNER_PROMPT, + API_PLANNER_TOOL_DESCRIPTION, + API_PLANNER_TOOL_NAME, + PARSING_DELETE_PROMPT, + PARSING_GET_PROMPT, + PARSING_PATCH_PROMPT, + PARSING_POST_PROMPT, + PARSING_PUT_PROMPT, + REQUESTS_DELETE_TOOL_DESCRIPTION, + REQUESTS_GET_TOOL_DESCRIPTION, + REQUESTS_PATCH_TOOL_DESCRIPTION, + REQUESTS_POST_TOOL_DESCRIPTION, + REQUESTS_PUT_TOOL_DESCRIPTION, +) +from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec +from langchain_community.tools.requests.tool import BaseRequestsTool +from langchain_community.utilities.requests import RequestsWrapper + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + from langchain.chains.llm import LLMChain + from langchain.memory import ReadOnlySharedMemory + +# +# Requests tools with LLM-instructed extraction of truncated responses. +# +# Of course, truncating so bluntly may lose a lot of valuable +# information in the response. +# However, the goal for now is to have only a single inference step. +MAX_RESPONSE_LENGTH = 5000 +"""Maximum length of the response to be returned.""" + + +def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain: + from langchain.chains.llm import LLMChain + return LLMChain( + llm=OpenAI(), + prompt=prompt, + ) + + +def _get_default_llm_chain_factory( + prompt: BasePromptTemplate, +) -> Callable[[], LLMChain]: + """Returns a default LLMChain factory.""" + return partial(_get_default_llm_chain, prompt) + + +class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): + """Requests GET tool with LLM-instructed extraction of truncated responses.""" + + name: str = "requests_get" + """Tool name.""" + description = REQUESTS_GET_TOOL_DESCRIPTION + """Tool description.""" + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """Maximum length of the response to be returned.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT) + ) + """LLMChain used to extract the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + data_params = data.get("params") + response = self.requests_wrapper.get(data["url"], params=data_params) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): + """Requests POST tool with LLM-instructed extraction of truncated responses.""" + + name: str = "requests_post" + """Tool name.""" + description = REQUESTS_POST_TOOL_DESCRIPTION + """Tool description.""" + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """Maximum length of the response to be returned.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT) + ) + """LLMChain used to extract the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + response = self.requests_wrapper.post(data["url"], data["data"]) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool): + """Requests PATCH tool with LLM-instructed extraction of truncated responses.""" + + name: str = "requests_patch" + """Tool name.""" + description = REQUESTS_PATCH_TOOL_DESCRIPTION + """Tool description.""" + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """Maximum length of the response to be returned.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT) + ) + """LLMChain used to extract the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + response = self.requests_wrapper.patch(data["url"], data["data"]) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool): + """Requests PUT tool with LLM-instructed extraction of truncated responses.""" + + name: str = "requests_put" + """Tool name.""" + description = REQUESTS_PUT_TOOL_DESCRIPTION + """Tool description.""" + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """Maximum length of the response to be returned.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT) + ) + """LLMChain used to extract the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + response = self.requests_wrapper.put(data["url"], data["data"]) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool): + """A tool that sends a DELETE request and parses the response.""" + + name: str = "requests_delete" + """The name of the tool.""" + description = REQUESTS_DELETE_TOOL_DESCRIPTION + """The description of the tool.""" + + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """The maximum length of the response.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT) + ) + """The LLM chain used to parse the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + response = self.requests_wrapper.delete(data["url"]) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +# +# Orchestrator, planner, controller. +# +def _create_api_planner_tool( + api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel +) -> Tool: + from langchain.chains.llm import LLMChain + endpoint_descriptions = [ + f"{name} {description}" for name, description, _ in api_spec.endpoints + ] + prompt = PromptTemplate( + template=API_PLANNER_PROMPT, + input_variables=["query"], + partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)}, + ) + chain = LLMChain(llm=llm, prompt=prompt) + tool = Tool( + name=API_PLANNER_TOOL_NAME, + description=API_PLANNER_TOOL_DESCRIPTION, + func=chain.run, + ) + return tool + + +def _create_api_controller_agent( + api_url: str, + api_docs: str, + requests_wrapper: RequestsWrapper, + llm: BaseLanguageModel, +) -> AgentExecutor: + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.agents.agent import AgentExecutor + from langchain.chains.llm import LLMChain + get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT) + post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT) + tools: List[BaseTool] = [ + RequestsGetToolWithParsing( + requests_wrapper=requests_wrapper, llm_chain=get_llm_chain + ), + RequestsPostToolWithParsing( + requests_wrapper=requests_wrapper, llm_chain=post_llm_chain + ), + ] + prompt = PromptTemplate( + template=API_CONTROLLER_PROMPT, + input_variables=["input", "agent_scratchpad"], + partial_variables={ + "api_url": api_url, + "api_docs": api_docs, + "tool_names": ", ".join([tool.name for tool in tools]), + "tool_descriptions": "\n".join( + [f"{tool.name}: {tool.description}" for tool in tools] + ), + }, + ) + agent = ZeroShotAgent( + llm_chain=LLMChain(llm=llm, prompt=prompt), + allowed_tools=[tool.name for tool in tools], + ) + return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True) + + +def _create_api_controller_tool( + api_spec: ReducedOpenAPISpec, + requests_wrapper: RequestsWrapper, + llm: BaseLanguageModel, +) -> Tool: + """Expose controller as a tool. + + The tool is invoked with a plan from the planner, and dynamically + creates a controller agent with relevant documentation only to + constrain the context. + """ + + base_url = api_spec.servers[0]["url"] # TODO: do better. + + def _create_and_run_api_controller_agent(plan_str: str) -> str: + pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*" + matches = re.findall(pattern, plan_str) + endpoint_names = [ + "{method} {route}".format(method=method, route=route.split("?")[0]) + for method, route in matches + ] + docs_str = "" + for endpoint_name in endpoint_names: + found_match = False + for name, _, docs in api_spec.endpoints: + regex_name = re.compile(re.sub("\{.*?\}", ".*", name)) + if regex_name.match(endpoint_name): + found_match = True + docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n" + if not found_match: + raise ValueError(f"{endpoint_name} endpoint does not exist.") + + agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm) + return agent.run(plan_str) + + return Tool( + name=API_CONTROLLER_TOOL_NAME, + func=_create_and_run_api_controller_agent, + description=API_CONTROLLER_TOOL_DESCRIPTION, + ) + + +def create_openapi_agent( + api_spec: ReducedOpenAPISpec, + requests_wrapper: RequestsWrapper, + llm: BaseLanguageModel, + shared_memory: Optional[ReadOnlySharedMemory] = None, + callback_manager: Optional[BaseCallbackManager] = None, + verbose: bool = True, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Instantiate OpenAI API planner and controller for a given spec. + + Inject credentials via requests_wrapper. + + We use a top-level "orchestrator" agent to invoke the planner and controller, + rather than a top-level planner + that invokes a controller with its plan. This is to keep the planner simple. + """ + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.agents.agent import AgentExecutor + from langchain.chains.llm import LLMChain + tools = [ + _create_api_planner_tool(api_spec, llm), + _create_api_controller_tool(api_spec, requests_wrapper, llm), + ] + prompt = PromptTemplate( + template=API_ORCHESTRATOR_PROMPT, + input_variables=["input", "agent_scratchpad"], + partial_variables={ + "tool_names": ", ".join([tool.name for tool in tools]), + "tool_descriptions": "\n".join( + [f"{tool.name}: {tool.description}" for tool in tools] + ), + }, + ) + agent = ZeroShotAgent( + llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory), + allowed_tools=[tool.name for tool in tools], + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py new file mode 100644 index 00000000000..5b7f3fcbd52 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py @@ -0,0 +1,90 @@ +"""Requests toolkit.""" +from __future__ import annotations + +from typing import Any, List + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.tools import Tool + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.agent_toolkits.json.base import create_json_agent +from langchain_community.agent_toolkits.json.toolkit import JsonToolkit +from langchain_community.agent_toolkits.openapi.prompt import DESCRIPTION +from langchain_community.tools import BaseTool +from langchain_community.tools.json.tool import JsonSpec +from langchain_community.tools.requests.tool import ( + RequestsDeleteTool, + RequestsGetTool, + RequestsPatchTool, + RequestsPostTool, + RequestsPutTool, +) +from langchain_community.utilities.requests import TextRequestsWrapper + + +class RequestsToolkit(BaseToolkit): + """Toolkit for making REST requests. + + *Security Note*: This toolkit contains tools to make GET, POST, PATCH, PUT, + and DELETE requests to an API. + + Exercise care in who is allowed to use this toolkit. If exposing + to end users, consider that users will be able to make arbitrary + requests on behalf of the server hosting the code. For example, + users could ask the server to make a request to a private API + that is only accessible from the server. + + Control access to who can submit issue requests using this toolkit and + what network access it has. + + See https://python.langchain.com/docs/security for more information. + """ + + requests_wrapper: TextRequestsWrapper + + def get_tools(self) -> List[BaseTool]: + """Return a list of tools.""" + return [ + RequestsGetTool(requests_wrapper=self.requests_wrapper), + RequestsPostTool(requests_wrapper=self.requests_wrapper), + RequestsPatchTool(requests_wrapper=self.requests_wrapper), + RequestsPutTool(requests_wrapper=self.requests_wrapper), + RequestsDeleteTool(requests_wrapper=self.requests_wrapper), + ] + + +class OpenAPIToolkit(BaseToolkit): + """Toolkit for interacting with an OpenAPI API. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by creating, deleting, or updating, + reading underlying data. + + For example, this toolkit can be used to delete data exposed via + an OpenAPI compliant API. + """ + + json_agent: Any + requests_wrapper: TextRequestsWrapper + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + json_agent_tool = Tool( + name="json_explorer", + func=self.json_agent.run, + description=DESCRIPTION, + ) + request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper) + return [*request_toolkit.get_tools(), json_agent_tool] + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + json_spec: JsonSpec, + requests_wrapper: TextRequestsWrapper, + **kwargs: Any, + ) -> OpenAPIToolkit: + """Create json agent from llm, then initialize.""" + json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs) + return cls(json_agent=json_agent, requests_wrapper=requests_wrapper) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/base.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/base.py new file mode 100644 index 00000000000..ef70ee7b43e --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/base.py @@ -0,0 +1,68 @@ +"""Power BI agent.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.powerbi.prompt import ( + POWERBI_PREFIX, + POWERBI_SUFFIX, +) +from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit +from langchain_community.utilities.powerbi import PowerBIDataset + +if TYPE_CHECKING: + from langchain.agents import AgentExecutor + + +def create_pbi_agent( + llm: BaseLanguageModel, + toolkit: Optional[PowerBIToolkit] = None, + powerbi: Optional[PowerBIDataset] = None, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = POWERBI_PREFIX, + suffix: str = POWERBI_SUFFIX, + format_instructions: Optional[str] = None, + examples: Optional[str] = None, + input_variables: Optional[List[str]] = None, + top_k: int = 10, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a Power BI agent from an LLM and tools.""" + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.agents import AgentExecutor + from langchain.chains.llm import LLMChain + if toolkit is None: + if powerbi is None: + raise ValueError("Must provide either a toolkit or powerbi dataset") + toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples) + tools = toolkit.get_tools() + tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names + prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {} + agent = ZeroShotAgent( + llm_chain=LLMChain( + llm=llm, + prompt=ZeroShotAgent.create_prompt( + tools, + prefix=prefix.format(top_k=top_k).format(tables=tables), + suffix=suffix, + input_variables=input_variables, + **prompt_params, + ), + callback_manager=callback_manager, # type: ignore + verbose=verbose, + ), + allowed_tools=[tool.name for tool in tools], + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py new file mode 100644 index 00000000000..0171e037135 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py @@ -0,0 +1,69 @@ +"""Power BI agent.""" +from __future__ import annotations +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models.chat_models import BaseChatModel + +from langchain_community.agent_toolkits.powerbi.prompt import ( + POWERBI_CHAT_PREFIX, + POWERBI_CHAT_SUFFIX, +) +from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit +from langchain_community.utilities.powerbi import PowerBIDataset + +if TYPE_CHECKING: + from langchain.agents import AgentExecutor + from langchain.agents.agent import AgentOutputParser + from langchain.memory.chat_memory import BaseChatMemory + + +def create_pbi_chat_agent( + llm: BaseChatModel, + toolkit: Optional[PowerBIToolkit] = None, + powerbi: Optional[PowerBIDataset] = None, + callback_manager: Optional[BaseCallbackManager] = None, + output_parser: Optional[AgentOutputParser] = None, + prefix: str = POWERBI_CHAT_PREFIX, + suffix: str = POWERBI_CHAT_SUFFIX, + examples: Optional[str] = None, + input_variables: Optional[List[str]] = None, + memory: Optional[BaseChatMemory] = None, + top_k: int = 10, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a Power BI agent from a Chat LLM and tools. + + If you supply only a toolkit and no Power BI dataset, the same LLM is used for both. + """ + from langchain.agents import AgentExecutor + from langchain.agents.conversational_chat.base import ConversationalChatAgent + from langchain.memory import ConversationBufferMemory + if toolkit is None: + if powerbi is None: + raise ValueError("Must provide either a toolkit or powerbi dataset") + toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples) + tools = toolkit.get_tools() + tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names + agent = ConversationalChatAgent.from_llm_and_tools( + llm=llm, + tools=tools, + system_message=prefix.format(top_k=top_k).format(tables=tables), + human_message=suffix, + input_variables=input_variables, + callback_manager=callback_manager, + output_parser=output_parser, + verbose=verbose, + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + memory=memory + or ConversationBufferMemory(memory_key="chat_history", return_messages=True), + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py new file mode 100644 index 00000000000..bf4969b889c --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py @@ -0,0 +1,106 @@ +"""Toolkit for interacting with a Power BI dataset.""" +from __future__ import annotations +from typing import List, Optional, Union, TYPE_CHECKING + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain_core.pydantic_v1 import Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.powerbi.prompt import ( + QUESTION_TO_QUERY_BASE, + SINGLE_QUESTION_TO_QUERY, + USER_INPUT, +) +from langchain_community.tools.powerbi.tool import ( + InfoPowerBITool, + ListPowerBITool, + QueryPowerBITool, +) +from langchain_community.utilities.powerbi import PowerBIDataset + +if TYPE_CHECKING: + from langchain.chains.llm import LLMChain + + +class PowerBIToolkit(BaseToolkit): + """Toolkit for interacting with Power BI dataset. + + *Security Note*: This toolkit interacts with an external service. + + Control access to who can use this toolkit. + + Make sure that the capabilities given by this toolkit to the calling + code are appropriately scoped to the application. + + See https://python.langchain.com/docs/security for more information. + """ + + powerbi: PowerBIDataset = Field(exclude=True) + llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True) + examples: Optional[str] = None + max_iterations: int = 5 + callback_manager: Optional[BaseCallbackManager] = None + output_token_limit: Optional[int] = None + tiktoken_model_name: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + QueryPowerBITool( + llm_chain=self._get_chain(), + powerbi=self.powerbi, + examples=self.examples, + max_iterations=self.max_iterations, + output_token_limit=self.output_token_limit, + tiktoken_model_name=self.tiktoken_model_name, + ), + InfoPowerBITool(powerbi=self.powerbi), + ListPowerBITool(powerbi=self.powerbi), + ] + + def _get_chain(self) -> LLMChain: + """Construct the chain based on the callback manager and model type.""" + from langchain.chains.llm import LLMChain + if isinstance(self.llm, BaseLanguageModel): + return LLMChain( + llm=self.llm, + callback_manager=self.callback_manager + if self.callback_manager + else None, + prompt=PromptTemplate( + template=SINGLE_QUESTION_TO_QUERY, + input_variables=["tool_input", "tables", "schemas", "examples"], + ), + ) + + system_prompt = SystemMessagePromptTemplate( + prompt=PromptTemplate( + template=QUESTION_TO_QUERY_BASE, + input_variables=["tables", "schemas", "examples"], + ) + ) + human_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template=USER_INPUT, + input_variables=["tool_input"], + ) + ) + return LLMChain( + llm=self.llm, + callback_manager=self.callback_manager if self.callback_manager else None, + prompt=ChatPromptTemplate.from_messages([system_prompt, human_prompt]), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/spark_sql/base.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/spark_sql/base.py new file mode 100644 index 00000000000..8b5cd9d9078 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/spark_sql/base.py @@ -0,0 +1,64 @@ +"""Spark SQL agent.""" +from __future__ import annotations +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from langchain_core.callbacks import BaseCallbackManager, Callbacks +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX +from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def create_spark_sql_agent( + llm: BaseLanguageModel, + toolkit: SparkSQLToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, + prefix: str = SQL_PREFIX, + suffix: str = SQL_SUFFIX, + format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, + top_k: int = 10, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a Spark SQL agent from an LLM and tools.""" + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + tools = toolkit.get_tools() + prefix = prefix.format(top_k=top_k) + prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {} + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + **prompt_params, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + callbacks=callbacks, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + callbacks=callbacks, + verbose=verbose, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/sql/base.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/sql/base.py new file mode 100644 index 00000000000..bab14e6497a --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/sql/base.py @@ -0,0 +1,102 @@ +"""SQL agent.""" +from __future__ import annotations +from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel +from langchain_core.messages import AIMessage, SystemMessage +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) + +from langchain_community.agent_toolkits.sql.prompt import ( + SQL_FUNCTIONS_SUFFIX, + SQL_PREFIX, + SQL_SUFFIX, +) +from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain_community.tools import BaseTool + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + from langchain.agents.agent_types import AgentType + + +def create_sql_agent( + llm: BaseLanguageModel, + toolkit: SQLDatabaseToolkit, + agent_type: Optional[AgentType] = None, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = SQL_PREFIX, + suffix: Optional[str] = None, + format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, + top_k: int = 10, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + extra_tools: Sequence[BaseTool] = (), + **kwargs: Any, +) -> AgentExecutor: + """Construct an SQL agent from an LLM and tools.""" + from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent + from langchain.agents.agent_types import AgentType + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent + from langchain.chains.llm import LLMChain + agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION + tools = toolkit.get_tools() + list(extra_tools) + prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k) + agent: BaseSingleActionAgent + + if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: + prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {} + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix or SQL_SUFFIX, + input_variables=input_variables, + **prompt_params, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + + elif agent_type == AgentType.OPENAI_FUNCTIONS: + messages = [ + SystemMessage(content=prefix), + HumanMessagePromptTemplate.from_template("{input}"), + AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + input_variables = ["input", "agent_scratchpad"] + _prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages) + + agent = OpenAIFunctionsAgent( + llm=llm, + prompt=_prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) + else: + raise ValueError(f"Agent type {agent_type} not supported at the moment.") + + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/agent_toolkits/vectorstore/base.py b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/vectorstore/base.py new file mode 100644 index 00000000000..3e25a06b306 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/agent_toolkits/vectorstore/base.py @@ -0,0 +1,103 @@ +"""VectorStore agent.""" +from __future__ import annotations +from typing import Any, Dict, Optional, TYPE_CHECKING + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.vectorstore.prompt import PREFIX, ROUTER_PREFIX +from langchain_community.agent_toolkits.vectorstore.toolkit import ( + VectorStoreRouterToolkit, + VectorStoreToolkit, +) + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def create_vectorstore_agent( + llm: BaseLanguageModel, + toolkit: VectorStoreToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = PREFIX, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a VectorStore agent from an LLM and tools. + + Args: + llm (BaseLanguageModel): LLM that will be used by the agent + toolkit (VectorStoreToolkit): Set of tools for the agent + callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ] + prefix (str, optional): The prefix prompt for the agent. If not provided uses default PREFIX. + verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ] + agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ] + **kwargs: Additional named parameters to pass to the ZeroShotAgent. + + Returns: + AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response + """ # noqa: E501 + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + tools = toolkit.get_tools() + prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) + + +def create_vectorstore_router_agent( + llm: BaseLanguageModel, + toolkit: VectorStoreRouterToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = ROUTER_PREFIX, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a VectorStore router agent from an LLM and tools. + + Args: + llm (BaseLanguageModel): LLM that will be used by the agent + toolkit (VectorStoreRouterToolkit): Set of tools for the agent which have routing capability with multiple vector stores + callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ] + prefix (str, optional): The prefix prompt for the router agent. If not provided uses default ROUTER_PREFIX. + verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ] + agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ] + **kwargs: Additional named parameters to pass to the ZeroShotAgent. + + Returns: + AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response. + """ # noqa: E501 + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + tools = toolkit.get_tools() + prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/.scripts/community_split/libs/community/langchain_community/callbacks/__init__.py b/.scripts/community_split/libs/community/langchain_community/callbacks/__init__.py new file mode 100644 index 00000000000..6016e8304d7 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/callbacks/__init__.py @@ -0,0 +1,66 @@ +"""**Callback handlers** allow listening to events in LangChain. + +**Class hierarchy:** + +.. code-block:: + + BaseCallbackHandler --> CallbackHandler # Example: AimCallbackHandler +""" + +from langchain_community.callbacks.aim_callback import AimCallbackHandler +from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler +from langchain_community.callbacks.arize_callback import ArizeCallbackHandler +from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler +from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler +from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler +from langchain_community.callbacks.context_callback import ContextCallbackHandler +from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler +from langchain_community.callbacks.human import HumanApprovalCallbackHandler +from langchain_community.callbacks.infino_callback import InfinoCallbackHandler +from langchain_community.callbacks.labelstudio_callback import ( + LabelStudioCallbackHandler, +) +from langchain_community.callbacks.llmonitor_callback import LLMonitorCallbackHandler +from langchain_community.callbacks.manager import ( + get_openai_callback, + wandb_tracing_enabled, +) +from langchain_community.callbacks.mlflow_callback import MlflowCallbackHandler +from langchain_community.callbacks.openai_info import OpenAICallbackHandler +from langchain_community.callbacks.promptlayer_callback import ( + PromptLayerCallbackHandler, +) +from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler +from langchain_community.callbacks.streamlit import ( + LLMThoughtLabeler, + StreamlitCallbackHandler, +) +from langchain_community.callbacks.trubrics_callback import TrubricsCallbackHandler +from langchain_community.callbacks.wandb_callback import WandbCallbackHandler +from langchain_community.callbacks.whylabs_callback import WhyLabsCallbackHandler + +__all__ = [ + "AimCallbackHandler", + "ArgillaCallbackHandler", + "ArizeCallbackHandler", + "PromptLayerCallbackHandler", + "ArthurCallbackHandler", + "ClearMLCallbackHandler", + "CometCallbackHandler", + "ContextCallbackHandler", + "HumanApprovalCallbackHandler", + "InfinoCallbackHandler", + "MlflowCallbackHandler", + "LLMonitorCallbackHandler", + "OpenAICallbackHandler", + "LLMThoughtLabeler", + "StreamlitCallbackHandler", + "WandbCallbackHandler", + "WhyLabsCallbackHandler", + "get_openai_callback", + "wandb_tracing_enabled", + "FlyteCallbackHandler", + "SageMakerCallbackHandler", + "LabelStudioCallbackHandler", + "TrubricsCallbackHandler", +] diff --git a/.scripts/community_split/libs/community/langchain_community/callbacks/manager.py b/.scripts/community_split/libs/community/langchain_community/callbacks/manager.py new file mode 100644 index 00000000000..196afed3d0f --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/callbacks/manager.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import logging +from contextlib import contextmanager +from contextvars import ContextVar +from typing import ( + Generator, + Optional, +) + +from langchain_core.tracers.context import register_configure_hook + +from langchain_community.callbacks.openai_info import OpenAICallbackHandler +from langchain_community.callbacks.tracers.wandb import WandbTracer + +logger = logging.getLogger(__name__) + +openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( + "openai_callback", default=None +) +wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( # noqa: E501 + "tracing_wandb_callback", default=None +) + +register_configure_hook(openai_callback_var, True) +register_configure_hook( + wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING" +) + + +@contextmanager +def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: + """Get the OpenAI callback handler in a context manager. + which conveniently exposes token and cost information. + + Returns: + OpenAICallbackHandler: The OpenAI callback handler. + + Example: + >>> with get_openai_callback() as cb: + ... # Use the OpenAI callback handler + """ + cb = OpenAICallbackHandler() + openai_callback_var.set(cb) + yield cb + openai_callback_var.set(None) + + +@contextmanager +def wandb_tracing_enabled( + session_name: str = "default", +) -> Generator[None, None, None]: + """Get the WandbTracer in a context manager. + + Args: + session_name (str, optional): The name of the session. + Defaults to "default". + + Returns: + None + + Example: + >>> with wandb_tracing_enabled() as session: + ... # Use the WandbTracer session + """ + cb = WandbTracer() + wandb_tracing_callback_var.set(cb) + yield None + wandb_tracing_callback_var.set(None) diff --git a/.scripts/community_split/libs/community/langchain_community/callbacks/tracers/__init__.py b/.scripts/community_split/libs/community/langchain_community/callbacks/tracers/__init__.py new file mode 100644 index 00000000000..8af691585a6 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/callbacks/tracers/__init__.py @@ -0,0 +1,18 @@ +"""Tracers that record execution of LangChain runs.""" + +from langchain_core.tracers.langchain import LangChainTracer +from langchain_core.tracers.langchain_v1 import LangChainTracerV1 +from langchain_core.tracers.stdout import ( + ConsoleCallbackHandler, + FunctionCallbackHandler, +) + +from langchain_community.callbacks.tracers.wandb import WandbTracer + +__all__ = [ + "ConsoleCallbackHandler", + "FunctionCallbackHandler", + "LangChainTracer", + "LangChainTracerV1", + "WandbTracer", +] diff --git a/.scripts/community_split/libs/community/langchain_community/document_loaders/base.py b/.scripts/community_split/libs/community/langchain_community/document_loaders/base.py new file mode 100644 index 00000000000..770cae74429 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/document_loaders/base.py @@ -0,0 +1,101 @@ +"""Abstract interface for document loader implementations.""" +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Iterator, List, Optional, TYPE_CHECKING + +from langchain_core.documents import Document + +from langchain_community.document_loaders.blob_loaders import Blob + +if TYPE_CHECKING: + from langchain.text_splitter import TextSplitter + + +class BaseLoader(ABC): + """Interface for Document Loader. + + Implementations should implement the lazy-loading method using generators + to avoid loading all Documents into memory at once. + + The `load` method will remain as is for backwards compatibility, but its + implementation should be just `list(self.lazy_load())`. + """ + + # Sub-classes should implement this method + # as return list(self.lazy_load()). + # This method returns a List which is materialized in memory. + @abstractmethod + def load(self) -> List[Document]: + """Load data into Document objects.""" + + def load_and_split( + self, text_splitter: Optional[TextSplitter] = None + ) -> List[Document]: + """Load Documents and split into chunks. Chunks are returned as Documents. + + Args: + text_splitter: TextSplitter instance to use for splitting documents. + Defaults to RecursiveCharacterTextSplitter. + + Returns: + List of Documents. + """ + from langchain.text_splitter import RecursiveCharacterTextSplitter + + if text_splitter is None: + _text_splitter: TextSplitter = RecursiveCharacterTextSplitter() + else: + _text_splitter = text_splitter + docs = self.load() + return _text_splitter.split_documents(docs) + + # Attention: This method will be upgraded into an abstractmethod once it's + # implemented in all the existing subclasses. + def lazy_load( + self, + ) -> Iterator[Document]: + """A lazy loader for Documents.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement lazy_load()" + ) + + +class BaseBlobParser(ABC): + """Abstract interface for blob parsers. + + A blob parser provides a way to parse raw data stored in a blob into one + or more documents. + + The parser can be composed with blob loaders, making it easy to reuse + a parser independent of how the blob was originally loaded. + """ + + @abstractmethod + def lazy_parse(self, blob: Blob) -> Iterator[Document]: + """Lazy parsing interface. + + Subclasses are required to implement this method. + + Args: + blob: Blob instance + + Returns: + Generator of documents + """ + + def parse(self, blob: Blob) -> List[Document]: + """Eagerly parse the blob into a document or documents. + + This is a convenience method for interactive development environment. + + Production applications should favor the lazy_parse method instead. + + Subclasses should generally not over-ride this parse method. + + Args: + blob: Blob instance + + Returns: + List of documents + """ + return list(self.lazy_parse(blob)) diff --git a/.scripts/community_split/libs/community/langchain_community/document_loaders/blob_loaders/file_system.py b/.scripts/community_split/libs/community/langchain_community/document_loaders/blob_loaders/file_system.py new file mode 100644 index 00000000000..0fcdd4438ee --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/document_loaders/blob_loaders/file_system.py @@ -0,0 +1,147 @@ +"""Use to load blobs from the local file system.""" +from pathlib import Path +from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union + +from langchain_community.document_loaders.blob_loaders.schema import Blob, BlobLoader + +T = TypeVar("T") + + +def _make_iterator( + length_func: Callable[[], int], show_progress: bool = False +) -> Callable[[Iterable[T]], Iterator[T]]: + """Create a function that optionally wraps an iterable in tqdm.""" + if show_progress: + try: + from tqdm.auto import tqdm + except ImportError: + raise ImportError( + "You must install tqdm to use show_progress=True." + "You can install tqdm with `pip install tqdm`." + ) + + # Make sure to provide `total` here so that tqdm can show + # a progress bar that takes into account the total number of files. + def _with_tqdm(iterable: Iterable[T]) -> Iterator[T]: + """Wrap an iterable in a tqdm progress bar.""" + return tqdm(iterable, total=length_func()) + + iterator = _with_tqdm + else: + iterator = iter # type: ignore + + return iterator + + +# PUBLIC API + + +class FileSystemBlobLoader(BlobLoader): + """Load blobs in the local file system. + + Example: + + .. code-block:: python + + from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader + loader = FileSystemBlobLoader("/path/to/directory") + for blob in loader.yield_blobs(): + print(blob) + """ # noqa: E501 + + def __init__( + self, + path: Union[str, Path], + *, + glob: str = "**/[!.]*", + exclude: Sequence[str] = (), + suffixes: Optional[Sequence[str]] = None, + show_progress: bool = False, + ) -> None: + """Initialize with a path to directory and how to glob over it. + + Args: + path: Path to directory to load from or path to file to load. + If a path to a file is provided, glob/exclude/suffixes are ignored. + glob: Glob pattern relative to the specified path + by default set to pick up all non-hidden files + exclude: patterns to exclude from results, use glob syntax + suffixes: Provide to keep only files with these suffixes + Useful when wanting to keep files with different suffixes + Suffixes must include the dot, e.g. ".txt" + show_progress: If true, will show a progress bar as the files are loaded. + This forces an iteration through all matching files + to count them prior to loading them. + + Examples: + + .. code-block:: python + from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader + + # Load a single file. + loader = FileSystemBlobLoader("/path/to/file.txt") + + # Recursively load all text files in a directory. + loader = FileSystemBlobLoader("/path/to/directory", glob="**/*.txt") + + # Recursively load all non-hidden files in a directory. + loader = FileSystemBlobLoader("/path/to/directory", glob="**/[!.]*") + + # Load all files in a directory without recursion. + loader = FileSystemBlobLoader("/path/to/directory", glob="*") + + # Recursively load all files in a directory, except for py or pyc files. + loader = FileSystemBlobLoader( + "/path/to/directory", + glob="**/*.txt", + exclude=["**/*.py", "**/*.pyc"] + ) + """ # noqa: E501 + if isinstance(path, Path): + _path = path + elif isinstance(path, str): + _path = Path(path) + else: + raise TypeError(f"Expected str or Path, got {type(path)}") + + self.path = _path.expanduser() # Expand user to handle ~ + self.glob = glob + self.suffixes = set(suffixes or []) + self.show_progress = show_progress + self.exclude = exclude + + def yield_blobs( + self, + ) -> Iterable[Blob]: + """Yield blobs that match the requested pattern.""" + iterator = _make_iterator( + length_func=self.count_matching_files, show_progress=self.show_progress + ) + + for path in iterator(self._yield_paths()): + yield Blob.from_path(path) + + def _yield_paths(self) -> Iterable[Path]: + """Yield paths that match the requested pattern.""" + if self.path.is_file(): + yield self.path + return + + paths = self.path.glob(self.glob) + for path in paths: + if self.exclude: + if any(path.match(glob) for glob in self.exclude): + continue + if path.is_file(): + if self.suffixes and path.suffix not in self.suffixes: + continue + yield path + + def count_matching_files(self) -> int: + """Count files that match the pattern without loading them.""" + # Carry out a full iteration to count the files without + # materializing anything expensive in memory. + num = 0 + for _ in self._yield_paths(): + num += 1 + return num diff --git a/.scripts/community_split/libs/community/langchain_community/document_loaders/generic.py b/.scripts/community_split/libs/community/langchain_community/document_loaders/generic.py new file mode 100644 index 00000000000..0ec6ca60bdf --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/document_loaders/generic.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Iterator, + List, + Literal, + Optional, + Sequence, + Union, +) + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader +from langchain_community.document_loaders.blob_loaders import ( + BlobLoader, + FileSystemBlobLoader, +) +from langchain_community.document_loaders.parsers.registry import get_parser + +if TYPE_CHECKING: + from langchain.text_splitter import TextSplitter + +_PathLike = Union[str, Path] + +DEFAULT = Literal["default"] + + +class GenericLoader(BaseLoader): + """Generic Document Loader. + + A generic document loader that allows combining an arbitrary blob loader with + a blob parser. + + Examples: + + Parse a specific PDF file: + + .. code-block:: python + + from langchain_community.document_loaders import GenericLoader + from langchain_community.document_loaders.parsers.pdf import PyPDFParser + + # Recursively load all text files in a directory. + loader = GenericLoader.from_filesystem( + "my_lovely_pdf.pdf", + parser=PyPDFParser() + ) + + .. code-block:: python + + from langchain_community.document_loaders import GenericLoader + from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader + + + loader = GenericLoader.from_filesystem( + path="path/to/directory", + glob="**/[!.]*", + suffixes=[".pdf"], + show_progress=True, + ) + + docs = loader.lazy_load() + next(docs) + + Example instantiations to change which files are loaded: + + .. code-block:: python + + # Recursively load all text files in a directory. + loader = GenericLoader.from_filesystem("/path/to/dir", glob="**/*.txt") + + # Recursively load all non-hidden files in a directory. + loader = GenericLoader.from_filesystem("/path/to/dir", glob="**/[!.]*") + + # Load all files in a directory without recursion. + loader = GenericLoader.from_filesystem("/path/to/dir", glob="*") + + Example instantiations to change which parser is used: + + .. code-block:: python + + from langchain_community.document_loaders.parsers.pdf import PyPDFParser + + # Recursively load all text files in a directory. + loader = GenericLoader.from_filesystem( + "/path/to/dir", + glob="**/*.pdf", + parser=PyPDFParser() + ) + + """ # noqa: E501 + + def __init__( + self, + blob_loader: BlobLoader, + blob_parser: BaseBlobParser, + ) -> None: + """A generic document loader. + + Args: + blob_loader: A blob loader which knows how to yield blobs + blob_parser: A blob parser which knows how to parse blobs into documents + """ + self.blob_loader = blob_loader + self.blob_parser = blob_parser + + def lazy_load( + self, + ) -> Iterator[Document]: + """Load documents lazily. Use this when working at a large scale.""" + for blob in self.blob_loader.yield_blobs(): + yield from self.blob_parser.lazy_parse(blob) + + def load(self) -> List[Document]: + """Load all documents.""" + return list(self.lazy_load()) + + def load_and_split( + self, text_splitter: Optional[TextSplitter] = None + ) -> List[Document]: + """Load all documents and split them into sentences.""" + raise NotImplementedError( + "Loading and splitting is not yet implemented for generic loaders. " + "When they will be implemented they will be added via the initializer. " + "This method should not be used going forward." + ) + + @classmethod + def from_filesystem( + cls, + path: _PathLike, + *, + glob: str = "**/[!.]*", + exclude: Sequence[str] = (), + suffixes: Optional[Sequence[str]] = None, + show_progress: bool = False, + parser: Union[DEFAULT, BaseBlobParser] = "default", + parser_kwargs: Optional[dict] = None, + ) -> GenericLoader: + """Create a generic document loader using a filesystem blob loader. + + Args: + path: The path to the directory to load documents from OR the path to a + single file to load. If this is a file, glob, exclude, suffixes + will be ignored. + glob: The glob pattern to use to find documents. + suffixes: The suffixes to use to filter documents. If None, all files + matching the glob will be loaded. + exclude: A list of patterns to exclude from the loader. + show_progress: Whether to show a progress bar or not (requires tqdm). + Proxies to the file system loader. + parser: A blob parser which knows how to parse blobs into documents, + will instantiate a default parser if not provided. + The default can be overridden by either passing a parser or + setting the class attribute `blob_parser` (the latter + should be used with inheritance). + parser_kwargs: Keyword arguments to pass to the parser. + + Returns: + A generic document loader. + """ + blob_loader = FileSystemBlobLoader( + path, + glob=glob, + exclude=exclude, + suffixes=suffixes, + show_progress=show_progress, + ) + if isinstance(parser, str): + if parser == "default": + try: + # If there is an implementation of get_parser on the class, use it. + blob_parser = cls.get_parser(**(parser_kwargs or {})) + except NotImplementedError: + # if not then use the global registry. + blob_parser = get_parser(parser) + else: + blob_parser = get_parser(parser) + else: + blob_parser = parser + return cls(blob_loader, blob_parser) + + @staticmethod + def get_parser(**kwargs: Any) -> BaseBlobParser: + """Override this method to associate a default parser with the class.""" + raise NotImplementedError() diff --git a/.scripts/community_split/libs/community/langchain_community/document_loaders/parsers/generic.py b/.scripts/community_split/libs/community/langchain_community/document_loaders/parsers/generic.py new file mode 100644 index 00000000000..6b6b91b93ee --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/document_loaders/parsers/generic.py @@ -0,0 +1,70 @@ +"""Code for generic / auxiliary parsers. + +This module contains some logic to help assemble more sophisticated parsers. +""" +from typing import Iterator, Mapping, Optional + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseBlobParser +from langchain_community.document_loaders.blob_loaders.schema import Blob + + +class MimeTypeBasedParser(BaseBlobParser): + """Parser that uses `mime`-types to parse a blob. + + This parser is useful for simple pipelines where the mime-type is sufficient + to determine how to parse a blob. + + To use, configure handlers based on mime-types and pass them to the initializer. + + Example: + + .. code-block:: python + + from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser + + parser = MimeTypeBasedParser( + handlers={ + "application/pdf": ..., + }, + fallback_parser=..., + ) + """ # noqa: E501 + + def __init__( + self, + handlers: Mapping[str, BaseBlobParser], + *, + fallback_parser: Optional[BaseBlobParser] = None, + ) -> None: + """Define a parser that uses mime-types to determine how to parse a blob. + + Args: + handlers: A mapping from mime-types to functions that take a blob, parse it + and return a document. + fallback_parser: A fallback_parser parser to use if the mime-type is not + found in the handlers. If provided, this parser will be + used to parse blobs with all mime-types not found in + the handlers. + If not provided, a ValueError will be raised if the + mime-type is not found in the handlers. + """ + self.handlers = handlers + self.fallback_parser = fallback_parser + + def lazy_parse(self, blob: Blob) -> Iterator[Document]: + """Load documents from a blob.""" + mimetype = blob.mimetype + + if mimetype is None: + raise ValueError(f"{blob} does not have a mimetype.") + + if mimetype in self.handlers: + handler = self.handlers[mimetype] + yield from handler.lazy_parse(blob) + else: + if self.fallback_parser is not None: + yield from self.fallback_parser.lazy_parse(blob) + else: + raise ValueError(f"Unsupported mime type: {mimetype}") diff --git a/.scripts/community_split/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py b/.scripts/community_split/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py new file mode 100644 index 00000000000..b87ca900657 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseBlobParser +from langchain_community.document_loaders.blob_loaders import Blob +from langchain_community.document_loaders.parsers.language.cobol import CobolSegmenter +from langchain_community.document_loaders.parsers.language.javascript import ( + JavaScriptSegmenter, +) +from langchain_community.document_loaders.parsers.language.python import PythonSegmenter + +if TYPE_CHECKING: + from langchain.text_splitter import Language + +try: + from langchain.text_splitter import Language + LANGUAGE_EXTENSIONS: Dict[str, str] = { + "py": Language.PYTHON, + "js": Language.JS, + "cobol": Language.COBOL, + } + + LANGUAGE_SEGMENTERS: Dict[str, Any] = { + Language.PYTHON: PythonSegmenter, + Language.JS: JavaScriptSegmenter, + Language.COBOL: CobolSegmenter, + } +except ImportError: + LANGUAGE_EXTENSIONS = {} + LANGUAGE_SEGMENTERS = {} + + +class LanguageParser(BaseBlobParser): + """Parse using the respective programming language syntax. + + Each top-level function and class in the code is loaded into separate documents. + Furthermore, an extra document is generated, containing the remaining top-level code + that excludes the already segmented functions and classes. + + This approach can potentially improve the accuracy of QA models over source code. + + Currently, the supported languages for code parsing are Python and JavaScript. + + The language used for parsing can be configured, along with the minimum number of + lines required to activate the splitting based on syntax. + + Examples: + + .. code-block:: python + + from langchain.text_splitter.Language + from langchain_community.document_loaders.generic import GenericLoader + from langchain_community.document_loaders.parsers import LanguageParser + + loader = GenericLoader.from_filesystem( + "./code", + glob="**/*", + suffixes=[".py", ".js"], + parser=LanguageParser() + ) + docs = loader.load() + + Example instantiations to manually select the language: + + .. code-block:: python + + from langchain.text_splitter import Language + + loader = GenericLoader.from_filesystem( + "./code", + glob="**/*", + suffixes=[".py"], + parser=LanguageParser(language=Language.PYTHON) + ) + + Example instantiations to set number of lines threshold: + + .. code-block:: python + + loader = GenericLoader.from_filesystem( + "./code", + glob="**/*", + suffixes=[".py"], + parser=LanguageParser(parser_threshold=200) + ) + """ + + def __init__(self, language: Optional[Language] = None, parser_threshold: int = 0): + """ + Language parser that split code using the respective language syntax. + + Args: + language: If None (default), it will try to infer language from source. + parser_threshold: Minimum lines needed to activate parsing (0 by default). + """ + self.language = language + self.parser_threshold = parser_threshold + + def lazy_parse(self, blob: Blob) -> Iterator[Document]: + code = blob.as_string() + + language = self.language or ( + LANGUAGE_EXTENSIONS.get(blob.source.rsplit(".", 1)[-1]) + if isinstance(blob.source, str) + else None + ) + + if language is None: + yield Document( + page_content=code, + metadata={ + "source": blob.source, + }, + ) + return + + if self.parser_threshold >= len(code.splitlines()): + yield Document( + page_content=code, + metadata={ + "source": blob.source, + "language": language, + }, + ) + return + + self.Segmenter = LANGUAGE_SEGMENTERS[language] + segmenter = self.Segmenter(blob.as_string()) + if not segmenter.is_valid(): + yield Document( + page_content=code, + metadata={ + "source": blob.source, + }, + ) + return + + for functions_classes in segmenter.extract_functions_classes(): + yield Document( + page_content=functions_classes, + metadata={ + "source": blob.source, + "content_type": "functions_classes", + "language": language, + }, + ) + yield Document( + page_content=segmenter.simplify_code(), + metadata={ + "source": blob.source, + "content_type": "simplified_code", + "language": language, + }, + ) diff --git a/.scripts/community_split/libs/community/langchain_community/document_loaders/telegram.py b/.scripts/community_split/libs/community/langchain_community/document_loaders/telegram.py new file mode 100644 index 00000000000..b4b796b6abc --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/document_loaders/telegram.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import asyncio +import json +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseLoader + +if TYPE_CHECKING: + import pandas as pd + from telethon.hints import EntityLike + + +def concatenate_rows(row: dict) -> str: + """Combine message information in a readable format ready to be used.""" + date = row["date"] + sender = row["from"] + text = row["text"] + return f"{sender} on {date}: {text}\n\n" + + +class TelegramChatFileLoader(BaseLoader): + """Load from `Telegram chat` dump.""" + + def __init__(self, path: str): + """Initialize with a path.""" + self.file_path = path + + def load(self) -> List[Document]: + """Load documents.""" + p = Path(self.file_path) + + with open(p, encoding="utf8") as f: + d = json.load(f) + + text = "".join( + concatenate_rows(message) + for message in d["messages"] + if message["type"] == "message" and isinstance(message["text"], str) + ) + metadata = {"source": str(p)} + + return [Document(page_content=text, metadata=metadata)] + + +def text_to_docs(text: Union[str, List[str]]) -> List[Document]: + """Convert a string or list of strings to a list of Documents with metadata.""" + from langchain.text_splitter import RecursiveCharacterTextSplitter + if isinstance(text, str): + # Take a single string as one page + text = [text] + page_docs = [Document(page_content=page) for page in text] + + # Add page numbers as metadata + for i, doc in enumerate(page_docs): + doc.metadata["page"] = i + 1 + + # Split pages into chunks + doc_chunks = [] + + for doc in page_docs: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=800, + separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""], + chunk_overlap=20, + ) + chunks = text_splitter.split_text(doc.page_content) + for i, chunk in enumerate(chunks): + doc = Document( + page_content=chunk, metadata={"page": doc.metadata["page"], "chunk": i} + ) + # Add sources a metadata + doc.metadata["source"] = f"{doc.metadata['page']}-{doc.metadata['chunk']}" + doc_chunks.append(doc) + return doc_chunks + + +class TelegramChatApiLoader(BaseLoader): + """Load `Telegram` chat json directory dump.""" + + def __init__( + self, + chat_entity: Optional[EntityLike] = None, + api_id: Optional[int] = None, + api_hash: Optional[str] = None, + username: Optional[str] = None, + file_path: str = "telegram_data.json", + ): + """Initialize with API parameters. + + Args: + chat_entity: The chat entity to fetch data from. + api_id: The API ID. + api_hash: The API hash. + username: The username. + file_path: The file path to save the data to. Defaults to + "telegram_data.json". + """ + self.chat_entity = chat_entity + self.api_id = api_id + self.api_hash = api_hash + self.username = username + self.file_path = file_path + + async def fetch_data_from_telegram(self) -> None: + """Fetch data from Telegram API and save it as a JSON file.""" + from telethon.sync import TelegramClient + + data = [] + async with TelegramClient(self.username, self.api_id, self.api_hash) as client: + async for message in client.iter_messages(self.chat_entity): + is_reply = message.reply_to is not None + reply_to_id = message.reply_to.reply_to_msg_id if is_reply else None + data.append( + { + "sender_id": message.sender_id, + "text": message.text, + "date": message.date.isoformat(), + "message.id": message.id, + "is_reply": is_reply, + "reply_to_id": reply_to_id, + } + ) + + with open(self.file_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + def _get_message_threads(self, data: pd.DataFrame) -> dict: + """Create a dictionary of message threads from the given data. + + Args: + data (pd.DataFrame): A DataFrame containing the conversation \ + data with columns: + - message.sender_id + - text + - date + - message.id + - is_reply + - reply_to_id + + Returns: + dict: A dictionary where the key is the parent message ID and \ + the value is a list of message IDs in ascending order. + """ + + def find_replies(parent_id: int, reply_data: pd.DataFrame) -> List[int]: + """ + Recursively find all replies to a given parent message ID. + + Args: + parent_id (int): The parent message ID. + reply_data (pd.DataFrame): A DataFrame containing reply messages. + + Returns: + list: A list of message IDs that are replies to the parent message ID. + """ + # Find direct replies to the parent message ID + direct_replies = reply_data[reply_data["reply_to_id"] == parent_id][ + "message.id" + ].tolist() + + # Recursively find replies to the direct replies + all_replies = [] + for reply_id in direct_replies: + all_replies += [reply_id] + find_replies(reply_id, reply_data) + + return all_replies + + # Filter out parent messages + parent_messages = data[~data["is_reply"]] + + # Filter out reply messages and drop rows with NaN in 'reply_to_id' + reply_messages = data[data["is_reply"]].dropna(subset=["reply_to_id"]) + + # Convert 'reply_to_id' to integer + reply_messages["reply_to_id"] = reply_messages["reply_to_id"].astype(int) + + # Create a dictionary of message threads with parent message IDs as keys and \ + # lists of reply message IDs as values + message_threads = { + parent_id: [parent_id] + find_replies(parent_id, reply_messages) + for parent_id in parent_messages["message.id"] + } + + return message_threads + + def _combine_message_texts( + self, message_threads: Dict[int, List[int]], data: pd.DataFrame + ) -> str: + """ + Combine the message texts for each parent message ID based \ + on the list of message threads. + + Args: + message_threads (dict): A dictionary where the key is the parent message \ + ID and the value is a list of message IDs in ascending order. + data (pd.DataFrame): A DataFrame containing the conversation data: + - message.sender_id + - text + - date + - message.id + - is_reply + - reply_to_id + + Returns: + str: A combined string of message texts sorted by date. + """ + combined_text = "" + + # Iterate through sorted parent message IDs + for parent_id, message_ids in message_threads.items(): + # Get the message texts for the message IDs and sort them by date + message_texts = ( + data[data["message.id"].isin(message_ids)] + .sort_values(by="date")["text"] + .tolist() + ) + message_texts = [str(elem) for elem in message_texts] + + # Combine the message texts + combined_text += " ".join(message_texts) + ".\n" + + return combined_text.strip() + + def load(self) -> List[Document]: + """Load documents.""" + + if self.chat_entity is not None: + try: + import nest_asyncio + + nest_asyncio.apply() + asyncio.run(self.fetch_data_from_telegram()) + except ImportError: + raise ImportError( + """`nest_asyncio` package not found. + please install with `pip install nest_asyncio` + """ + ) + + p = Path(self.file_path) + + with open(p, encoding="utf8") as f: + d = json.load(f) + try: + import pandas as pd + except ImportError: + raise ImportError( + """`pandas` package not found. + please install with `pip install pandas` + """ + ) + normalized_messages = pd.json_normalize(d) + df = pd.DataFrame(normalized_messages) + + message_threads = self._get_message_threads(df) + combined_texts = self._combine_message_texts(message_threads, df) + + return text_to_docs(combined_texts) diff --git a/.scripts/community_split/libs/community/langchain_community/document_transformers/beautiful_soup_transformer.py b/.scripts/community_split/libs/community/langchain_community/document_transformers/beautiful_soup_transformer.py new file mode 100644 index 00000000000..0e2b5d394c2 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/document_transformers/beautiful_soup_transformer.py @@ -0,0 +1,149 @@ +from typing import Any, Iterator, List, Sequence, cast + +from langchain_core.documents import BaseDocumentTransformer, Document + + +class BeautifulSoupTransformer(BaseDocumentTransformer): + """Transform HTML content by extracting specific tags and removing unwanted ones. + + Example: + .. code-block:: python + + from langchain_community.document_transformers import BeautifulSoupTransformer + + bs4_transformer = BeautifulSoupTransformer() + docs_transformed = bs4_transformer.transform_documents(docs) + """ # noqa: E501 + + def __init__(self) -> None: + """ + Initialize the transformer. + + This checks if the BeautifulSoup4 package is installed. + If not, it raises an ImportError. + """ + try: + import bs4 # noqa:F401 + except ImportError: + raise ImportError( + "BeautifulSoup4 is required for BeautifulSoupTransformer. " + "Please install it with `pip install beautifulsoup4`." + ) + + def transform_documents( + self, + documents: Sequence[Document], + unwanted_tags: List[str] = ["script", "style"], + tags_to_extract: List[str] = ["p", "li", "div", "a"], + remove_lines: bool = True, + **kwargs: Any, + ) -> Sequence[Document]: + """ + Transform a list of Document objects by cleaning their HTML content. + + Args: + documents: A sequence of Document objects containing HTML content. + unwanted_tags: A list of tags to be removed from the HTML. + tags_to_extract: A list of tags whose content will be extracted. + remove_lines: If set to True, unnecessary lines will be + removed from the HTML content. + + Returns: + A sequence of Document objects with transformed content. + """ + for doc in documents: + cleaned_content = doc.page_content + + cleaned_content = self.remove_unwanted_tags(cleaned_content, unwanted_tags) + + cleaned_content = self.extract_tags(cleaned_content, tags_to_extract) + + if remove_lines: + cleaned_content = self.remove_unnecessary_lines(cleaned_content) + + doc.page_content = cleaned_content + + return documents + + @staticmethod + def remove_unwanted_tags(html_content: str, unwanted_tags: List[str]) -> str: + """ + Remove unwanted tags from a given HTML content. + + Args: + html_content: The original HTML content string. + unwanted_tags: A list of tags to be removed from the HTML. + + Returns: + A cleaned HTML string with unwanted tags removed. + """ + from bs4 import BeautifulSoup + + soup = BeautifulSoup(html_content, "html.parser") + for tag in unwanted_tags: + for element in soup.find_all(tag): + element.decompose() + return str(soup) + + @staticmethod + def extract_tags(html_content: str, tags: List[str]) -> str: + """ + Extract specific tags from a given HTML content. + + Args: + html_content: The original HTML content string. + tags: A list of tags to be extracted from the HTML. + + Returns: + A string combining the content of the extracted tags. + """ + from bs4 import BeautifulSoup + + soup = BeautifulSoup(html_content, "html.parser") + text_parts: List[str] = [] + for element in soup.find_all(): + if element.name in tags: + # Extract all navigable strings recursively from this element. + text_parts += get_navigable_strings(element) + + # To avoid duplicate text, remove all descendants from the soup. + element.decompose() + + return " ".join(text_parts) + + @staticmethod + def remove_unnecessary_lines(content: str) -> str: + """ + Clean up the content by removing unnecessary lines. + + Args: + content: A string, which may contain unnecessary lines or spaces. + + Returns: + A cleaned string with unnecessary lines removed. + """ + lines = content.split("\n") + stripped_lines = [line.strip() for line in lines] + non_empty_lines = [line for line in stripped_lines if line] + cleaned_content = " ".join(non_empty_lines) + return cleaned_content + + async def atransform_documents( + self, + documents: Sequence[Document], + **kwargs: Any, + ) -> Sequence[Document]: + raise NotImplementedError + + +def get_navigable_strings(element: Any) -> Iterator[str]: + from bs4 import NavigableString, Tag + + for child in cast(Tag, element).children: + if isinstance(child, Tag): + yield from get_navigable_strings(child) + elif isinstance(child, NavigableString): + if (element.name == "a") and (href := element.get("href")): + yield f"{child.strip()} ({href})" + else: + yield child.strip() diff --git a/.scripts/community_split/libs/community/langchain_community/document_transformers/openai_functions.py b/.scripts/community_split/libs/community/langchain_community/document_transformers/openai_functions.py new file mode 100644 index 00000000000..b796e6db87b --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/document_transformers/openai_functions.py @@ -0,0 +1,140 @@ +"""Document transformers that use OpenAI Functions models""" +from typing import Any, Dict, Optional, Sequence, Type, Union + +from langchain_core.documents import BaseDocumentTransformer, Document +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel + + +class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel): + """Extract metadata tags from document contents using OpenAI functions. + + Example: + .. code-block:: python + + from langchain_community.chat_models import ChatOpenAI + from langchain_community.document_transformers import OpenAIMetadataTagger + from langchain_core.documents import Document + + schema = { + "properties": { + "movie_title": { "type": "string" }, + "critic": { "type": "string" }, + "tone": { + "type": "string", + "enum": ["positive", "negative"] + }, + "rating": { + "type": "integer", + "description": "The number of stars the critic rated the movie" + } + }, + "required": ["movie_title", "critic", "tone"] + } + + # Must be an OpenAI model that supports functions + llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613") + tagging_chain = create_tagging_chain(schema, llm) + document_transformer = OpenAIMetadataTagger(tagging_chain=tagging_chain) + original_documents = [ + Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\nThis is the greatest movie ever made. 4 out of 5 stars."), + Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}), + ] + + enhanced_documents = document_transformer.transform_documents(original_documents) + """ # noqa: E501 + + tagging_chain: Any + """The chain used to extract metadata from each document.""" + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Automatically extract and populate metadata + for each document according to the provided schema.""" + + new_documents = [] + + for document in documents: + extracted_metadata: Dict = self.tagging_chain.run(document.page_content) # type: ignore[assignment] # noqa: E501 + new_document = Document( + page_content=document.page_content, + metadata={**extracted_metadata, **document.metadata}, + ) + new_documents.append(new_document) + return new_documents + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + + +def create_metadata_tagger( + metadata_schema: Union[Dict[str, Any], Type[BaseModel]], + llm: BaseLanguageModel, + prompt: Optional[ChatPromptTemplate] = None, + *, + tagging_chain_kwargs: Optional[Dict] = None, +) -> OpenAIMetadataTagger: + """Create a DocumentTransformer that uses an OpenAI function chain to automatically + tag documents with metadata based on their content and an input schema. + + Args: + metadata_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary + is passed in, it's assumed to already be a valid JsonSchema. + For best results, pydantic.BaseModels should have docstrings describing what + the schema represents and descriptions for the parameters. + llm: Language model to use, assumed to support the OpenAI function-calling API. + Defaults to use "gpt-3.5-turbo-0613" + prompt: BasePromptTemplate to pass to the model. + + Returns: + An LLMChain that will pass the given function to the model. + + Example: + .. code-block:: python + + from langchain_community.chat_models import ChatOpenAI + from langchain_community.document_transformers import create_metadata_tagger + from langchain_core.documents import Document + + schema = { + "properties": { + "movie_title": { "type": "string" }, + "critic": { "type": "string" }, + "tone": { + "type": "string", + "enum": ["positive", "negative"] + }, + "rating": { + "type": "integer", + "description": "The number of stars the critic rated the movie" + } + }, + "required": ["movie_title", "critic", "tone"] + } + + # Must be an OpenAI model that supports functions + llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613") + + document_transformer = create_metadata_tagger(schema, llm) + original_documents = [ + Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\nThis is the greatest movie ever made. 4 out of 5 stars."), + Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}), + ] + + enhanced_documents = document_transformer.transform_documents(original_documents) + """ # noqa: E501 + from langchain.chains.openai_functions import create_tagging_chain + metadata_schema = ( + metadata_schema + if isinstance(metadata_schema, dict) + else metadata_schema.schema() + ) + _tagging_chain_kwargs = tagging_chain_kwargs or {} + tagging_chain = create_tagging_chain( + metadata_schema, llm, prompt=prompt, **_tagging_chain_kwargs + ) + return OpenAIMetadataTagger(tagging_chain=tagging_chain) diff --git a/.scripts/community_split/libs/community/langchain_community/embeddings/__init__.py b/.scripts/community_split/libs/community/langchain_community/embeddings/__init__.py new file mode 100644 index 00000000000..ce9cfc7aa0b --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/embeddings/__init__.py @@ -0,0 +1,161 @@ +"""**Embedding models** are wrappers around embedding models +from different APIs and services. + +**Embedding models** can be LLMs or not. + +**Class hierarchy:** + +.. code-block:: + + Embeddings --> Embeddings # Examples: OpenAIEmbeddings, HuggingFaceEmbeddings +""" + + +import logging +from typing import Any + +from langchain_community.embeddings.aleph_alpha import ( + AlephAlphaAsymmetricSemanticEmbedding, + AlephAlphaSymmetricSemanticEmbedding, +) +from langchain_community.embeddings.awa import AwaEmbeddings +from langchain_community.embeddings.azure_openai import AzureOpenAIEmbeddings +from langchain_community.embeddings.baidu_qianfan_endpoint import ( + QianfanEmbeddingsEndpoint, +) +from langchain_community.embeddings.bedrock import BedrockEmbeddings +from langchain_community.embeddings.bookend import BookendEmbeddings +from langchain_community.embeddings.clarifai import ClarifaiEmbeddings +from langchain_community.embeddings.cohere import CohereEmbeddings +from langchain_community.embeddings.dashscope import DashScopeEmbeddings +from langchain_community.embeddings.databricks import DatabricksEmbeddings +from langchain_community.embeddings.deepinfra import DeepInfraEmbeddings +from langchain_community.embeddings.edenai import EdenAiEmbeddings +from langchain_community.embeddings.elasticsearch import ElasticsearchEmbeddings +from langchain_community.embeddings.embaas import EmbaasEmbeddings +from langchain_community.embeddings.ernie import ErnieEmbeddings +from langchain_community.embeddings.fake import ( + DeterministicFakeEmbedding, + FakeEmbeddings, +) +from langchain_community.embeddings.fastembed import FastEmbedEmbeddings +from langchain_community.embeddings.google_palm import GooglePalmEmbeddings +from langchain_community.embeddings.gpt4all import GPT4AllEmbeddings +from langchain_community.embeddings.gradient_ai import GradientEmbeddings +from langchain_community.embeddings.huggingface import ( + HuggingFaceBgeEmbeddings, + HuggingFaceEmbeddings, + HuggingFaceInferenceAPIEmbeddings, + HuggingFaceInstructEmbeddings, +) +from langchain_community.embeddings.huggingface_hub import HuggingFaceHubEmbeddings +from langchain_community.embeddings.infinity import InfinityEmbeddings +from langchain_community.embeddings.javelin_ai_gateway import JavelinAIGatewayEmbeddings +from langchain_community.embeddings.jina import JinaEmbeddings +from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings +from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings +from langchain_community.embeddings.localai import LocalAIEmbeddings +from langchain_community.embeddings.minimax import MiniMaxEmbeddings +from langchain_community.embeddings.mlflow import MlflowEmbeddings +from langchain_community.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings +from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings +from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings +from langchain_community.embeddings.nlpcloud import NLPCloudEmbeddings +from langchain_community.embeddings.octoai_embeddings import OctoAIEmbeddings +from langchain_community.embeddings.ollama import OllamaEmbeddings +from langchain_community.embeddings.openai import OpenAIEmbeddings +from langchain_community.embeddings.sagemaker_endpoint import ( + SagemakerEndpointEmbeddings, +) +from langchain_community.embeddings.self_hosted import SelfHostedEmbeddings +from langchain_community.embeddings.self_hosted_hugging_face import ( + SelfHostedHuggingFaceEmbeddings, + SelfHostedHuggingFaceInstructEmbeddings, +) +from langchain_community.embeddings.sentence_transformer import ( + SentenceTransformerEmbeddings, +) +from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings +from langchain_community.embeddings.tensorflow_hub import TensorflowHubEmbeddings +from langchain_community.embeddings.vertexai import VertexAIEmbeddings +from langchain_community.embeddings.voyageai import VoyageEmbeddings +from langchain_community.embeddings.xinference import XinferenceEmbeddings + +logger = logging.getLogger(__name__) + +__all__ = [ + "OpenAIEmbeddings", + "AzureOpenAIEmbeddings", + "ClarifaiEmbeddings", + "CohereEmbeddings", + "DatabricksEmbeddings", + "ElasticsearchEmbeddings", + "FastEmbedEmbeddings", + "HuggingFaceEmbeddings", + "HuggingFaceInferenceAPIEmbeddings", + "InfinityEmbeddings", + "GradientEmbeddings", + "JinaEmbeddings", + "LlamaCppEmbeddings", + "HuggingFaceHubEmbeddings", + "MlflowEmbeddings", + "MlflowAIGatewayEmbeddings", + "ModelScopeEmbeddings", + "TensorflowHubEmbeddings", + "SagemakerEndpointEmbeddings", + "HuggingFaceInstructEmbeddings", + "MosaicMLInstructorEmbeddings", + "SelfHostedEmbeddings", + "SelfHostedHuggingFaceEmbeddings", + "SelfHostedHuggingFaceInstructEmbeddings", + "FakeEmbeddings", + "DeterministicFakeEmbedding", + "AlephAlphaAsymmetricSemanticEmbedding", + "AlephAlphaSymmetricSemanticEmbedding", + "SentenceTransformerEmbeddings", + "GooglePalmEmbeddings", + "MiniMaxEmbeddings", + "VertexAIEmbeddings", + "BedrockEmbeddings", + "DeepInfraEmbeddings", + "EdenAiEmbeddings", + "DashScopeEmbeddings", + "EmbaasEmbeddings", + "OctoAIEmbeddings", + "SpacyEmbeddings", + "NLPCloudEmbeddings", + "GPT4AllEmbeddings", + "XinferenceEmbeddings", + "LocalAIEmbeddings", + "AwaEmbeddings", + "HuggingFaceBgeEmbeddings", + "ErnieEmbeddings", + "JavelinAIGatewayEmbeddings", + "OllamaEmbeddings", + "QianfanEmbeddingsEndpoint", + "JohnSnowLabsEmbeddings", + "VoyageEmbeddings", + "BookendEmbeddings", +] + + +# TODO: this is in here to maintain backwards compatibility +class HypotheticalDocumentEmbedder: + def __init__(self, *args: Any, **kwargs: Any): + logger.warning( + "Using a deprecated class. Please use " + "`from langchain.chains import HypotheticalDocumentEmbedder` instead" + ) + from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H + + return H(*args, **kwargs) # type: ignore + + @classmethod + def from_llm(cls, *args: Any, **kwargs: Any) -> Any: + logger.warning( + "Using a deprecated class. Please use " + "`from langchain.chains import HypotheticalDocumentEmbedder` instead" + ) + from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H + + return H.from_llm(*args, **kwargs) diff --git a/.scripts/community_split/libs/community/langchain_community/embeddings/huggingface.py b/.scripts/community_split/libs/community/langchain_community/embeddings/huggingface.py new file mode 100644 index 00000000000..84a568866f1 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/embeddings/huggingface.py @@ -0,0 +1,343 @@ +from typing import Any, Dict, List, Optional + +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, Field + +DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" +DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" +DEFAULT_BGE_MODEL = "BAAI/bge-large-en" +DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: " +DEFAULT_QUERY_INSTRUCTION = ( + "Represent the question for retrieving supporting documents: " +) +DEFAULT_QUERY_BGE_INSTRUCTION_EN = ( + "Represent this question for searching relevant passages: " +) +DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "δΈΊθΏ™δΈͺε₯ε­η”Ÿζˆθ‘¨η€Ίδ»₯η”¨δΊŽζ£€η΄’η›Έε…³ζ–‡η« οΌš" + + +class HuggingFaceEmbeddings(BaseModel, Embeddings): + """HuggingFace sentence_transformers embedding models. + + To use, you should have the ``sentence_transformers`` python package installed. + + Example: + .. code-block:: python + + from langchain_community.embeddings import HuggingFaceEmbeddings + + model_name = "sentence-transformers/all-mpnet-base-v2" + model_kwargs = {'device': 'cpu'} + encode_kwargs = {'normalize_embeddings': False} + hf = HuggingFaceEmbeddings( + model_name=model_name, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs + ) + """ + + client: Any #: :meta private: + model_name: str = DEFAULT_MODEL_NAME + """Model name to use.""" + cache_folder: Optional[str] = None + """Path to store models. + Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" + encode_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass when calling the `encode` method of the model.""" + multi_process: bool = False + """Run encode() on multiple GPUs.""" + + def __init__(self, **kwargs: Any): + """Initialize the sentence_transformer.""" + super().__init__(**kwargs) + try: + import sentence_transformers + + except ImportError as exc: + raise ImportError( + "Could not import sentence_transformers python package. " + "Please install it with `pip install sentence-transformers`." + ) from exc + + self.client = sentence_transformers.SentenceTransformer( + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + ) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + import sentence_transformers + + texts = list(map(lambda x: x.replace("\n", " "), texts)) + if self.multi_process: + pool = self.client.start_multi_process_pool() + embeddings = self.client.encode_multi_process(texts, pool) + sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) + else: + embeddings = self.client.encode(texts, **self.encode_kwargs) + + return embeddings.tolist() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] + + +class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): + """Wrapper around sentence_transformers embedding models. + + To use, you should have the ``sentence_transformers`` + and ``InstructorEmbedding`` python packages installed. + + Example: + .. code-block:: python + + from langchain_community.embeddings import HuggingFaceInstructEmbeddings + + model_name = "hkunlp/instructor-large" + model_kwargs = {'device': 'cpu'} + encode_kwargs = {'normalize_embeddings': True} + hf = HuggingFaceInstructEmbeddings( + model_name=model_name, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs + ) + """ + + client: Any #: :meta private: + model_name: str = DEFAULT_INSTRUCT_MODEL + """Model name to use.""" + cache_folder: Optional[str] = None + """Path to store models. + Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" + encode_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass when calling the `encode` method of the model.""" + embed_instruction: str = DEFAULT_EMBED_INSTRUCTION + """Instruction to use for embedding documents.""" + query_instruction: str = DEFAULT_QUERY_INSTRUCTION + """Instruction to use for embedding query.""" + + def __init__(self, **kwargs: Any): + """Initialize the sentence_transformer.""" + super().__init__(**kwargs) + try: + from InstructorEmbedding import INSTRUCTOR + + self.client = INSTRUCTOR( + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + ) + except ImportError as e: + raise ImportError("Dependencies for InstructorEmbedding not found.") from e + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace instruct model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + instruction_pairs = [[self.embed_instruction, text] for text in texts] + embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs) + return embeddings.tolist() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace instruct model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + instruction_pair = [self.query_instruction, text] + embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0] + return embedding.tolist() + + +class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): + """HuggingFace BGE sentence_transformers embedding models. + + To use, you should have the ``sentence_transformers`` python package installed. + + Example: + .. code-block:: python + + from langchain_community.embeddings import HuggingFaceBgeEmbeddings + + model_name = "BAAI/bge-large-en" + model_kwargs = {'device': 'cpu'} + encode_kwargs = {'normalize_embeddings': True} + hf = HuggingFaceBgeEmbeddings( + model_name=model_name, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs + ) + """ + + client: Any #: :meta private: + model_name: str = DEFAULT_BGE_MODEL + """Model name to use.""" + cache_folder: Optional[str] = None + """Path to store models. + Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" + encode_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass when calling the `encode` method of the model.""" + query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN + """Instruction to use for embedding query.""" + + def __init__(self, **kwargs: Any): + """Initialize the sentence_transformer.""" + super().__init__(**kwargs) + try: + import sentence_transformers + + except ImportError as exc: + raise ImportError( + "Could not import sentence_transformers python package. " + "Please install it with `pip install sentence_transformers`." + ) from exc + + self.client = sentence_transformers.SentenceTransformer( + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + ) + if "-zh" in self.model_name: + self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + texts = [t.replace("\n", " ") for t in texts] + embeddings = self.client.encode(texts, **self.encode_kwargs) + return embeddings.tolist() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + text = text.replace("\n", " ") + embedding = self.client.encode( + self.query_instruction + text, **self.encode_kwargs + ) + return embedding.tolist() + + +class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings): + """Embed texts using the HuggingFace API. + + Requires a HuggingFace Inference API key and a model name. + """ + + api_key: str + """Your API key for the HuggingFace Inference API.""" + model_name: str = "sentence-transformers/all-MiniLM-L6-v2" + """The name of the model to use for text embeddings.""" + api_url: Optional[str] = None + """Custom inference endpoint url. None for using default public url.""" + + @property + def _api_url(self) -> str: + return self.api_url or self._default_api_url + + @property + def _default_api_url(self) -> str: + return ( + "https://api-inference.huggingface.co" + "/pipeline" + "/feature-extraction" + f"/{self.model_name}" + ) + + @property + def _headers(self) -> dict: + return {"Authorization": f"Bearer {self.api_key}"} + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embedded texts as List[List[float]], where each inner List[float] + corresponds to a single input text. + + Example: + .. code-block:: python + + from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings + + hf_embeddings = HuggingFaceInferenceAPIEmbeddings( + api_key="your_api_key", + model_name="sentence-transformers/all-MiniLM-l6-v2" + ) + texts = ["Hello, world!", "How are you?"] + hf_embeddings.embed_documents(texts) + """ # noqa: E501 + response = requests.post( + self._api_url, + headers=self._headers, + json={ + "inputs": texts, + "options": {"wait_for_model": True, "use_cache": True}, + }, + ) + return response.json() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] diff --git a/.scripts/community_split/libs/community/langchain_community/embeddings/johnsnowlabs.py b/.scripts/community_split/libs/community/langchain_community/embeddings/johnsnowlabs.py new file mode 100644 index 00000000000..f183efe87b5 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/embeddings/johnsnowlabs.py @@ -0,0 +1,92 @@ +import os +import sys +from typing import Any, List + +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra + + +class JohnSnowLabsEmbeddings(BaseModel, Embeddings): + """JohnSnowLabs embedding models + + To use, you should have the ``johnsnowlabs`` python package installed. + Example: + .. code-block:: python + + from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings + + embedding = JohnSnowLabsEmbeddings(model='embed_sentence.bert') + output = embedding.embed_query("foo bar") + """ # noqa: E501 + + model: Any = "embed_sentence.bert" + + def __init__( + self, + model: Any = "embed_sentence.bert", + hardware_target: str = "cpu", + **kwargs: Any, + ): + """Initialize the johnsnowlabs model.""" + super().__init__(**kwargs) + # 1) Check imports + try: + from johnsnowlabs import nlp + from nlu.pipe.pipeline import NLUPipeline + except ImportError as exc: + raise ImportError( + "Could not import johnsnowlabs python package. " + "Please install it with `pip install johnsnowlabs`." + ) from exc + + # 2) Start a Spark Session + try: + os.environ["PYSPARK_PYTHON"] = sys.executable + os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable + nlp.start(hardware_target=hardware_target) + except Exception as exc: + raise Exception("Failure starting Spark Session") from exc + + # 3) Load the model + try: + if isinstance(model, str): + self.model = nlp.load(model) + elif isinstance(model, NLUPipeline): + self.model = model + else: + self.model = nlp.to_nlu_pipe(model) + except Exception as exc: + raise Exception("Failure loading model") from exc + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a JohnSnowLabs transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + + df = self.model.predict(texts, output_level="document") + emb_col = None + for c in df.columns: + if "embedding" in c: + emb_col = c + return [vec.tolist() for vec in df[emb_col].tolist()] + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a JohnSnowLabs transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] diff --git a/.scripts/community_split/libs/community/langchain_community/embeddings/self_hosted_hugging_face.py b/.scripts/community_split/libs/community/langchain_community/embeddings/self_hosted_hugging_face.py new file mode 100644 index 00000000000..0b706532cf2 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/embeddings/self_hosted_hugging_face.py @@ -0,0 +1,168 @@ +import importlib +import logging +from typing import Any, Callable, List, Optional + +from langchain_community.embeddings.self_hosted import SelfHostedEmbeddings + +DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" +DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" +DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: " +DEFAULT_QUERY_INSTRUCTION = ( + "Represent the question for retrieving supporting documents: " +) + +logger = logging.getLogger(__name__) + + +def _embed_documents(client: Any, *args: Any, **kwargs: Any) -> List[List[float]]: + """Inference function to send to the remote hardware. + + Accepts a sentence_transformer model_id and + returns a list of embeddings for each document in the batch. + """ + return client.encode(*args, **kwargs) + + +def load_embedding_model(model_id: str, instruct: bool = False, device: int = 0) -> Any: + """Load the embedding model.""" + if not instruct: + import sentence_transformers + + client = sentence_transformers.SentenceTransformer(model_id) + else: + from InstructorEmbedding import INSTRUCTOR + + client = INSTRUCTOR(model_id) + + if importlib.util.find_spec("torch") is not None: + import torch + + cuda_device_count = torch.cuda.device_count() + if device < -1 or (device >= cuda_device_count): + raise ValueError( + f"Got device=={device}, " + f"device is required to be within [-1, {cuda_device_count})" + ) + if device < 0 and cuda_device_count > 0: + logger.warning( + "Device has %d GPUs available. " + "Provide device={deviceId} to `from_model_id` to use available" + "GPUs for execution. deviceId is -1 for CPU and " + "can be a positive integer associated with CUDA device id.", + cuda_device_count, + ) + + client = client.to(device) + return client + + +class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings): + """HuggingFace embedding models on self-hosted remote hardware. + + Supported hardware includes auto-launched instances on AWS, GCP, Azure, + and Lambda, as well as servers specified + by IP address and SSH credentials (such as on-prem, or another cloud + like Paperspace, Coreweave, etc.). + + To use, you should have the ``runhouse`` python package installed. + + Example: + .. code-block:: python + + from langchain_community.embeddings import SelfHostedHuggingFaceEmbeddings + import runhouse as rh + model_name = "sentence-transformers/all-mpnet-base-v2" + gpu = rh.cluster(name="rh-a10x", instance_type="A100:1") + hf = SelfHostedHuggingFaceEmbeddings(model_name=model_name, hardware=gpu) + """ + + client: Any #: :meta private: + model_id: str = DEFAULT_MODEL_NAME + """Model name to use.""" + model_reqs: List[str] = ["./", "sentence_transformers", "torch"] + """Requirements to install on hardware to inference the model.""" + hardware: Any + """Remote hardware to send the inference function to.""" + model_load_fn: Callable = load_embedding_model + """Function to load the model remotely on the server.""" + load_fn_kwargs: Optional[dict] = None + """Keyword arguments to pass to the model load function.""" + inference_fn: Callable = _embed_documents + """Inference function to extract the embeddings.""" + + def __init__(self, **kwargs: Any): + """Initialize the remote inference function.""" + load_fn_kwargs = kwargs.pop("load_fn_kwargs", {}) + load_fn_kwargs["model_id"] = load_fn_kwargs.get("model_id", DEFAULT_MODEL_NAME) + load_fn_kwargs["instruct"] = load_fn_kwargs.get("instruct", False) + load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0) + super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs) + + +class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings): + """HuggingFace InstructEmbedding models on self-hosted remote hardware. + + Supported hardware includes auto-launched instances on AWS, GCP, Azure, + and Lambda, as well as servers specified + by IP address and SSH credentials (such as on-prem, or another + cloud like Paperspace, Coreweave, etc.). + + To use, you should have the ``runhouse`` python package installed. + + Example: + .. code-block:: python + + from langchain_community.embeddings import SelfHostedHuggingFaceInstructEmbeddings + import runhouse as rh + model_name = "hkunlp/instructor-large" + gpu = rh.cluster(name='rh-a10x', instance_type='A100:1') + hf = SelfHostedHuggingFaceInstructEmbeddings( + model_name=model_name, hardware=gpu) + """ # noqa: E501 + + model_id: str = DEFAULT_INSTRUCT_MODEL + """Model name to use.""" + embed_instruction: str = DEFAULT_EMBED_INSTRUCTION + """Instruction to use for embedding documents.""" + query_instruction: str = DEFAULT_QUERY_INSTRUCTION + """Instruction to use for embedding query.""" + model_reqs: List[str] = ["./", "InstructorEmbedding", "torch"] + """Requirements to install on hardware to inference the model.""" + + def __init__(self, **kwargs: Any): + """Initialize the remote inference function.""" + load_fn_kwargs = kwargs.pop("load_fn_kwargs", {}) + load_fn_kwargs["model_id"] = load_fn_kwargs.get( + "model_id", DEFAULT_INSTRUCT_MODEL + ) + load_fn_kwargs["instruct"] = load_fn_kwargs.get("instruct", True) + load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0) + super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace instruct model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + instruction_pairs = [] + for text in texts: + instruction_pairs.append([self.embed_instruction, text]) + embeddings = self.client(self.pipeline_ref, instruction_pairs) + return embeddings.tolist() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace instruct model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + instruction_pair = [self.query_instruction, text] + embedding = self.client(self.pipeline_ref, [instruction_pair])[0] + return embedding.tolist() diff --git a/.scripts/community_split/libs/community/langchain_community/llms/anthropic.py b/.scripts/community_split/libs/community/langchain_community/llms/anthropic.py new file mode 100644 index 00000000000..be832cf1368 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/llms/anthropic.py @@ -0,0 +1,351 @@ +import re +import warnings +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, +) + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk +from langchain_core.prompt_values import PromptValue +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import ( + check_package_version, + get_from_dict_or_env, + get_pydantic_field_names, +) +from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str + + +class _AnthropicCommon(BaseLanguageModel): + client: Any = None #: :meta private: + async_client: Any = None #: :meta private: + model: str = Field(default="claude-2", alias="model_name") + """Model name to use.""" + + max_tokens_to_sample: int = Field(default=256, alias="max_tokens") + """Denotes the number of tokens to predict per generation.""" + + temperature: Optional[float] = None + """A non-negative float that tunes the degree of randomness in generation.""" + + top_k: Optional[int] = None + """Number of most likely tokens to consider at each step.""" + + top_p: Optional[float] = None + """Total probability mass of tokens to consider at each step.""" + + streaming: bool = False + """Whether to stream the results.""" + + default_request_timeout: Optional[float] = None + """Timeout for requests to Anthropic Completion API. Default is 600 seconds.""" + + anthropic_api_url: Optional[str] = None + + anthropic_api_key: Optional[SecretStr] = None + + HUMAN_PROMPT: Optional[str] = None + AI_PROMPT: Optional[str] = None + count_tokens: Optional[Callable[[str], int]] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @root_validator(pre=True) + def build_extra(cls, values: Dict) -> Dict: + extra = values.get("model_kwargs", {}) + all_required_field_names = get_pydantic_field_names(cls) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["anthropic_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") + ) + # Get custom api url from environment. + values["anthropic_api_url"] = get_from_dict_or_env( + values, + "anthropic_api_url", + "ANTHROPIC_API_URL", + default="https://api.anthropic.com", + ) + + try: + import anthropic + + check_package_version("anthropic", gte_version="0.3") + values["client"] = anthropic.Anthropic( + base_url=values["anthropic_api_url"], + api_key=values["anthropic_api_key"].get_secret_value(), + timeout=values["default_request_timeout"], + ) + values["async_client"] = anthropic.AsyncAnthropic( + base_url=values["anthropic_api_url"], + api_key=values["anthropic_api_key"].get_secret_value(), + timeout=values["default_request_timeout"], + ) + values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT + values["AI_PROMPT"] = anthropic.AI_PROMPT + values["count_tokens"] = values["client"].count_tokens + + except ImportError: + raise ImportError( + "Could not import anthropic python package. " + "Please it install it with `pip install anthropic`." + ) + return values + + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling Anthropic API.""" + d = { + "max_tokens_to_sample": self.max_tokens_to_sample, + "model": self.model, + } + if self.temperature is not None: + d["temperature"] = self.temperature + if self.top_k is not None: + d["top_k"] = self.top_k + if self.top_p is not None: + d["top_p"] = self.top_p + return {**d, **self.model_kwargs} + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{}, **self._default_params} + + def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: + if not self.HUMAN_PROMPT or not self.AI_PROMPT: + raise NameError("Please ensure the anthropic package is loaded") + + if stop is None: + stop = [] + + # Never want model to invent new turns of Human / Assistant dialog. + stop.extend([self.HUMAN_PROMPT]) + + return stop + + +class Anthropic(LLM, _AnthropicCommon): + """Anthropic large language models. + + To use, you should have the ``anthropic`` python package installed, and the + environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass + it as a named parameter to the constructor. + + Example: + .. code-block:: python + + import anthropic + from langchain_community.llms import Anthropic + + model = Anthropic(model="", anthropic_api_key="my-api-key") + + # Simplest invocation, automatically wrapped with HUMAN_PROMPT + # and AI_PROMPT. + response = model("What are the biggest risks facing humanity?") + + # Or if you want to use the chat mode, build a few-shot-prompt, or + # put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT: + raw_prompt = "What are the biggest risks facing humanity?" + prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" + response = model(prompt) + """ + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + + @root_validator() + def raise_warning(cls, values: Dict) -> Dict: + """Raise warning that this class is deprecated.""" + warnings.warn( + "This Anthropic LLM is deprecated. " + "Please use `from langchain_community.chat_models import ChatAnthropic` " + "instead" + ) + return values + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "anthropic-llm" + + def _wrap_prompt(self, prompt: str) -> str: + if not self.HUMAN_PROMPT or not self.AI_PROMPT: + raise NameError("Please ensure the anthropic package is loaded") + + if prompt.startswith(self.HUMAN_PROMPT): + return prompt # Already wrapped. + + # Guard against common errors in specifying wrong number of newlines. + corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt) + if n_subs == 1: + return corrected_prompt + + # As a last resort, wrap the prompt ourselves to emulate instruct-style. + return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + r"""Call out to Anthropic's completion endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + prompt = "What are the biggest risks facing humanity?" + prompt = f"\n\nHuman: {prompt}\n\nAssistant:" + response = model(prompt) + + """ + if self.streaming: + completion = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + response = self.client.completions.create( + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + **params, + ) + return response.completion + + def convert_prompt(self, prompt: PromptValue) -> str: + return self._wrap_prompt(prompt.to_string()) + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to Anthropic's completion endpoint asynchronously.""" + if self.streaming: + completion = "" + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + + response = await self.async_client.completions.create( + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + **params, + ) + return response.completion + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + r"""Call Anthropic completion_stream and return the resulting generator. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + A generator representing the stream of tokens from Anthropic. + Example: + .. code-block:: python + + prompt = "Write a poem about a stream." + prompt = f"\n\nHuman: {prompt}\n\nAssistant:" + generator = anthropic.stream(prompt) + for token in generator: + yield token + """ + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + + for token in self.client.completions.create( + prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params + ): + chunk = GenerationChunk(text=token.completion) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + r"""Call Anthropic completion_stream and return the resulting generator. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + A generator representing the stream of tokens from Anthropic. + Example: + .. code-block:: python + prompt = "Write a poem about a stream." + prompt = f"\n\nHuman: {prompt}\n\nAssistant:" + generator = anthropic.stream(prompt) + for token in generator: + yield token + """ + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + + async for token in await self.async_client.completions.create( + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + stream=True, + **params, + ): + chunk = GenerationChunk(text=token.completion) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + def get_num_tokens(self, text: str) -> int: + """Calculate number of tokens.""" + if not self.count_tokens: + raise NameError("Please ensure the anthropic package is loaded") + return self.count_tokens(text) diff --git a/.scripts/community_split/libs/community/langchain_community/llms/cloudflare_workersai.py b/.scripts/community_split/libs/community/langchain_community/llms/cloudflare_workersai.py new file mode 100644 index 00000000000..840acdbdb81 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/llms/cloudflare_workersai.py @@ -0,0 +1,126 @@ +import json +import logging +from typing import Any, Dict, Iterator, List, Optional + +import requests +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk + +logger = logging.getLogger(__name__) + + +class CloudflareWorkersAI(LLM): + """Langchain LLM class to help to access Cloudflare Workers AI service. + + To use, you must provide an API token and + account ID to access Cloudflare Workers AI, and + pass it as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI + + my_account_id = "my_account_id" + my_api_token = "my_secret_api_token" + llm_model = "@cf/meta/llama-2-7b-chat-int8" + + cf_ai = CloudflareWorkersAI( + account_id=my_account_id, + api_token=my_api_token, + model=llm_model + ) + """ # noqa: E501 + + account_id: str + api_token: str + model: str = "@cf/meta/llama-2-7b-chat-int8" + base_url: str = "https://api.cloudflare.com/client/v4/accounts" + streaming: bool = False + endpoint_url: str = "" + + def __init__(self, **kwargs: Any) -> None: + """Initialize the Cloudflare Workers AI class.""" + super().__init__(**kwargs) + + self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}" + + @property + def _llm_type(self) -> str: + """Return type of LLM.""" + return "cloudflare" + + @property + def _default_params(self) -> Dict[str, Any]: + """Default parameters""" + return {} + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Identifying parameters""" + return { + "account_id": self.account_id, + "api_token": self.api_token, + "model": self.model, + "base_url": self.base_url, + } + + def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response: + """Call Cloudflare Workers API""" + headers = {"Authorization": f"Bearer {self.api_token}"} + data = {"prompt": prompt, "stream": self.streaming, **params} + response = requests.post(self.endpoint_url, headers=headers, json=data) + return response + + def _process_response(self, response: requests.Response) -> str: + """Process API response""" + if response.ok: + data = response.json() + return data["result"]["response"] + else: + raise ValueError(f"Request failed with status {response.status_code}") + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Streaming prediction""" + original_steaming: bool = self.streaming + self.streaming = True + _response_prefix_count = len("data: ") + _response_stream_end = b"data: [DONE]" + for chunk in self._call_api(prompt, kwargs).iter_lines(): + if chunk == _response_stream_end: + break + if len(chunk) > _response_prefix_count: + try: + data = json.loads(chunk[_response_prefix_count:]) + except Exception as e: + logger.debug(chunk) + raise e + if data is not None and "response" in data: + yield GenerationChunk(text=data["response"]) + if run_manager: + run_manager.on_llm_new_token(data["response"]) + logger.debug("stream end") + self.streaming = original_steaming + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Regular prediction""" + if self.streaming: + return "".join( + [c.text for c in self._stream(prompt, stop, run_manager, **kwargs)] + ) + else: + response = self._call_api(prompt, kwargs) + return self._process_response(response) diff --git a/.scripts/community_split/libs/community/langchain_community/retrievers/__init__.py b/.scripts/community_split/libs/community/langchain_community/retrievers/__init__.py new file mode 100644 index 00000000000..75b4ef67536 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/retrievers/__init__.py @@ -0,0 +1,106 @@ +"""**Retriever** class returns Documents given a text **query**. + +It is more general than a vector store. A retriever does not need to be able to +store documents, only to return (or retrieve) it. Vector stores can be used as +the backbone of a retriever, but there are other types of retrievers as well. + +**Class hierarchy:** + +.. code-block:: + + BaseRetriever --> Retriever # Examples: ArxivRetriever, MergerRetriever + +**Main helpers:** + +.. code-block:: + + Document, Serializable, Callbacks, + CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun +""" + +from langchain_community.retrievers.arcee import ArceeRetriever +from langchain_community.retrievers.arxiv import ArxivRetriever +from langchain_community.retrievers.azure_cognitive_search import ( + AzureCognitiveSearchRetriever, +) +from langchain_community.retrievers.bedrock import AmazonKnowledgeBasesRetriever +from langchain_community.retrievers.bm25 import BM25Retriever +from langchain_community.retrievers.chaindesk import ChaindeskRetriever +from langchain_community.retrievers.chatgpt_plugin_retriever import ( + ChatGPTPluginRetriever, +) +from langchain_community.retrievers.cohere_rag_retriever import CohereRagRetriever +from langchain_community.retrievers.docarray import DocArrayRetriever +from langchain_community.retrievers.elastic_search_bm25 import ( + ElasticSearchBM25Retriever, +) +from langchain_community.retrievers.embedchain import EmbedchainRetriever +from langchain_community.retrievers.google_cloud_documentai_warehouse import ( + GoogleDocumentAIWarehouseRetriever, +) +from langchain_community.retrievers.google_vertex_ai_search import ( + GoogleCloudEnterpriseSearchRetriever, + GoogleVertexAIMultiTurnSearchRetriever, + GoogleVertexAISearchRetriever, +) +from langchain_community.retrievers.kay import KayAiRetriever +from langchain_community.retrievers.kendra import AmazonKendraRetriever +from langchain_community.retrievers.knn import KNNRetriever +from langchain_community.retrievers.llama_index import ( + LlamaIndexGraphRetriever, + LlamaIndexRetriever, +) +from langchain_community.retrievers.metal import MetalRetriever +from langchain_community.retrievers.milvus import MilvusRetriever +from langchain_community.retrievers.outline import OutlineRetriever +from langchain_community.retrievers.pinecone_hybrid_search import ( + PineconeHybridSearchRetriever, +) +from langchain_community.retrievers.pubmed import PubMedRetriever +from langchain_community.retrievers.remote_retriever import RemoteLangChainRetriever +from langchain_community.retrievers.svm import SVMRetriever +from langchain_community.retrievers.tavily_search_api import TavilySearchAPIRetriever +from langchain_community.retrievers.tfidf import TFIDFRetriever +from langchain_community.retrievers.weaviate_hybrid_search import ( + WeaviateHybridSearchRetriever, +) +from langchain_community.retrievers.wikipedia import WikipediaRetriever +from langchain_community.retrievers.zep import ZepRetriever +from langchain_community.retrievers.zilliz import ZillizRetriever + +__all__ = [ + "AmazonKendraRetriever", + "AmazonKnowledgeBasesRetriever", + "ArceeRetriever", + "ArxivRetriever", + "AzureCognitiveSearchRetriever", + "ChatGPTPluginRetriever", + "ChaindeskRetriever", + "CohereRagRetriever", + "ElasticSearchBM25Retriever", + "EmbedchainRetriever", + "GoogleDocumentAIWarehouseRetriever", + "GoogleCloudEnterpriseSearchRetriever", + "GoogleVertexAIMultiTurnSearchRetriever", + "GoogleVertexAISearchRetriever", + "KayAiRetriever", + "KNNRetriever", + "LlamaIndexGraphRetriever", + "LlamaIndexRetriever", + "MetalRetriever", + "MilvusRetriever", + "OutlineRetriever", + "PineconeHybridSearchRetriever", + "PubMedRetriever", + "RemoteLangChainRetriever", + "SVMRetriever", + "TavilySearchAPIRetriever", + "TFIDFRetriever", + "BM25Retriever", + "VespaRetriever", + "WeaviateHybridSearchRetriever", + "WikipediaRetriever", + "ZepRetriever", + "ZillizRetriever", + "DocArrayRetriever", +] diff --git a/.scripts/community_split/libs/community/langchain_community/storage/__init__.py b/.scripts/community_split/libs/community/langchain_community/storage/__init__.py new file mode 100644 index 00000000000..7af3d5b3000 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/storage/__init__.py @@ -0,0 +1,19 @@ +"""Implementations of key-value stores and storage helpers. + +Module provides implementations of various key-value stores that conform +to a simple key-value interface. + +The primary goal of these storages is to support implementation of caching. +""" + +from langchain_community.storage.redis import RedisStore +from langchain_community.storage.upstash_redis import ( + UpstashRedisByteStore, + UpstashRedisStore, +) + +__all__ = [ + "RedisStore", + "UpstashRedisByteStore", + "UpstashRedisStore", +] diff --git a/.scripts/community_split/libs/community/langchain_community/tools/amadeus/closest_airport.py b/.scripts/community_split/libs/community/langchain_community/tools/amadeus/closest_airport.py new file mode 100644 index 00000000000..4e8b90a1b2a --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/tools/amadeus/closest_airport.py @@ -0,0 +1,50 @@ +from typing import Optional, Type + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import BaseModel, Field + +from langchain_community.chat_models import ChatOpenAI +from langchain_community.tools.amadeus.base import AmadeusBaseTool + + +class ClosestAirportSchema(BaseModel): + """Schema for the AmadeusClosestAirport tool.""" + + location: str = Field( + description=( + " The location for which you would like to find the nearest airport " + " along with optional details such as country, state, region, or " + " province, allowing for easy processing and identification of " + " the closest airport. Examples of the format are the following:\n" + " Cali, Colombia\n " + " Lincoln, Nebraska, United States\n" + " New York, United States\n" + " Sydney, New South Wales, Australia\n" + " Rome, Lazio, Italy\n" + " Toronto, Ontario, Canada\n" + ) + ) + + +class AmadeusClosestAirport(AmadeusBaseTool): + """Tool for finding the closest airport to a particular location.""" + + name: str = "closest_airport" + description: str = ( + "Use this tool to find the closest airport to a particular location." + ) + args_schema: Type[ClosestAirportSchema] = ClosestAirportSchema + + def _run( + self, + location: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + content = ( + f" What is the nearest airport to {location}? Please respond with the " + " airport's International Air Transport Association (IATA) Location " + ' Identifier in the following JSON format. JSON: "iataCode": "IATA ' + ' Location Identifier" ' + ) + + return ChatOpenAI(temperature=0).predict(content) diff --git a/.scripts/community_split/libs/community/langchain_community/tools/clickup/tool.py b/.scripts/community_split/libs/community/langchain_community/tools/clickup/tool.py new file mode 100644 index 00000000000..93988dd7d59 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/tools/clickup/tool.py @@ -0,0 +1,42 @@ +""" +This tool allows agents to interact with the clickup library +and operate on a Clickup instance. +To use this tool, you must first set as environment variables: + client_secret + client_id + code + +Below is a sample script that uses the Clickup tool: + +```python +from langchain_community.agent_toolkits.clickup.toolkit import ClickupToolkit +from langchain_community.utilities.clickup import ClickupAPIWrapper + +clickup = ClickupAPIWrapper() +toolkit = ClickupToolkit.from_clickup_api_wrapper(clickup) +``` +""" +from typing import Optional + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import Field +from langchain_core.tools import BaseTool + +from langchain_community.utilities.clickup import ClickupAPIWrapper + + +class ClickupAction(BaseTool): + """Tool that queries the Clickup API.""" + + api_wrapper: ClickupAPIWrapper = Field(default_factory=ClickupAPIWrapper) + mode: str + name: str = "" + description: str = "" + + def _run( + self, + instructions: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the Clickup API to run an operation.""" + return self.api_wrapper.run(self.mode, instructions) diff --git a/.scripts/community_split/libs/community/langchain_community/tools/jira/tool.py b/.scripts/community_split/libs/community/langchain_community/tools/jira/tool.py new file mode 100644 index 00000000000..dc57b13dc20 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/tools/jira/tool.py @@ -0,0 +1,44 @@ +""" +This tool allows agents to interact with the atlassian-python-api library +and operate on a Jira instance. For more information on the +atlassian-python-api library, see https://atlassian-python-api.readthedocs.io/jira.html + +To use this tool, you must first set as environment variables: + JIRA_API_TOKEN + JIRA_USERNAME + JIRA_INSTANCE_URL + +Below is a sample script that uses the Jira tool: + +```python +from langchain_community.agent_toolkits.jira.toolkit import JiraToolkit +from langchain_community.utilities.jira import JiraAPIWrapper + +jira = JiraAPIWrapper() +toolkit = JiraToolkit.from_jira_api_wrapper(jira) +``` +""" +from typing import Optional + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import Field +from langchain_core.tools import BaseTool + +from langchain_community.utilities.jira import JiraAPIWrapper + + +class JiraAction(BaseTool): + """Tool that queries the Atlassian Jira API.""" + + api_wrapper: JiraAPIWrapper = Field(default_factory=JiraAPIWrapper) + mode: str + name: str = "" + description: str = "" + + def _run( + self, + instructions: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the Atlassian Jira API to run an operation.""" + return self.api_wrapper.run(self.mode, instructions) diff --git a/.scripts/community_split/libs/community/langchain_community/tools/powerbi/tool.py b/.scripts/community_split/libs/community/langchain_community/tools/powerbi/tool.py new file mode 100644 index 00000000000..2ee4e4f5129 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/tools/powerbi/tool.py @@ -0,0 +1,276 @@ +"""Tools for interacting with a Power BI dataset.""" +import logging +from time import perf_counter +from typing import Any, Dict, Optional, Tuple + +from langchain_core.callbacks import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.pydantic_v1 import Field, validator +from langchain_core.tools import BaseTool +from langchain_community.chat_models.openai import _import_tiktoken + +from langchain_community.tools.powerbi.prompt import ( + BAD_REQUEST_RESPONSE, + DEFAULT_FEWSHOT_EXAMPLES, + RETRY_RESPONSE, +) +from langchain_community.utilities.powerbi import PowerBIDataset, json_to_md + +logger = logging.getLogger(__name__) + + +class QueryPowerBITool(BaseTool): + """Tool for querying a Power BI Dataset.""" + + name: str = "query_powerbi" + description: str = """ + Input to this tool is a detailed question about the dataset, output is a result from the dataset. It will try to answer the question using the dataset, and if it cannot, it will ask for clarification. + + Example Input: "How many rows are in table1?" + """ # noqa: E501 + llm_chain: Any + powerbi: PowerBIDataset = Field(exclude=True) + examples: Optional[str] = DEFAULT_FEWSHOT_EXAMPLES + session_cache: Dict[str, Any] = Field(default_factory=dict, exclude=True) + max_iterations: int = 5 + output_token_limit: int = 4000 + tiktoken_model_name: Optional[str] = None # "cl100k_base" + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @validator("llm_chain") + def validate_llm_chain_input_variables( # pylint: disable=E0213 + cls, llm_chain: Any + ) -> Any: + """Make sure the LLM chain has the correct input variables.""" + for var in llm_chain.prompt.input_variables: + if var not in ["tool_input", "tables", "schemas", "examples"]: + raise ValueError( + "LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: C0301 E501 # pylint: disable=C0301 + llm_chain.prompt.input_variables, + ) + return llm_chain + + def _check_cache(self, tool_input: str) -> Optional[str]: + """Check if the input is present in the cache. + + If the value is a bad request, overwrite with the escalated version, + if not present return None.""" + if tool_input not in self.session_cache: + return None + return self.session_cache[tool_input] + + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: + """Execute the query, return the results or an error message.""" + if cache := self._check_cache(tool_input): + logger.debug("Found cached result for %s: %s", tool_input, cache) + return cache + + try: + logger.info("Running PBI Query Tool with input: %s", tool_input) + query = self.llm_chain.predict( + tool_input=tool_input, + tables=self.powerbi.get_table_names(), + schemas=self.powerbi.get_schemas(), + examples=self.examples, + callbacks=run_manager.get_child() if run_manager else None, + ) + except Exception as exc: # pylint: disable=broad-except + self.session_cache[tool_input] = f"Error on call to LLM: {exc}" + return self.session_cache[tool_input] + if query == "I cannot answer this": + self.session_cache[tool_input] = query + return self.session_cache[tool_input] + logger.info("PBI Query:\n%s", query) + start_time = perf_counter() + pbi_result = self.powerbi.run(command=query) + end_time = perf_counter() + logger.debug("PBI Result: %s", pbi_result) + logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}") + result, error = self._parse_output(pbi_result) + if error is not None and "TokenExpired" in error: + self.session_cache[ + tool_input + ] = "Authentication token expired or invalid, please try reauthenticate." + return self.session_cache[tool_input] + + iterations = kwargs.get("iterations", 0) + if error and iterations < self.max_iterations: + return self._run( + tool_input=RETRY_RESPONSE.format( + tool_input=tool_input, query=query, error=error + ), + run_manager=run_manager, + iterations=iterations + 1, + ) + + self.session_cache[tool_input] = ( + result if result else BAD_REQUEST_RESPONSE.format(error=error) + ) + return self.session_cache[tool_input] + + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: + """Execute the query, return the results or an error message.""" + if cache := self._check_cache(tool_input): + logger.debug("Found cached result for %s: %s", tool_input, cache) + return f"{cache}, from cache, you have already asked this question." + try: + logger.info("Running PBI Query Tool with input: %s", tool_input) + query = await self.llm_chain.apredict( + tool_input=tool_input, + tables=self.powerbi.get_table_names(), + schemas=self.powerbi.get_schemas(), + examples=self.examples, + callbacks=run_manager.get_child() if run_manager else None, + ) + except Exception as exc: # pylint: disable=broad-except + self.session_cache[tool_input] = f"Error on call to LLM: {exc}" + return self.session_cache[tool_input] + + if query == "I cannot answer this": + self.session_cache[tool_input] = query + return self.session_cache[tool_input] + logger.info("PBI Query: %s", query) + start_time = perf_counter() + pbi_result = await self.powerbi.arun(command=query) + end_time = perf_counter() + logger.debug("PBI Result: %s", pbi_result) + logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}") + result, error = self._parse_output(pbi_result) + if error is not None and ("TokenExpired" in error or "TokenError" in error): + self.session_cache[ + tool_input + ] = "Authentication token expired or invalid, please try to reauthenticate or check the scope of the credential." # noqa: E501 + return self.session_cache[tool_input] + + iterations = kwargs.get("iterations", 0) + if error and iterations < self.max_iterations: + return await self._arun( + tool_input=RETRY_RESPONSE.format( + tool_input=tool_input, query=query, error=error + ), + run_manager=run_manager, + iterations=iterations + 1, + ) + + self.session_cache[tool_input] = ( + result if result else BAD_REQUEST_RESPONSE.format(error=error) + ) + return self.session_cache[tool_input] + + def _parse_output( + self, pbi_result: Dict[str, Any] + ) -> Tuple[Optional[str], Optional[Any]]: + """Parse the output of the query to a markdown table.""" + if "results" in pbi_result: + rows = pbi_result["results"][0]["tables"][0]["rows"] + if len(rows) == 0: + logger.info("0 records in result, query was valid.") + return ( + None, + "0 rows returned, this might be correct, but please validate if all filter values were correct?", # noqa: E501 + ) + result = json_to_md(rows) + too_long, length = self._result_too_large(result) + if too_long: + return ( + f"Result too large, please try to be more specific or use the `TOPN` function. The result is {length} tokens long, the limit is {self.output_token_limit} tokens.", # noqa: E501 + None, + ) + return result, None + + if "error" in pbi_result: + if ( + "pbi.error" in pbi_result["error"] + and "details" in pbi_result["error"]["pbi.error"] + ): + return None, pbi_result["error"]["pbi.error"]["details"][0]["detail"] + return None, pbi_result["error"] + return None, pbi_result + + def _result_too_large(self, result: str) -> Tuple[bool, int]: + """Tokenize the output of the query.""" + if self.tiktoken_model_name: + tiktoken_ = _import_tiktoken() + encoding = tiktoken_.encoding_for_model(self.tiktoken_model_name) + length = len(encoding.encode(result)) + logger.info("Result length: %s", length) + return length > self.output_token_limit, length + return False, 0 + + +class InfoPowerBITool(BaseTool): + """Tool for getting metadata about a PowerBI Dataset.""" + + name: str = "schema_powerbi" + description: str = """ + Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. + Be sure that the tables actually exist by calling list_tables_powerbi first! + + Example Input: "table1, table2, table3" + """ # noqa: E501 + powerbi: PowerBIDataset = Field(exclude=True) + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + return self.powerbi.get_table_info(tool_input.split(", ")) + + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return await self.powerbi.aget_table_info(tool_input.split(", ")) + + +class ListPowerBITool(BaseTool): + """Tool for getting tables names.""" + + name: str = "list_tables_powerbi" + description: str = "Input is an empty string, output is a comma separated list of tables in the database." # noqa: E501 # pylint: disable=C0301 + powerbi: PowerBIDataset = Field(exclude=True) + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def _run( + self, + tool_input: Optional[str] = None, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the names of the tables.""" + return ", ".join(self.powerbi.get_table_names()) + + async def _arun( + self, + tool_input: Optional[str] = None, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + """Get the names of the tables.""" + return ", ".join(self.powerbi.get_table_names()) diff --git a/.scripts/community_split/libs/community/langchain_community/tools/spark_sql/tool.py b/.scripts/community_split/libs/community/langchain_community/tools/spark_sql/tool.py new file mode 100644 index 00000000000..4a07000249d --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/tools/spark_sql/tool.py @@ -0,0 +1,130 @@ +# flake8: noqa +"""Tools for interacting with Spark SQL.""" +from typing import Any, Dict, Optional + +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.callbacks import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.prompts import PromptTemplate +from langchain_community.utilities.spark_sql import SparkSQL +from langchain_core.tools import BaseTool +from langchain_community.tools.spark_sql.prompt import QUERY_CHECKER + + +class BaseSparkSQLTool(BaseModel): + """Base tool for interacting with Spark SQL.""" + + db: SparkSQL = Field(exclude=True) + + class Config(BaseTool.Config): + pass + + +class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): + """Tool for querying a Spark SQL.""" + + name: str = "query_sql_db" + description: str = """ + Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL. + If the query is not correct, an error message will be returned. + If an error is returned, rewrite the query, check the query, and try again. + """ + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + return self.db.run_no_throw(query) + + +class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool): + """Tool for getting metadata about a Spark SQL.""" + + name: str = "schema_sql_db" + description: str = """ + Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. + Be sure that the tables actually exist by calling list_tables_sql_db first! + + Example Input: "table1, table2, table3" + """ + + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + return self.db.get_table_info_no_throw(table_names.split(", ")) + + +class ListSparkSQLTool(BaseSparkSQLTool, BaseTool): + """Tool for getting tables names.""" + + name: str = "list_tables_sql_db" + description: str = "Input is an empty string, output is a comma separated list of tables in the Spark SQL." + + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for a specific table.""" + return ", ".join(self.db.get_usable_table_names()) + + +class QueryCheckerTool(BaseSparkSQLTool, BaseTool): + """Use an LLM to check if a query is correct. + Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" + + template: str = QUERY_CHECKER + llm: BaseLanguageModel + llm_chain: Any = Field(init=False) + name: str = "query_checker_sql_db" + description: str = """ + Use this tool to double check if your query is correct before executing it. + Always use this tool before executing a query with query_sql_db! + """ + + @root_validator(pre=True) + def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "llm_chain" not in values: + from langchain.chains.llm import LLMChain + values["llm_chain"] = LLMChain( + llm=values.get("llm"), + prompt=PromptTemplate( + template=QUERY_CHECKER, input_variables=["query"] + ), + ) + + if values["llm_chain"].prompt.input_variables != ["query"]: + raise ValueError( + "LLM chain for QueryCheckerTool need to use ['query'] as input_variables " + "for the embedded prompt" + ) + + return values + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the LLM to check the query.""" + return self.llm_chain.predict( + query=query, callbacks=run_manager.get_child() if run_manager else None + ) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return await self.llm_chain.apredict( + query=query, callbacks=run_manager.get_child() if run_manager else None + ) diff --git a/.scripts/community_split/libs/community/langchain_community/tools/sql_database/tool.py b/.scripts/community_split/libs/community/langchain_community/tools/sql_database/tool.py new file mode 100644 index 00000000000..3e0d6509b99 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/tools/sql_database/tool.py @@ -0,0 +1,134 @@ +# flake8: noqa +"""Tools for interacting with a SQL database.""" +from typing import Any, Dict, Optional + +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.callbacks import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.prompts import PromptTemplate +from langchain_community.utilities.sql_database import SQLDatabase +from langchain_core.tools import BaseTool +from langchain_community.tools.sql_database.prompt import QUERY_CHECKER + + +class BaseSQLDatabaseTool(BaseModel): + """Base tool for interacting with a SQL database.""" + + db: SQLDatabase = Field(exclude=True) + + class Config(BaseTool.Config): + pass + + +class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for querying a SQL database.""" + + name: str = "sql_db_query" + description: str = """ + Input to this tool is a detailed and correct SQL query, output is a result from the database. + If the query is not correct, an error message will be returned. + If an error is returned, rewrite the query, check the query, and try again. + """ + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + return self.db.run_no_throw(query) + + +class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting metadata about a SQL database.""" + + name: str = "sql_db_schema" + description: str = """ + Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. + + Example Input: "table1, table2, table3" + """ + + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + return self.db.get_table_info_no_throw( + [t.strip() for t in table_names.split(",")] + ) + + +class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting tables names.""" + + name: str = "sql_db_list_tables" + description: str = "Input is an empty string, output is a comma separated list of tables in the database." + + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for a specific table.""" + return ", ".join(self.db.get_usable_table_names()) + + +class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): + """Use an LLM to check if a query is correct. + Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" + + template: str = QUERY_CHECKER + llm: BaseLanguageModel + llm_chain: Any = Field(init=False) + name: str = "sql_db_query_checker" + description: str = """ + Use this tool to double check if your query is correct before executing it. + Always use this tool before executing a query with sql_db_query! + """ + + @root_validator(pre=True) + def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "llm_chain" not in values: + from langchain.chains.llm import LLMChain + values["llm_chain"] = LLMChain( + llm=values.get("llm"), + prompt=PromptTemplate( + template=QUERY_CHECKER, input_variables=["dialect", "query"] + ), + ) + + if values["llm_chain"].prompt.input_variables != ["dialect", "query"]: + raise ValueError( + "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" + ) + + return values + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the LLM to check the query.""" + return self.llm_chain.predict( + query=query, + dialect=self.db.dialect, + callbacks=run_manager.get_child() if run_manager else None, + ) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return await self.llm_chain.apredict( + query=query, + dialect=self.db.dialect, + callbacks=run_manager.get_child() if run_manager else None, + ) diff --git a/.scripts/community_split/libs/community/langchain_community/tools/zapier/tool.py b/.scripts/community_split/libs/community/langchain_community/tools/zapier/tool.py new file mode 100644 index 00000000000..3d5f3955546 --- /dev/null +++ b/.scripts/community_split/libs/community/langchain_community/tools/zapier/tool.py @@ -0,0 +1,215 @@ +"""[DEPRECATED] + +## Zapier Natural Language Actions API +\ +Full docs here: https://nla.zapier.com/start/ + +**Zapier Natural Language Actions** gives you access to the 5k+ apps, 20k+ actions +on Zapier's platform through a natural language API interface. + +NLA supports apps like Gmail, Salesforce, Trello, Slack, Asana, HubSpot, Google Sheets, +Microsoft Teams, and thousands more apps: https://zapier.com/apps + +Zapier NLA handles ALL the underlying API auth and translation from +natural language --> underlying API call --> return simplified output for LLMs +The key idea is you, or your users, expose a set of actions via an oauth-like setup +window, which you can then query and execute via a REST API. + +NLA offers both API Key and OAuth for signing NLA API requests. + +1. Server-side (API Key): for quickly getting started, testing, and production scenarios + where LangChain will only use actions exposed in the developer's Zapier account + (and will use the developer's connected accounts on Zapier.com) + +2. User-facing (Oauth): for production scenarios where you are deploying an end-user + facing application and LangChain needs access to end-user's exposed actions and + connected accounts on Zapier.com + +This quick start will focus on the server-side use case for brevity. +Review [full docs](https://nla.zapier.com/start/) for user-facing oauth developer +support. + +Typically, you'd use SequentialChain, here's a basic example: + + 1. Use NLA to find an email in Gmail + 2. Use LLMChain to generate a draft reply to (1) + 3. Use NLA to send the draft reply (2) to someone in Slack via direct message + +In code, below: + +```python + +import os + +# get from https://platform.openai.com/ +os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "") + +# get from https://nla.zapier.com/docs/authentication/ +os.environ["ZAPIER_NLA_API_KEY"] = os.environ.get("ZAPIER_NLA_API_KEY", "") + +from langchain_community.agent_toolkits import ZapierToolkit +from langchain_community.utilities.zapier import ZapierNLAWrapper + +## step 0. expose gmail 'find email' and slack 'send channel message' actions + +# first go here, log in, expose (enable) the two actions: +# https://nla.zapier.com/demo/start +# -- for this example, can leave all fields "Have AI guess" +# in an oauth scenario, you'd get your own id (instead of 'demo') +# which you route your users through first + +zapier = ZapierNLAWrapper() +## To leverage OAuth you may pass the value `nla_oauth_access_token` to +## the ZapierNLAWrapper. If you do this there is no need to initialize +## the ZAPIER_NLA_API_KEY env variable +# zapier = ZapierNLAWrapper(zapier_nla_oauth_access_token="TOKEN_HERE") +toolkit = ZapierToolkit.from_zapier_nla_wrapper(zapier) +``` + +""" +from typing import Any, Dict, Optional + +from langchain_core._api import warn_deprecated +from langchain_core.callbacks import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.tools import BaseTool + +from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT +from langchain_community.utilities.zapier import ZapierNLAWrapper + + +class ZapierNLARunAction(BaseTool): + """ + Args: + action_id: a specific action ID (from list actions) of the action to execute + (the set api_key must be associated with the action owner) + instructions: a natural language instruction string for using the action + (eg. "get the latest email from Mike Knoop" for "Gmail: find email" action) + params: a dict, optional. Any params provided will *override* AI guesses + from `instructions` (see "understanding the AI guessing flow" here: + https://nla.zapier.com/docs/using-the-api#ai-guessing) + + """ + + api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) + action_id: str + params: Optional[dict] = None + base_prompt: str = BASE_ZAPIER_TOOL_PROMPT + zapier_description: str + params_schema: Dict[str, str] = Field(default_factory=dict) + name: str = "" + description: str = "" + + @root_validator + def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]: + zapier_description = values["zapier_description"] + params_schema = values["params_schema"] + if "instructions" in params_schema: + del params_schema["instructions"] + + # Ensure base prompt (if overridden) contains necessary input fields + necessary_fields = {"{zapier_description}", "{params}"} + if not all(field in values["base_prompt"] for field in necessary_fields): + raise ValueError( + "Your custom base Zapier prompt must contain input fields for " + "{zapier_description} and {params}." + ) + + values["name"] = zapier_description + values["description"] = values["base_prompt"].format( + zapier_description=zapier_description, + params=str(list(params_schema.keys())), + ) + return values + + def _run( + self, instructions: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + """Use the Zapier NLA tool to return a list of all exposed user actions.""" + warn_deprecated( + since="0.0.319", + message=( + "This tool will be deprecated on 2023-11-17. See " + "https://nla.zapier.com/sunset/ for details" + ), + ) + return self.api_wrapper.run_as_str(self.action_id, instructions, self.params) + + async def _arun( + self, + instructions: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + """Use the Zapier NLA tool to return a list of all exposed user actions.""" + warn_deprecated( + since="0.0.319", + message=( + "This tool will be deprecated on 2023-11-17. See " + "https://nla.zapier.com/sunset/ for details" + ), + ) + return await self.api_wrapper.arun_as_str( + self.action_id, + instructions, + self.params, + ) + + +ZapierNLARunAction.__doc__ = ( + ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore +) + + +# other useful actions + + +class ZapierNLAListActions(BaseTool): + """ + Args: + None + + """ + + name: str = "ZapierNLA_list_actions" + description: str = BASE_ZAPIER_TOOL_PROMPT + ( + "This tool returns a list of the user's exposed actions." + ) + api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) + + def _run( + self, + _: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the Zapier NLA tool to return a list of all exposed user actions.""" + warn_deprecated( + since="0.0.319", + message=( + "This tool will be deprecated on 2023-11-17. See " + "https://nla.zapier.com/sunset/ for details" + ), + ) + return self.api_wrapper.list_as_str() + + async def _arun( + self, + _: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + """Use the Zapier NLA tool to return a list of all exposed user actions.""" + warn_deprecated( + since="0.0.319", + message=( + "This tool will be deprecated on 2023-11-17. See " + "https://nla.zapier.com/sunset/ for details" + ), + ) + return await self.api_wrapper.alist_as_str() + + +ZapierNLAListActions.__doc__ = ( + ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore +) diff --git a/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py b/.scripts/community_split/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py similarity index 87% rename from libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py rename to .scripts/community_split/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py index 8605e8419ed..3dc6fcea19a 100644 --- a/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -3,20 +3,12 @@ import asyncio import os from aiohttp import ClientSession +from langchain_core.callbacks.manager import atrace_as_chain_group, trace_as_chain_group +from langchain_core.tracers.context import tracing_v2_enabled, tracing_enabled from langchain_core.prompts import PromptTemplate -from langchain.agents import AgentType, initialize_agent, load_tools -from langchain.callbacks import tracing_enabled -from langchain.callbacks.manager import ( - atrace_as_chain_group, - trace_as_chain_group, - tracing_v2_enabled, -) -from langchain.chains import LLMChain -from langchain.chains.constitutional_ai.base import ConstitutionalChain -from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple -from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI +from langchain_community.chat_models import ChatOpenAI +from langchain_community.llms import OpenAI questions = [ ( @@ -40,6 +32,7 @@ questions = [ def test_tracing_sequential() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_TRACING"] = "true" for q in questions[:3]: @@ -52,6 +45,7 @@ def test_tracing_sequential() -> None: def test_tracing_session_env_var() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_TRACING"] = "true" os.environ["LANGCHAIN_SESSION"] = "my_session" @@ -66,6 +60,7 @@ def test_tracing_session_env_var() -> None: async def test_tracing_concurrent() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_TRACING"] = "true" aiosession = ClientSession() llm = OpenAI(temperature=0) @@ -79,6 +74,7 @@ async def test_tracing_concurrent() -> None: async def test_tracing_concurrent_bw_compat_environ() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_HANDLER"] = "langchain" if "LANGCHAIN_TRACING" in os.environ: del os.environ["LANGCHAIN_TRACING"] @@ -96,6 +92,7 @@ async def test_tracing_concurrent_bw_compat_environ() -> None: def test_tracing_context_manager() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools llm = OpenAI(temperature=0) tools = load_tools(["llm-math", "serpapi"], llm=llm) agent = initialize_agent( @@ -111,6 +108,7 @@ def test_tracing_context_manager() -> None: async def test_tracing_context_manager_async() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools llm = OpenAI(temperature=0) async_tools = load_tools(["llm-math", "serpapi"], llm=llm) agent = initialize_agent( @@ -130,6 +128,7 @@ async def test_tracing_context_manager_async() -> None: async def test_tracing_v2_environment_variable() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_TRACING_V2"] = "true" aiosession = ClientSession() @@ -144,6 +143,7 @@ async def test_tracing_v2_environment_variable() -> None: def test_tracing_v2_context_manager() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools llm = ChatOpenAI(temperature=0) tools = load_tools(["llm-math", "serpapi"], llm=llm) agent = initialize_agent( @@ -158,6 +158,9 @@ def test_tracing_v2_context_manager() -> None: def test_tracing_v2_chain_with_tags() -> None: + from langchain.chains.llm import LLMChain + from langchain.chains.constitutional_ai.base import ConstitutionalChain + from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple llm = OpenAI(temperature=0) chain = ConstitutionalChain.from_llm( llm, @@ -177,6 +180,7 @@ def test_tracing_v2_chain_with_tags() -> None: def test_tracing_v2_agent_with_metadata() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_TRACING_V2"] = "true" llm = OpenAI(temperature=0) chat = ChatOpenAI(temperature=0) @@ -192,6 +196,7 @@ def test_tracing_v2_agent_with_metadata() -> None: async def test_tracing_v2_async_agent_with_metadata() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_TRACING_V2"] = "true" llm = OpenAI(temperature=0, metadata={"f": "g", "h": "i"}) chat = ChatOpenAI(temperature=0, metadata={"f": "g", "h": "i"}) @@ -210,6 +215,7 @@ async def test_tracing_v2_async_agent_with_metadata() -> None: def test_trace_as_group() -> None: + from langchain.chains.llm import LLMChain llm = OpenAI(temperature=0.9) prompt = PromptTemplate( input_variables=["product"], @@ -228,6 +234,7 @@ def test_trace_as_group() -> None: def test_trace_as_group_with_env_set() -> None: + from langchain.chains.llm import LLMChain os.environ["LANGCHAIN_TRACING_V2"] = "true" llm = OpenAI(temperature=0.9) prompt = PromptTemplate( @@ -251,6 +258,7 @@ def test_trace_as_group_with_env_set() -> None: async def test_trace_as_group_async() -> None: + from langchain.chains.llm import LLMChain llm = OpenAI(temperature=0.9) prompt = PromptTemplate( input_variables=["product"], diff --git a/libs/langchain/tests/integration_tests/callbacks/test_openai_callback.py b/.scripts/community_split/libs/community/tests/integration_tests/callbacks/test_openai_callback.py similarity index 91% rename from libs/langchain/tests/integration_tests/callbacks/test_openai_callback.py rename to .scripts/community_split/libs/community/tests/integration_tests/callbacks/test_openai_callback.py index e2bef107b45..f0908bbf8d0 100644 --- a/libs/langchain/tests/integration_tests/callbacks/test_openai_callback.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/callbacks/test_openai_callback.py @@ -1,9 +1,9 @@ """Integration tests for the langchain tracer module.""" import asyncio -from langchain.agents import AgentType, initialize_agent, load_tools -from langchain.callbacks import get_openai_callback -from langchain.llms import OpenAI + +from langchain_community.callbacks import get_openai_callback +from langchain_community.llms import OpenAI async def test_openai_callback() -> None: @@ -51,6 +51,7 @@ def test_openai_callback_batch_llm() -> None: def test_openai_callback_agent() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools llm = OpenAI(temperature=0) tools = load_tools(["serpapi", "llm-math"], llm=llm) agent = initialize_agent( diff --git a/libs/langchain/tests/integration_tests/callbacks/test_streamlit_callback.py b/.scripts/community_split/libs/community/tests/integration_tests/callbacks/test_streamlit_callback.py similarity index 73% rename from libs/langchain/tests/integration_tests/callbacks/test_streamlit_callback.py rename to .scripts/community_split/libs/community/tests/integration_tests/callbacks/test_streamlit_callback.py index 8008c5cc978..1ffe61dbdcf 100644 --- a/libs/langchain/tests/integration_tests/callbacks/test_streamlit_callback.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/callbacks/test_streamlit_callback.py @@ -2,19 +2,18 @@ import pytest -from langchain.agents import AgentType, initialize_agent, load_tools - # Import the internal StreamlitCallbackHandler from its module - and not from -# the `langchain.callbacks.streamlit` package - so that we don't end up using +# the `langchain_community.callbacks.streamlit` package - so that we don't end up using # Streamlit's externally-provided callback handler. -from langchain.callbacks.streamlit.streamlit_callback_handler import ( +from langchain_community.callbacks.streamlit.streamlit_callback_handler import ( StreamlitCallbackHandler, ) -from langchain.llms import OpenAI +from langchain_community.llms import OpenAI @pytest.mark.requires("streamlit") def test_streamlit_callback_agent() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools import streamlit as st streamlit_callback = StreamlitCallbackHandler(st.container()) diff --git a/libs/langchain/tests/integration_tests/callbacks/test_wandb_tracer.py b/.scripts/community_split/libs/community/tests/integration_tests/callbacks/test_wandb_tracer.py similarity index 87% rename from libs/langchain/tests/integration_tests/callbacks/test_wandb_tracer.py rename to .scripts/community_split/libs/community/tests/integration_tests/callbacks/test_wandb_tracer.py index ec5a5791f2d..02f022c62ad 100644 --- a/libs/langchain/tests/integration_tests/callbacks/test_wandb_tracer.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/callbacks/test_wandb_tracer.py @@ -3,10 +3,9 @@ import asyncio import os from aiohttp import ClientSession +from langchain_community.callbacks import wandb_tracing_enabled -from langchain.agents import AgentType, initialize_agent, load_tools -from langchain.callbacks.manager import wandb_tracing_enabled -from langchain.llms import OpenAI +from langchain_community.llms import OpenAI questions = [ ( @@ -30,6 +29,7 @@ questions = [ def test_tracing_sequential() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_WANDB_TRACING"] = "true" os.environ["WANDB_PROJECT"] = "langchain-tracing" @@ -46,6 +46,7 @@ def test_tracing_sequential() -> None: def test_tracing_session_env_var() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_WANDB_TRACING"] = "true" llm = OpenAI(temperature=0) @@ -60,6 +61,7 @@ def test_tracing_session_env_var() -> None: async def test_tracing_concurrent() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools os.environ["LANGCHAIN_WANDB_TRACING"] = "true" aiosession = ClientSession() llm = OpenAI(temperature=0) @@ -77,6 +79,7 @@ async def test_tracing_concurrent() -> None: def test_tracing_context_manager() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools llm = OpenAI(temperature=0) tools = load_tools( ["llm-math", "serpapi"], @@ -94,6 +97,7 @@ def test_tracing_context_manager() -> None: async def test_tracing_context_manager_async() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools llm = OpenAI(temperature=0) async_tools = load_tools( ["llm-math", "serpapi"], diff --git a/libs/langchain/tests/integration_tests/chat_models/test_openai.py b/.scripts/community_split/libs/community/tests/integration_tests/chat_models/test_openai.py similarity index 72% rename from libs/langchain/tests/integration_tests/chat_models/test_openai.py rename to .scripts/community_split/libs/community/tests/integration_tests/chat_models/test_openai.py index 8edf4f22c84..2ea93c6d9e2 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_openai.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/chat_models/test_openai.py @@ -1,25 +1,18 @@ """Test ChatOpenAI wrapper.""" -from typing import Any, List, Optional, Union +from typing import Any, Optional import pytest -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.callbacks import CallbackManager +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.outputs import ( ChatGeneration, - ChatGenerationChunk, ChatResult, - GenerationChunk, LLMResult, ) -from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field -from langchain.callbacks.base import AsyncCallbackHandler -from langchain.callbacks.manager import CallbackManager -from langchain.chains.openai_functions import ( - create_openai_fn_chain, -) -from langchain.chat_models.openai import ChatOpenAI -from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser +from langchain_community.chat_models.openai import ChatOpenAI from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -210,106 +203,6 @@ async def test_async_chat_openai_streaming() -> None: assert generation.text == generation.message.content -@pytest.mark.scheduled -async def test_async_chat_openai_streaming_with_function() -> None: - """Test ChatOpenAI wrapper with multiple completions.""" - - class MyCustomAsyncHandler(AsyncCallbackHandler): - def __init__(self) -> None: - super().__init__() - self._captured_tokens: List[str] = [] - self._captured_chunks: List[ - Optional[Union[ChatGenerationChunk, GenerationChunk]] - ] = [] - - def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[ChatGenerationChunk, GenerationChunk]] = None, - **kwargs: Any, - ) -> Any: - self._captured_tokens.append(token) - self._captured_chunks.append(chunk) - - json_schema = { - "title": "Person", - "description": "Identifying information about a person.", - "type": "object", - "properties": { - "name": { - "title": "Name", - "description": "The person's name", - "type": "string", - }, - "age": { - "title": "Age", - "description": "The person's age", - "type": "integer", - }, - "fav_food": { - "title": "Fav Food", - "description": "The person's favorite food", - "type": "string", - }, - }, - "required": ["name", "age"], - } - - callback_handler = MyCustomAsyncHandler() - callback_manager = CallbackManager([callback_handler]) - - chat = ChatOpenAI( - max_tokens=10, - n=1, - callback_manager=callback_manager, - streaming=True, - ) - - prompt_msgs = [ - SystemMessage( - content="You are a world class algorithm for " - "extracting information in structured formats." - ), - HumanMessage( - content="Use the given format to extract " - "information from the following input:" - ), - HumanMessagePromptTemplate.from_template("{input}"), - HumanMessage(content="Tips: Make sure to answer in the correct format"), - ] - prompt = ChatPromptTemplate(messages=prompt_msgs) - - function: Any = { - "name": "output_formatter", - "description": ( - "Output formatter. Should always be used to format your response to the" - " user." - ), - "parameters": json_schema, - } - chain = create_openai_fn_chain( - [function], - chat, - prompt, - output_parser=None, - ) - - message = HumanMessage(content="Sally is 13 years old") - response = await chain.agenerate([{"input": message}]) - - assert isinstance(response, LLMResult) - assert len(response.generations) == 1 - for generations in response.generations: - assert len(generations) == 1 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - assert len(callback_handler._captured_tokens) > 0 - assert len(callback_handler._captured_chunks) > 0 - assert all([chunk is not None for chunk in callback_handler._captured_chunks]) - @pytest.mark.scheduled async def test_async_chat_openai_bind_functions() -> None: @@ -337,7 +230,7 @@ async def test_async_chat_openai_bind_functions() -> None: ] ) - chain = prompt | chat | JsonOutputFunctionsParser(args_only=True) + chain = prompt | chat message = HumanMessage(content="Sally is 13 years old") response = await chain.abatch([{"input": message}]) @@ -345,9 +238,7 @@ async def test_async_chat_openai_bind_functions() -> None: assert isinstance(response, list) assert len(response) == 1 for generation in response: - assert isinstance(generation, dict) - assert "name" in generation - assert "age" in generation + assert isinstance(generation, AIMessage) def test_chat_openai_extra_kwargs() -> None: diff --git a/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py b/.scripts/community_split/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py similarity index 89% rename from libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py rename to .scripts/community_split/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py index fe22694f97d..88bfc66a382 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py @@ -2,6 +2,7 @@ from typing import Any +from langchain_core.callbacks import CallbackManager from langchain_core.messages import ( AIMessage, BaseMessage, @@ -11,11 +12,7 @@ from langchain_core.messages import ( from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.callbacks.manager import CallbackManager -from langchain.chains.openai_functions import ( - create_openai_fn_chain, -) -from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint +from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler _FUNCTIONS: Any = [ @@ -185,18 +182,12 @@ def test_functions_call_thoughts() -> None: ] prompt = ChatPromptTemplate(messages=prompt_msgs) - chain = create_openai_fn_chain( - _FUNCTIONS, - chat, - prompt, - output_parser=None, - ) + chain = prompt | chat.bind(functions=_FUNCTIONS) message = HumanMessage(content="What's the temperature in Shanghai today?") - response = chain.generate([{"input": message}]) - assert isinstance(response.generations[0][0], ChatGeneration) - assert isinstance(response.generations[0][0].message, AIMessage) - assert "function_call" in response.generations[0][0].message.additional_kwargs + response = chain.batch([{"input": message}]) + assert isinstance(response[0], AIMessage) + assert "function_call" in response[0].additional_kwargs def test_functions_call() -> None: @@ -223,11 +214,6 @@ def test_functions_call() -> None: ), ] ) - llm_chain = create_openai_fn_chain( - _FUNCTIONS, - chat, - prompt, - output_parser=None, - ) - resp = llm_chain.generate([{}]) - assert isinstance(resp, LLMResult) + chain = prompt | chat.bind(functions=_FUNCTIONS) + resp = chain.invoke({}) + assert isinstance(resp, AIMessage) diff --git a/libs/langchain/tests/integration_tests/document_loaders/parsers/test_language.py b/.scripts/community_split/libs/community/tests/integration_tests/document_loaders/parsers/test_language.py similarity index 90% rename from libs/langchain/tests/integration_tests/document_loaders/parsers/test_language.py rename to .scripts/community_split/libs/community/tests/integration_tests/document_loaders/parsers/test_language.py index c0de58d683d..c28789c7cd3 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/parsers/test_language.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/document_loaders/parsers/test_language.py @@ -2,10 +2,9 @@ from pathlib import Path import pytest -from langchain.document_loaders.concurrent import ConcurrentLoader -from langchain.document_loaders.generic import GenericLoader -from langchain.document_loaders.parsers import LanguageParser -from langchain.text_splitter import Language +from langchain_community.document_loaders.concurrent import ConcurrentLoader +from langchain_community.document_loaders.generic import GenericLoader +from langchain_community.document_loaders.parsers import LanguageParser def test_language_loader_for_python() -> None: @@ -55,7 +54,7 @@ def test_language_loader_for_python_with_parser_threshold() -> None: loader = GenericLoader.from_filesystem( file_path, glob="hello_world.py", - parser=LanguageParser(language=Language.PYTHON, parser_threshold=1000), + parser=LanguageParser(language="python", parser_threshold=1000), ) docs = loader.load() @@ -127,7 +126,7 @@ def test_language_loader_for_javascript_with_parser_threshold() -> None: loader = GenericLoader.from_filesystem( file_path, glob="hello_world.js", - parser=LanguageParser(language=Language.JS, parser_threshold=1000), + parser=LanguageParser(language="js", parser_threshold=1000), ) docs = loader.load() @@ -140,7 +139,7 @@ def test_concurrent_language_loader_for_javascript_with_parser_threshold() -> No loader = ConcurrentLoader.from_filesystem( file_path, glob="hello_world.js", - parser=LanguageParser(language=Language.JS, parser_threshold=1000), + parser=LanguageParser(language="js", parser_threshold=1000), ) docs = loader.load() @@ -153,7 +152,7 @@ def test_concurrent_language_loader_for_python_with_parser_threshold() -> None: loader = ConcurrentLoader.from_filesystem( file_path, glob="hello_world.py", - parser=LanguageParser(language=Language.PYTHON, parser_threshold=1000), + parser=LanguageParser(language="python", parser_threshold=1000), ) docs = loader.load() diff --git a/.scripts/community_split/libs/community/tests/integration_tests/llms/test_fireworks.py b/.scripts/community_split/libs/community/tests/integration_tests/llms/test_fireworks.py new file mode 100644 index 00000000000..d7839b50645 --- /dev/null +++ b/.scripts/community_split/libs/community/tests/integration_tests/llms/test_fireworks.py @@ -0,0 +1,136 @@ +"""Test Fireworks AI API Wrapper.""" +from typing import Generator + +import pytest +from langchain_core.outputs import LLMResult + +from langchain_community.llms.fireworks import Fireworks + +@pytest.fixture +def llm() -> Fireworks: + return Fireworks(model_kwargs={"temperature": 0, "max_tokens": 512}) + + +@pytest.mark.scheduled +def test_fireworks_call(llm: Fireworks) -> None: + """Test valid call to fireworks.""" + output = llm("How is the weather in New York today?") + assert isinstance(output, str) + + +@pytest.mark.scheduled +def test_fireworks_model_param() -> None: + """Tests model parameters for Fireworks""" + llm = Fireworks(model="foo") + assert llm.model == "foo" + + +@pytest.mark.scheduled +def test_fireworks_invoke(llm: Fireworks) -> None: + """Tests completion with invoke""" + output = llm.invoke("How is the weather in New York today?", stop=[","]) + assert isinstance(output, str) + assert output[-1] == "," + + +@pytest.mark.scheduled +async def test_fireworks_ainvoke(llm: Fireworks) -> None: + """Tests completion with invoke""" + output = await llm.ainvoke("How is the weather in New York today?", stop=[","]) + assert isinstance(output, str) + assert output[-1] == "," + + +@pytest.mark.scheduled +def test_fireworks_batch(llm: Fireworks) -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = llm.batch( + [ + "How is the weather in New York today?", + "How is the weather in New York today?", + ], + stop=[","], + ) + for token in output: + assert isinstance(token, str) + assert token[-1] == "," + + +@pytest.mark.scheduled +async def test_fireworks_abatch(llm: Fireworks) -> None: + """Tests completion with invoke""" + output = await llm.abatch( + [ + "How is the weather in New York today?", + "How is the weather in New York today?", + ], + stop=[","], + ) + for token in output: + assert isinstance(token, str) + assert token[-1] == "," + + +@pytest.mark.scheduled +def test_fireworks_multiple_prompts( + llm: Fireworks, +) -> None: + """Test completion with multiple prompts.""" + output = llm.generate(["How is the weather in New York today?", "I'm pickle rick"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) + assert len(output.generations) == 2 + + +@pytest.mark.scheduled +def test_fireworks_streaming(llm: Fireworks) -> None: + """Test stream completion.""" + generator = llm.stream("Who's the best quarterback in the NFL?") + assert isinstance(generator, Generator) + + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.scheduled +def test_fireworks_streaming_stop_words(llm: Fireworks) -> None: + """Test stream completion with stop words.""" + generator = llm.stream("Who's the best quarterback in the NFL?", stop=[","]) + assert isinstance(generator, Generator) + + last_token = "" + for token in generator: + last_token = token + assert isinstance(token, str) + assert last_token[-1] == "," + + +@pytest.mark.scheduled +async def test_fireworks_streaming_async(llm: Fireworks) -> None: + """Test stream completion.""" + + last_token = "" + async for token in llm.astream( + "Who's the best quarterback in the NFL?", stop=[","] + ): + last_token = token + assert isinstance(token, str) + assert last_token[-1] == "," + + +@pytest.mark.scheduled +async def test_fireworks_async_agenerate(llm: Fireworks) -> None: + """Test async.""" + output = await llm.agenerate(["What is the best city to live in California?"]) + assert isinstance(output, LLMResult) + + +@pytest.mark.scheduled +async def test_fireworks_multiple_prompts_async_agenerate(llm: Fireworks) -> None: + output = await llm.agenerate( + ["How is the weather in New York today?", "I'm pickle rick"] + ) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) + assert len(output.generations) == 2 diff --git a/libs/langchain/tests/integration_tests/llms/test_opaqueprompts.py b/.scripts/community_split/libs/community/tests/integration_tests/llms/test_opaqueprompts.py similarity index 87% rename from libs/langchain/tests/integration_tests/llms/test_opaqueprompts.py rename to .scripts/community_split/libs/community/tests/integration_tests/llms/test_opaqueprompts.py index e6418f426e7..69d851765db 100644 --- a/libs/langchain/tests/integration_tests/llms/test_opaqueprompts.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/llms/test_opaqueprompts.py @@ -1,12 +1,10 @@ +import langchain_community.utilities.opaqueprompts as op from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnableParallel -import langchain.utilities.opaqueprompts as op -from langchain.chains.llm import LLMChain -from langchain.llms import OpenAI -from langchain.llms.opaqueprompts import OpaquePrompts -from langchain.memory import ConversationBufferWindowMemory +from langchain_community.llms import OpenAI +from langchain_community.llms.opaqueprompts import OpaquePrompts prompt_template = """ As an AI assistant, you will answer questions according to given context. @@ -45,13 +43,8 @@ Question: ```{question}``` def test_opaqueprompts() -> None: - chain = LLMChain( - prompt=PromptTemplate.from_template(prompt_template), - llm=OpaquePrompts(llm=OpenAI()), - memory=ConversationBufferWindowMemory(k=2), - ) - - output = chain.run( + chain = PromptTemplate.from_template(prompt_template) | OpaquePrompts(llm=OpenAI()) + output = chain.invoke( { "question": "Write a text message to remind John to do password reset \ for his website through his email to stay secure." diff --git a/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py b/.scripts/community_split/libs/community/tests/integration_tests/llms/test_symblai_nebula.py similarity index 85% rename from libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py rename to .scripts/community_split/libs/community/tests/integration_tests/llms/test_symblai_nebula.py index 6f623bd6e28..b1b2eb3b535 100644 --- a/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/llms/test_symblai_nebula.py @@ -1,8 +1,5 @@ """Test Nebula API wrapper.""" -from langchain_core.prompts.prompt import PromptTemplate - -from langchain.chains.llm import LLMChain -from langchain.llms.symblai_nebula import Nebula +from langchain_community.llms.symblai_nebula import Nebula def test_symblai_nebula_call() -> None: @@ -41,7 +38,5 @@ Rhea: Thanks, bye!""" instruction = """Identify the main objectives mentioned in this conversation.""" - prompt = PromptTemplate.from_template(template="{instruction}\n{conversation}") - llm_chain = LLMChain(prompt=prompt, llm=llm) - output = llm_chain.run(instruction=instruction, conversation=conversation) + output = llm.invoke(f"{instruction}\n{conversation}") assert isinstance(output, str) diff --git a/libs/langchain/tests/integration_tests/llms/test_vertexai.py b/.scripts/community_split/libs/community/tests/integration_tests/llms/test_vertexai.py similarity index 80% rename from libs/langchain/tests/integration_tests/llms/test_vertexai.py rename to .scripts/community_split/libs/community/tests/integration_tests/llms/test_vertexai.py index 1408d71ffa7..ae5f776b4ba 100644 --- a/libs/langchain/tests/integration_tests/llms/test_vertexai.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/llms/test_vertexai.py @@ -9,12 +9,9 @@ import os from typing import Optional import pytest -from langchain_core.documents import Document from langchain_core.outputs import LLMResult -from pytest_mock import MockerFixture -from langchain.chains.summarize import load_summarize_chain -from langchain.llms import VertexAI, VertexAIModelGarden +from langchain_community.llms import VertexAI, VertexAIModelGarden def test_vertex_initialization() -> None: @@ -152,31 +149,3 @@ def test_vertex_call_count_tokens() -> None: llm = VertexAI() output = llm.get_num_tokens("How are you?") assert output == 4 - - -@pytest.mark.requires("google.cloud.aiplatform") -def test_get_num_tokens_be_called_when_using_mapreduce_chain( - mocker: MockerFixture, -) -> None: - from vertexai.language_models._language_models import CountTokensResponse - - m1 = mocker.patch( - "vertexai.preview.language_models._PreviewTextGenerationModel.count_tokens", - return_value=CountTokensResponse( - total_tokens=2, - total_billable_characters=2, - _count_tokens_response={"total_tokens": 2, "total_billable_characters": 2}, - ), - ) - llm = VertexAI() - chain = load_summarize_chain( - llm, - chain_type="map_reduce", - return_intermediate_steps=False, - ) - doc = Document(page_content="Hi") - output = chain({"input_documents": [doc]}) - assert isinstance(output["output_text"], str) - m1.assert_called_once() - assert llm._llm_type == "vertexai" - assert llm.model_name == llm.client._model_id diff --git a/libs/langchain/tests/integration_tests/utilities/test_arxiv.py b/.scripts/community_split/libs/community/tests/integration_tests/utilities/test_arxiv.py similarity index 96% rename from libs/langchain/tests/integration_tests/utilities/test_arxiv.py rename to .scripts/community_split/libs/community/tests/integration_tests/utilities/test_arxiv.py index b1ca9de7330..536ba323bfb 100644 --- a/libs/langchain/tests/integration_tests/utilities/test_arxiv.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/utilities/test_arxiv.py @@ -3,11 +3,10 @@ from typing import Any, List import pytest from langchain_core.documents import Document +from langchain_core.tools import BaseTool -from langchain.agents.load_tools import load_tools -from langchain.tools import ArxivQueryRun -from langchain.tools.base import BaseTool -from langchain.utilities import ArxivAPIWrapper +from langchain_community.tools import ArxivQueryRun +from langchain_community.utilities import ArxivAPIWrapper @pytest.fixture @@ -142,6 +141,7 @@ def test_load_returns_full_set_of_metadata() -> None: def _load_arxiv_from_universal_entry(**kwargs: Any) -> BaseTool: + from langchain.agents.load_tools import load_tools tools = load_tools(["arxiv"], **kwargs) assert len(tools) == 1, "loaded more than 1 tool" return tools[0] diff --git a/libs/langchain/tests/integration_tests/utilities/test_pubmed.py b/.scripts/community_split/libs/community/tests/integration_tests/utilities/test_pubmed.py similarity index 96% rename from libs/langchain/tests/integration_tests/utilities/test_pubmed.py rename to .scripts/community_split/libs/community/tests/integration_tests/utilities/test_pubmed.py index 52c8d8cdf88..bed2681e681 100644 --- a/libs/langchain/tests/integration_tests/utilities/test_pubmed.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/utilities/test_pubmed.py @@ -3,11 +3,10 @@ from typing import Any, List import pytest from langchain_core.documents import Document +from langchain_core.tools import BaseTool -from langchain.agents.load_tools import load_tools -from langchain.tools import PubmedQueryRun -from langchain.tools.base import BaseTool -from langchain.utilities import PubMedAPIWrapper +from langchain_community.tools import PubmedQueryRun +from langchain_community.utilities import PubMedAPIWrapper xmltodict = pytest.importorskip("xmltodict") @@ -135,6 +134,7 @@ def test_load_returns_full_set_of_metadata() -> None: def _load_pubmed_from_universal_entry(**kwargs: Any) -> BaseTool: + from langchain.agents.load_tools import load_tools tools = load_tools(["pubmed"], **kwargs) assert len(tools) == 1, "loaded more than 1 tool" return tools[0] diff --git a/libs/langchain/tests/integration_tests/vectorstores/conftest.py b/.scripts/community_split/libs/community/tests/integration_tests/vectorstores/conftest.py similarity index 54% rename from libs/langchain/tests/integration_tests/vectorstores/conftest.py rename to .scripts/community_split/libs/community/tests/integration_tests/vectorstores/conftest.py index 5899e4ec146..a2fc9053128 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/conftest.py +++ b/.scripts/community_split/libs/community/tests/integration_tests/vectorstores/conftest.py @@ -1,14 +1,9 @@ import os -from typing import Generator, List, Union +from typing import Union import pytest -from langchain_core.documents import Document from vcr.request import Request -from langchain.document_loaders import TextLoader -from langchain.embeddings import OpenAIEmbeddings -from langchain.text_splitter import CharacterTextSplitter - # Those environment variables turn on Deep Lake pytest mode. # It significantly makes tests run much faster. # Need to run before `import deeplake` @@ -47,35 +42,3 @@ def vcr_config() -> dict: ], "ignore_localhost": True, } - - -# Define a fixture that yields a generator object returning a list of documents -@pytest.fixture(scope="function") -def documents() -> Generator[List[Document], None, None]: - """Return a generator that yields a list of documents.""" - - # Create a CharacterTextSplitter object for splitting the documents into chunks - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - - # Load the documents from a file located in the fixtures directory - documents = TextLoader( - os.path.join(os.path.dirname(__file__), "fixtures", "sharks.txt") - ).load() - - # Yield the documents split into chunks - yield text_splitter.split_documents(documents) - - -@pytest.fixture(scope="function") -def texts() -> Generator[List[str], None, None]: - # Load the documents from a file located in the fixtures directory - documents = TextLoader( - os.path.join(os.path.dirname(__file__), "fixtures", "sharks.txt") - ).load() - - yield [doc.page_content for doc in documents] - - -@pytest.fixture(scope="module") -def embedding_openai() -> OpenAIEmbeddings: - return OpenAIEmbeddings() diff --git a/.scripts/community_split/libs/community/tests/unit_tests/callbacks/test_callback_manager.py b/.scripts/community_split/libs/community/tests/unit_tests/callbacks/test_callback_manager.py new file mode 100644 index 00000000000..f63c0ab8cc7 --- /dev/null +++ b/.scripts/community_split/libs/community/tests/unit_tests/callbacks/test_callback_manager.py @@ -0,0 +1,85 @@ +"""Test CallbackManager.""" +from unittest.mock import patch + +import pytest +from langchain_community.callbacks import get_openai_callback +from langchain_core.callbacks.manager import trace_as_chain_group, CallbackManager +from langchain_core.outputs import LLMResult +from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers +from langchain_community.llms.openai import BaseOpenAI + + +def test_callback_manager_configure_context_vars( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test callback manager configuration.""" + monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true") + monkeypatch.setenv("LANGCHAIN_TRACING", "false") + with patch.object(LangChainTracer, "_update_run_single"): + with patch.object(LangChainTracer, "_persist_run_single"): + with trace_as_chain_group("test") as group_manager: + assert len(group_manager.handlers) == 1 + tracer = group_manager.handlers[0] + assert isinstance(tracer, LangChainTracer) + + with get_openai_callback() as cb: + # This is a new empty callback handler + assert cb.successful_requests == 0 + assert cb.total_tokens == 0 + + # configure adds this openai cb but doesn't modify the group manager + mngr = CallbackManager.configure(group_manager) + assert mngr.handlers == [tracer, cb] + assert group_manager.handlers == [tracer] + + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3, + }, + "model_name": BaseOpenAI.__fields__["model_name"].default, + }, + ) + mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) + + # The callback handler has been updated + assert cb.successful_requests == 1 + assert cb.total_tokens == 3 + assert cb.prompt_tokens == 2 + assert cb.completion_tokens == 1 + assert cb.total_cost > 0 + + with get_openai_callback() as cb: + # This is a new empty callback handler + assert cb.successful_requests == 0 + assert cb.total_tokens == 0 + + # configure adds this openai cb but doesn't modify the group manager + mngr = CallbackManager.configure(group_manager) + assert mngr.handlers == [tracer, cb] + assert group_manager.handlers == [tracer] + + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3, + }, + "model_name": BaseOpenAI.__fields__["model_name"].default, + }, + ) + mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) + + # The callback handler has been updated + assert cb.successful_requests == 1 + assert cb.total_tokens == 3 + assert cb.prompt_tokens == 2 + assert cb.completion_tokens == 1 + assert cb.total_cost > 0 + wait_for_all_tracers() + assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore diff --git a/.scripts/community_split/libs/community/tests/unit_tests/callbacks/test_imports.py b/.scripts/community_split/libs/community/tests/unit_tests/callbacks/test_imports.py new file mode 100644 index 00000000000..cfe420b3e82 --- /dev/null +++ b/.scripts/community_split/libs/community/tests/unit_tests/callbacks/test_imports.py @@ -0,0 +1,31 @@ +from langchain_community.callbacks import __all__ + +EXPECTED_ALL = [ + "AimCallbackHandler", + "ArgillaCallbackHandler", + "ArizeCallbackHandler", + "PromptLayerCallbackHandler", + "ArthurCallbackHandler", + "ClearMLCallbackHandler", + "CometCallbackHandler", + "ContextCallbackHandler", + "HumanApprovalCallbackHandler", + "InfinoCallbackHandler", + "MlflowCallbackHandler", + "LLMonitorCallbackHandler", + "OpenAICallbackHandler", + "LLMThoughtLabeler", + "StreamlitCallbackHandler", + "WandbCallbackHandler", + "WhyLabsCallbackHandler", + "get_openai_callback", + "wandb_tracing_enabled", + "FlyteCallbackHandler", + "SageMakerCallbackHandler", + "LabelStudioCallbackHandler", + "TrubricsCallbackHandler", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/chat_loaders/test_slack.py b/.scripts/community_split/libs/community/tests/unit_tests/chat_loaders/test_slack.py similarity index 88% rename from libs/langchain/tests/unit_tests/chat_loaders/test_slack.py rename to .scripts/community_split/libs/community/tests/unit_tests/chat_loaders/test_slack.py index cdf569d6090..afa4843ace4 100644 --- a/libs/langchain/tests/unit_tests/chat_loaders/test_slack.py +++ b/.scripts/community_split/libs/community/tests/unit_tests/chat_loaders/test_slack.py @@ -1,12 +1,11 @@ import pathlib -from langchain.chat_loaders import slack, utils +from langchain_community.chat_loaders import slack, utils def test_slack_chat_loader() -> None: chat_path = ( pathlib.Path(__file__).parents[2] - / "integration_tests" / "examples" / "slack_export.zip" ) diff --git a/.scripts/community_split/libs/community/tests/unit_tests/chat_models/test_bedrock.py b/.scripts/community_split/libs/community/tests/unit_tests/chat_models/test_bedrock.py new file mode 100644 index 00000000000..6767bad3792 --- /dev/null +++ b/.scripts/community_split/libs/community/tests/unit_tests/chat_models/test_bedrock.py @@ -0,0 +1,54 @@ +"""Test Anthropic Chat API wrapper.""" +from typing import List +from unittest.mock import MagicMock + +import pytest + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) + +from langchain_community.chat_models import BedrockChat +from langchain_community.chat_models.meta import convert_messages_to_prompt_llama + + +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ([HumanMessage(content="Hello")], "[INST] Hello [/INST]"), + ( + [HumanMessage(content="Hello"), AIMessage(content="Answer:")], + "[INST] Hello [/INST]\nAnswer:", + ), + ( + [ + SystemMessage(content="You're an assistant"), + HumanMessage(content="Hello"), + AIMessage(content="Answer:"), + ], + "<> You're an assistant <>\n[INST] Hello [/INST]\nAnswer:", + ), + ], +) +def test_formatting(messages: List[BaseMessage], expected: str) -> None: + result = convert_messages_to_prompt_llama(messages) + assert result == expected + + +def test_anthropic_bedrock() -> None: + client = MagicMock() + respbody = MagicMock( + read=MagicMock( + return_value=MagicMock( + decode=MagicMock(return_value=b'{"completion":"Hi back"}') + ) + ) + ) + client.invoke_model.return_value = {"body": respbody} + model = BedrockChat(model_id="anthropic.claude-v2", client=client) + + # should not throw an error + model.invoke("hello there") diff --git a/libs/langchain/tests/unit_tests/document_loaders/parsers/test_pdf_parsers.py b/.scripts/community_split/libs/community/tests/unit_tests/document_loaders/parsers/test_pdf_parsers.py similarity index 86% rename from libs/langchain/tests/unit_tests/document_loaders/parsers/test_pdf_parsers.py rename to .scripts/community_split/libs/community/tests/unit_tests/document_loaders/parsers/test_pdf_parsers.py index a7b90ffccd2..84802d8bc12 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/parsers/test_pdf_parsers.py +++ b/.scripts/community_split/libs/community/tests/unit_tests/document_loaders/parsers/test_pdf_parsers.py @@ -1,17 +1,25 @@ """Tests for the various PDF parsers.""" +from pathlib import Path from typing import Iterator import pytest -from langchain.document_loaders.base import BaseBlobParser -from langchain.document_loaders.blob_loaders import Blob -from langchain.document_loaders.parsers.pdf import ( +from langchain_community.document_loaders.base import BaseBlobParser +from langchain_community.document_loaders.blob_loaders import Blob +from langchain_community.document_loaders.parsers.pdf import ( PDFMinerParser, PyMuPDFParser, PyPDFium2Parser, PyPDFParser, ) -from tests.data import HELLO_PDF, LAYOUT_PARSER_PAPER_PDF + +_THIS_DIR = Path(__file__).parents[3] + +_EXAMPLES_DIR = _THIS_DIR / "examples" + +# Paths to test PDF files +HELLO_PDF = _EXAMPLES_DIR / "hello.pdf" +LAYOUT_PARSER_PAPER_PDF = _EXAMPLES_DIR / "layout-parser-paper.pdf" def _assert_with_parser(parser: BaseBlobParser, splits_by_page: bool = True) -> None: diff --git a/.scripts/community_split/libs/community/tests/unit_tests/embeddings/test_imports.py b/.scripts/community_split/libs/community/tests/unit_tests/embeddings/test_imports.py new file mode 100644 index 00000000000..d33d98e493b --- /dev/null +++ b/.scripts/community_split/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -0,0 +1,60 @@ +from langchain_community.embeddings import __all__ + +EXPECTED_ALL = [ + "OpenAIEmbeddings", + "AzureOpenAIEmbeddings", + "ClarifaiEmbeddings", + "CohereEmbeddings", + "DatabricksEmbeddings", + "ElasticsearchEmbeddings", + "FastEmbedEmbeddings", + "HuggingFaceEmbeddings", + "HuggingFaceInferenceAPIEmbeddings", + "InfinityEmbeddings", + "GradientEmbeddings", + "JinaEmbeddings", + "LlamaCppEmbeddings", + "HuggingFaceHubEmbeddings", + "MlflowAIGatewayEmbeddings", + "MlflowEmbeddings", + "ModelScopeEmbeddings", + "TensorflowHubEmbeddings", + "SagemakerEndpointEmbeddings", + "HuggingFaceInstructEmbeddings", + "MosaicMLInstructorEmbeddings", + "SelfHostedEmbeddings", + "SelfHostedHuggingFaceEmbeddings", + "SelfHostedHuggingFaceInstructEmbeddings", + "FakeEmbeddings", + "DeterministicFakeEmbedding", + "AlephAlphaAsymmetricSemanticEmbedding", + "AlephAlphaSymmetricSemanticEmbedding", + "SentenceTransformerEmbeddings", + "GooglePalmEmbeddings", + "MiniMaxEmbeddings", + "VertexAIEmbeddings", + "BedrockEmbeddings", + "DeepInfraEmbeddings", + "EdenAiEmbeddings", + "DashScopeEmbeddings", + "EmbaasEmbeddings", + "OctoAIEmbeddings", + "SpacyEmbeddings", + "NLPCloudEmbeddings", + "GPT4AllEmbeddings", + "XinferenceEmbeddings", + "LocalAIEmbeddings", + "AwaEmbeddings", + "HuggingFaceBgeEmbeddings", + "ErnieEmbeddings", + "JavelinAIGatewayEmbeddings", + "OllamaEmbeddings", + "QianfanEmbeddingsEndpoint", + "JohnSnowLabsEmbeddings", + "VoyageEmbeddings", + "BookendEmbeddings", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/.scripts/community_split/libs/community/tests/unit_tests/llms/test_openai.py b/.scripts/community_split/libs/community/tests/unit_tests/llms/test_openai.py new file mode 100644 index 00000000000..a14cc9651cb --- /dev/null +++ b/.scripts/community_split/libs/community/tests/unit_tests/llms/test_openai.py @@ -0,0 +1,56 @@ +import os + +import pytest + +from langchain_community.llms.openai import OpenAI +from langchain_community.utils.openai import is_openai_v1 + +os.environ["OPENAI_API_KEY"] = "foo" + + +def _openai_v1_installed() -> bool: + try: + return is_openai_v1() + except Exception as _: + return False + + +@pytest.mark.requires("openai") +def test_openai_model_param() -> None: + llm = OpenAI(model="foo") + assert llm.model_name == "foo" + llm = OpenAI(model_name="foo") + assert llm.model_name == "foo" + + +@pytest.mark.requires("openai") +def test_openai_model_kwargs() -> None: + llm = OpenAI(model_kwargs={"foo": "bar"}) + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("openai") +def test_openai_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + OpenAI(model_kwargs={"model_name": "foo"}) + + +@pytest.mark.requires("openai") +def test_openai_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = OpenAI(foo="bar") + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.fixture +def mock_completion() -> dict: + return { + "id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ", + "object": "text_completion", + "created": 1689989000, + "model": "text-davinci-003", + "choices": [ + {"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } diff --git a/libs/langchain/tests/unit_tests/retrievers/test_imports.py b/.scripts/community_split/libs/community/tests/unit_tests/retrievers/test_imports.py similarity index 76% rename from libs/langchain/tests/unit_tests/retrievers/test_imports.py rename to .scripts/community_split/libs/community/tests/unit_tests/retrievers/test_imports.py index a26d7d48918..04ebf72d5ea 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_imports.py +++ b/.scripts/community_split/libs/community/tests/unit_tests/retrievers/test_imports.py @@ -1,4 +1,4 @@ -from langchain.retrievers import __all__ +from langchain_community.retrievers import __all__ EXPECTED_ALL = [ "AmazonKendraRetriever", @@ -7,7 +7,6 @@ EXPECTED_ALL = [ "ArxivRetriever", "AzureCognitiveSearchRetriever", "ChatGPTPluginRetriever", - "ContextualCompressionRetriever", "ChaindeskRetriever", "CohereRagRetriever", "ElasticSearchBM25Retriever", @@ -20,31 +19,22 @@ EXPECTED_ALL = [ "KNNRetriever", "LlamaIndexGraphRetriever", "LlamaIndexRetriever", - "MergerRetriever", "MetalRetriever", "MilvusRetriever", - "MultiQueryRetriever", "OutlineRetriever", "PineconeHybridSearchRetriever", "PubMedRetriever", "RemoteLangChainRetriever", "SVMRetriever", - "SelfQueryRetriever", "TavilySearchAPIRetriever", "TFIDFRetriever", "BM25Retriever", - "TimeWeightedVectorStoreRetriever", "VespaRetriever", "WeaviateHybridSearchRetriever", "WikipediaRetriever", "ZepRetriever", "ZillizRetriever", "DocArrayRetriever", - "RePhraseQueryRetriever", - "WebResearchRetriever", - "EnsembleRetriever", - "ParentDocumentRetriever", - "MultiVectorRetriever", ] diff --git a/.scripts/community_split/libs/community/tests/unit_tests/storage/test_imports.py b/.scripts/community_split/libs/community/tests/unit_tests/storage/test_imports.py new file mode 100644 index 00000000000..68c47b76d69 --- /dev/null +++ b/.scripts/community_split/libs/community/tests/unit_tests/storage/test_imports.py @@ -0,0 +1,11 @@ +from langchain_community.storage import __all__ + +EXPECTED_ALL = [ + "RedisStore", + "UpstashRedisByteStore", + "UpstashRedisStore", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/tools/test_exported.py b/.scripts/community_split/libs/community/tests/unit_tests/tools/test_exported.py similarity index 80% rename from libs/langchain/tests/unit_tests/tools/test_exported.py rename to .scripts/community_split/libs/community/tests/unit_tests/tools/test_exported.py index 78eb7e50461..9ca12287cb9 100644 --- a/libs/langchain/tests/unit_tests/tools/test_exported.py +++ b/.scripts/community_split/libs/community/tests/unit_tests/tools/test_exported.py @@ -1,9 +1,10 @@ from typing import List, Type -import langchain.tools -from langchain.tools import _DEPRECATED_TOOLS -from langchain.tools import __all__ as tools_all -from langchain.tools.base import BaseTool, StructuredTool +from langchain_core.tools import BaseTool, StructuredTool + +import langchain_community.tools +from langchain_community.tools import _DEPRECATED_TOOLS +from langchain_community.tools import __all__ as tools_all _EXCLUDE = { BaseTool, @@ -17,7 +18,7 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT if tool_class_name in _DEPRECATED_TOOLS: continue # Resolve the str to the class - tool_class = getattr(langchain.tools, tool_class_name) + tool_class = getattr(langchain_community.tools, tool_class_name) if isinstance(tool_class, type) and issubclass(tool_class, BaseTool): if tool_class in _EXCLUDE: continue diff --git a/.scripts/community_split/libs/community/tests/unit_tests/vectorstores/test_faiss.py b/.scripts/community_split/libs/community/tests/unit_tests/vectorstores/test_faiss.py new file mode 100644 index 00000000000..c8ef467efe5 --- /dev/null +++ b/.scripts/community_split/libs/community/tests/unit_tests/vectorstores/test_faiss.py @@ -0,0 +1,728 @@ +"""Test FAISS functionality.""" +import datetime +import math +import tempfile + +import pytest + +from typing import Union + +from langchain_core.documents import Document + +from langchain_community.docstore.base import Docstore +from langchain_community.docstore.in_memory import InMemoryDocstore +from langchain_community.vectorstores.faiss import FAISS +from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings + + +_PAGE_CONTENT = """This is a page about LangChain. + +It is a really cool framework. + +What isn't there to love about langchain? + +Made in 2022.""" + + +class FakeDocstore(Docstore): + """Fake docstore for testing purposes.""" + + def search(self, search: str) -> Union[str, Document]: + """Return the fake document.""" + document = Document(page_content=_PAGE_CONTENT) + return document + + + +@pytest.mark.requires("faiss") +def test_faiss() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +@pytest.mark.requires("faiss") +async def test_faiss_afrom_texts() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +@pytest.mark.requires("faiss") +def test_faiss_vector_sim() -> None: + """Test vector similarity.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = FakeEmbeddings().embed_query(text="foo") + output = docsearch.similarity_search_by_vector(query_vec, k=1) + assert output == [Document(page_content="foo")] + + +@pytest.mark.requires("faiss") +async def test_faiss_async_vector_sim() -> None: + """Test vector similarity.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = await FakeEmbeddings().aembed_query(text="foo") + output = await docsearch.asimilarity_search_by_vector(query_vec, k=1) + assert output == [Document(page_content="foo")] + + +@pytest.mark.requires("faiss") +def test_faiss_vector_sim_with_score_threshold() -> None: + """Test vector similarity.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = FakeEmbeddings().embed_query(text="foo") + output = docsearch.similarity_search_by_vector(query_vec, k=2, score_threshold=0.2) + assert output == [Document(page_content="foo")] + + +@pytest.mark.requires("faiss") +async def test_faiss_vector_async_sim_with_score_threshold() -> None: + """Test vector similarity.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = await FakeEmbeddings().aembed_query(text="foo") + output = await docsearch.asimilarity_search_by_vector( + query_vec, k=2, score_threshold=0.2 + ) + assert output == [Document(page_content="foo")] + + +@pytest.mark.requires("faiss") +def test_similarity_search_with_score_by_vector() -> None: + """Test vector similarity with score by vector.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = FakeEmbeddings().embed_query(text="foo") + output = docsearch.similarity_search_with_score_by_vector(query_vec, k=1) + assert len(output) == 1 + assert output[0][0] == Document(page_content="foo") + + +@pytest.mark.requires("faiss") +async def test_similarity_async_search_with_score_by_vector() -> None: + """Test vector similarity with score by vector.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = await FakeEmbeddings().aembed_query(text="foo") + output = await docsearch.asimilarity_search_with_score_by_vector(query_vec, k=1) + assert len(output) == 1 + assert output[0][0] == Document(page_content="foo") + + +@pytest.mark.requires("faiss") +def test_similarity_search_with_score_by_vector_with_score_threshold() -> None: + """Test vector similarity with score by vector.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = FakeEmbeddings().embed_query(text="foo") + output = docsearch.similarity_search_with_score_by_vector( + query_vec, + k=2, + score_threshold=0.2, + ) + assert len(output) == 1 + assert output[0][0] == Document(page_content="foo") + assert output[0][1] < 0.2 + + +@pytest.mark.requires("faiss") +async def test_sim_asearch_with_score_by_vector_with_score_threshold() -> None: + """Test vector similarity with score by vector.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = await FakeEmbeddings().aembed_query(text="foo") + output = await docsearch.asimilarity_search_with_score_by_vector( + query_vec, + k=2, + score_threshold=0.2, + ) + assert len(output) == 1 + assert output[0][0] == Document(page_content="foo") + assert output[0][1] < 0.2 + + +@pytest.mark.requires("faiss") +def test_faiss_mmr() -> None: + texts = ["foo", "foo", "fou", "foy"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + query_vec = FakeEmbeddings().embed_query(text="foo") + # make sure we can have k > docstore size + output = docsearch.max_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1 + ) + assert len(output) == len(texts) + assert output[0][0] == Document(page_content="foo") + assert output[0][1] == 0.0 + assert output[1][0] != Document(page_content="foo") + + +@pytest.mark.requires("faiss") +async def test_faiss_async_mmr() -> None: + texts = ["foo", "foo", "fou", "foy"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + query_vec = await FakeEmbeddings().aembed_query(text="foo") + # make sure we can have k > docstore size + output = await docsearch.amax_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1 + ) + assert len(output) == len(texts) + assert output[0][0] == Document(page_content="foo") + assert output[0][1] == 0.0 + assert output[1][0] != Document(page_content="foo") + + +@pytest.mark.requires("faiss") +def test_faiss_mmr_with_metadatas() -> None: + texts = ["foo", "foo", "fou", "foy"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas) + query_vec = FakeEmbeddings().embed_query(text="foo") + output = docsearch.max_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1 + ) + assert len(output) == len(texts) + assert output[0][0] == Document(page_content="foo", metadata={"page": 0}) + assert output[0][1] == 0.0 + assert output[1][0] != Document(page_content="foo", metadata={"page": 0}) + + +@pytest.mark.requires("faiss") +async def test_faiss_async_mmr_with_metadatas() -> None: + texts = ["foo", "foo", "fou", "foy"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas) + query_vec = await FakeEmbeddings().aembed_query(text="foo") + output = await docsearch.amax_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1 + ) + assert len(output) == len(texts) + assert output[0][0] == Document(page_content="foo", metadata={"page": 0}) + assert output[0][1] == 0.0 + assert output[1][0] != Document(page_content="foo", metadata={"page": 0}) + + +@pytest.mark.requires("faiss") +def test_faiss_mmr_with_metadatas_and_filter() -> None: + texts = ["foo", "foo", "fou", "foy"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas) + query_vec = FakeEmbeddings().embed_query(text="foo") + output = docsearch.max_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1, filter={"page": 1} + ) + assert len(output) == 1 + assert output[0][0] == Document(page_content="foo", metadata={"page": 1}) + assert output[0][1] == 0.0 + + +@pytest.mark.requires("faiss") +async def test_faiss_async_mmr_with_metadatas_and_filter() -> None: + texts = ["foo", "foo", "fou", "foy"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas) + query_vec = await FakeEmbeddings().aembed_query(text="foo") + output = await docsearch.amax_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1, filter={"page": 1} + ) + assert len(output) == 1 + assert output[0][0] == Document(page_content="foo", metadata={"page": 1}) + assert output[0][1] == 0.0 + + +@pytest.mark.requires("faiss") +def test_faiss_mmr_with_metadatas_and_list_filter() -> None: + texts = ["foo", "foo", "fou", "foy"] + metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))] + docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas) + query_vec = FakeEmbeddings().embed_query(text="foo") + output = docsearch.max_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1, filter={"page": [0, 1, 2]} + ) + assert len(output) == 3 + assert output[0][0] == Document(page_content="foo", metadata={"page": 0}) + assert output[0][1] == 0.0 + assert output[1][0] != Document(page_content="foo", metadata={"page": 0}) + + +@pytest.mark.requires("faiss") +async def test_faiss_async_mmr_with_metadatas_and_list_filter() -> None: + texts = ["foo", "foo", "fou", "foy"] + metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas) + query_vec = await FakeEmbeddings().aembed_query(text="foo") + output = await docsearch.amax_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1, filter={"page": [0, 1, 2]} + ) + assert len(output) == 3 + assert output[0][0] == Document(page_content="foo", metadata={"page": 0}) + assert output[0][1] == 0.0 + assert output[1][0] != Document(page_content="foo", metadata={"page": 0}) + + +@pytest.mark.requires("faiss") +def test_faiss_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas) + expected_docstore = InMemoryDocstore( + { + docsearch.index_to_docstore_id[0]: Document( + page_content="foo", metadata={"page": 0} + ), + docsearch.index_to_docstore_id[1]: Document( + page_content="bar", metadata={"page": 1} + ), + docsearch.index_to_docstore_id[2]: Document( + page_content="baz", metadata={"page": 2} + ), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": 0})] + + +@pytest.mark.requires("faiss") +async def test_faiss_async_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas) + expected_docstore = InMemoryDocstore( + { + docsearch.index_to_docstore_id[0]: Document( + page_content="foo", metadata={"page": 0} + ), + docsearch.index_to_docstore_id[1]: Document( + page_content="bar", metadata={"page": 1} + ), + docsearch.index_to_docstore_id[2]: Document( + page_content="baz", metadata={"page": 2} + ), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": 0})] + + +@pytest.mark.requires("faiss") +def test_faiss_with_metadatas_and_filter() -> None: + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas) + expected_docstore = InMemoryDocstore( + { + docsearch.index_to_docstore_id[0]: Document( + page_content="foo", metadata={"page": 0} + ), + docsearch.index_to_docstore_id[1]: Document( + page_content="bar", metadata={"page": 1} + ), + docsearch.index_to_docstore_id[2]: Document( + page_content="baz", metadata={"page": 2} + ), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = docsearch.similarity_search("foo", k=1, filter={"page": 1}) + assert output == [Document(page_content="bar", metadata={"page": 1})] + + +@pytest.mark.requires("faiss") +async def test_faiss_async_with_metadatas_and_filter() -> None: + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas) + expected_docstore = InMemoryDocstore( + { + docsearch.index_to_docstore_id[0]: Document( + page_content="foo", metadata={"page": 0} + ), + docsearch.index_to_docstore_id[1]: Document( + page_content="bar", metadata={"page": 1} + ), + docsearch.index_to_docstore_id[2]: Document( + page_content="baz", metadata={"page": 2} + ), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = await docsearch.asimilarity_search("foo", k=1, filter={"page": 1}) + assert output == [Document(page_content="bar", metadata={"page": 1})] + + +@pytest.mark.requires("faiss") +def test_faiss_with_metadatas_and_list_filter() -> None: + texts = ["foo", "bar", "baz", "foo", "qux"] + metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))] + docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas) + expected_docstore = InMemoryDocstore( + { + docsearch.index_to_docstore_id[0]: Document( + page_content="foo", metadata={"page": 0} + ), + docsearch.index_to_docstore_id[1]: Document( + page_content="bar", metadata={"page": 1} + ), + docsearch.index_to_docstore_id[2]: Document( + page_content="baz", metadata={"page": 2} + ), + docsearch.index_to_docstore_id[3]: Document( + page_content="foo", metadata={"page": 3} + ), + docsearch.index_to_docstore_id[4]: Document( + page_content="qux", metadata={"page": 3} + ), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]}) + assert output == [Document(page_content="foo", metadata={"page": 0})] + + +@pytest.mark.requires("faiss") +async def test_faiss_async_with_metadatas_and_list_filter() -> None: + texts = ["foo", "bar", "baz", "foo", "qux"] + metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas) + expected_docstore = InMemoryDocstore( + { + docsearch.index_to_docstore_id[0]: Document( + page_content="foo", metadata={"page": 0} + ), + docsearch.index_to_docstore_id[1]: Document( + page_content="bar", metadata={"page": 1} + ), + docsearch.index_to_docstore_id[2]: Document( + page_content="baz", metadata={"page": 2} + ), + docsearch.index_to_docstore_id[3]: Document( + page_content="foo", metadata={"page": 3} + ), + docsearch.index_to_docstore_id[4]: Document( + page_content="qux", metadata={"page": 3} + ), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = await docsearch.asimilarity_search("foor", k=1, filter={"page": [0, 1, 2]}) + assert output == [Document(page_content="foo", metadata={"page": 0})] + + +@pytest.mark.requires("faiss") +def test_faiss_search_not_found() -> None: + """Test what happens when document is not found.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + # Get rid of the docstore to purposefully induce errors. + docsearch.docstore = InMemoryDocstore({}) + with pytest.raises(ValueError): + docsearch.similarity_search("foo") + + +@pytest.mark.requires("faiss") +async def test_faiss_async_search_not_found() -> None: + """Test what happens when document is not found.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + # Get rid of the docstore to purposefully induce errors. + docsearch.docstore = InMemoryDocstore({}) + with pytest.raises(ValueError): + await docsearch.asimilarity_search("foo") + + +@pytest.mark.requires("faiss") +def test_faiss_add_texts() -> None: + """Test end to end adding of texts.""" + # Create initial doc store. + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + # Test adding a similar document as before. + docsearch.add_texts(["foo"]) + output = docsearch.similarity_search("foo", k=2) + assert output == [Document(page_content="foo"), Document(page_content="foo")] + + +@pytest.mark.requires("faiss") +async def test_faiss_async_add_texts() -> None: + """Test end to end adding of texts.""" + # Create initial doc store. + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + # Test adding a similar document as before. + await docsearch.aadd_texts(["foo"]) + output = await docsearch.asimilarity_search("foo", k=2) + assert output == [Document(page_content="foo"), Document(page_content="foo")] + + +@pytest.mark.requires("faiss") +def test_faiss_add_texts_not_supported() -> None: + """Test adding of texts to a docstore that doesn't support it.""" + docsearch = FAISS(FakeEmbeddings(), None, FakeDocstore(), {}) + with pytest.raises(ValueError): + docsearch.add_texts(["foo"]) + + +@pytest.mark.requires("faiss") +async def test_faiss_async_add_texts_not_supported() -> None: + """Test adding of texts to a docstore that doesn't support it.""" + docsearch = FAISS(FakeEmbeddings(), None, FakeDocstore(), {}) + with pytest.raises(ValueError): + await docsearch.aadd_texts(["foo"]) + + +@pytest.mark.requires("faiss") +def test_faiss_local_save_load() -> None: + """Test end to end serialization.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S") + with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder: + docsearch.save_local(temp_folder) + new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings()) + assert new_docsearch.index is not None + + +@pytest.mark.requires("faiss") +async def test_faiss_async_local_save_load() -> None: + """Test end to end serialization.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings()) + temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S") + with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder: + docsearch.save_local(temp_folder) + new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings()) + assert new_docsearch.index is not None + + +@pytest.mark.requires("faiss") +def test_faiss_similarity_search_with_relevance_scores() -> None: + """Test the similarity search with normalized similarities.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts( + texts, + FakeEmbeddings(), + relevance_score_fn=lambda score: 1.0 - score / math.sqrt(2), + ) + outputs = docsearch.similarity_search_with_relevance_scores("foo", k=1) + output, score = outputs[0] + assert output == Document(page_content="foo") + assert score == 1.0 + + +@pytest.mark.requires("faiss") +async def test_faiss_async_similarity_search_with_relevance_scores() -> None: + """Test the similarity search with normalized similarities.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts( + texts, + FakeEmbeddings(), + relevance_score_fn=lambda score: 1.0 - score / math.sqrt(2), + ) + outputs = await docsearch.asimilarity_search_with_relevance_scores("foo", k=1) + output, score = outputs[0] + assert output == Document(page_content="foo") + assert score == 1.0 + + +@pytest.mark.requires("faiss") +def test_faiss_similarity_search_with_relevance_scores_with_threshold() -> None: + """Test the similarity search with normalized similarities with score threshold.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts( + texts, + FakeEmbeddings(), + relevance_score_fn=lambda score: 1.0 - score / math.sqrt(2), + ) + outputs = docsearch.similarity_search_with_relevance_scores( + "foo", k=2, score_threshold=0.5 + ) + assert len(outputs) == 1 + output, score = outputs[0] + assert output == Document(page_content="foo") + assert score == 1.0 + + +@pytest.mark.requires("faiss") +async def test_faiss_asimilarity_search_with_relevance_scores_with_threshold() -> None: + """Test the similarity search with normalized similarities with score threshold.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts( + texts, + FakeEmbeddings(), + relevance_score_fn=lambda score: 1.0 - score / math.sqrt(2), + ) + outputs = await docsearch.asimilarity_search_with_relevance_scores( + "foo", k=2, score_threshold=0.5 + ) + assert len(outputs) == 1 + output, score = outputs[0] + assert output == Document(page_content="foo") + assert score == 1.0 + + +@pytest.mark.requires("faiss") +def test_faiss_invalid_normalize_fn() -> None: + """Test the similarity search with normalized similarities.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts( + texts, FakeEmbeddings(), relevance_score_fn=lambda _: 2.0 + ) + with pytest.warns(Warning, match="scores must be between"): + docsearch.similarity_search_with_relevance_scores("foo", k=1) + + +@pytest.mark.requires("faiss") +async def test_faiss_async_invalid_normalize_fn() -> None: + """Test the similarity search with normalized similarities.""" + texts = ["foo", "bar", "baz"] + docsearch = await FAISS.afrom_texts( + texts, FakeEmbeddings(), relevance_score_fn=lambda _: 2.0 + ) + with pytest.warns(Warning, match="scores must be between"): + await docsearch.asimilarity_search_with_relevance_scores("foo", k=1) + + +@pytest.mark.requires("faiss") +def test_missing_normalize_score_fn() -> None: + """Test doesn't perform similarity search without a valid distance strategy.""" + texts = ["foo", "bar", "baz"] + faiss_instance = FAISS.from_texts(texts, FakeEmbeddings(), distance_strategy="fake") + with pytest.raises(ValueError): + faiss_instance.similarity_search_with_relevance_scores("foo", k=2) + + +@pytest.mark.requires("faiss") +async def test_async_missing_normalize_score_fn() -> None: + """Test doesn't perform similarity search without a valid distance strategy.""" + texts = ["foo", "bar", "baz"] + faiss_instance = await FAISS.afrom_texts( + texts, FakeEmbeddings(), distance_strategy="fake" + ) + with pytest.raises(ValueError): + await faiss_instance.asimilarity_search_with_relevance_scores("foo", k=2) + + +@pytest.mark.requires("faiss") +def test_delete() -> None: + """Test the similarity search with normalized similarities.""" + ids = ["a", "b", "c"] + docsearch = FAISS.from_texts(["foo", "bar", "baz"], FakeEmbeddings(), ids=ids) + docsearch.delete(ids[1:2]) + + result = docsearch.similarity_search("bar", k=2) + assert sorted([d.page_content for d in result]) == ["baz", "foo"] + assert docsearch.index_to_docstore_id == {0: ids[0], 1: ids[2]} + + +@pytest.mark.requires("faiss") +async def test_async_delete() -> None: + """Test the similarity search with normalized similarities.""" + ids = ["a", "b", "c"] + docsearch = await FAISS.afrom_texts( + ["foo", "bar", "baz"], FakeEmbeddings(), ids=ids + ) + docsearch.delete(ids[1:2]) + + result = await docsearch.asimilarity_search("bar", k=2) + assert sorted([d.page_content for d in result]) == ["baz", "foo"] + assert docsearch.index_to_docstore_id == {0: ids[0], 1: ids[2]} diff --git a/.scripts/community_split/libs/community/tests/unit_tests/vectorstores/test_imports.py b/.scripts/community_split/libs/community/tests/unit_tests/vectorstores/test_imports.py new file mode 100644 index 00000000000..95cbcc11d96 --- /dev/null +++ b/.scripts/community_split/libs/community/tests/unit_tests/vectorstores/test_imports.py @@ -0,0 +1,13 @@ +from langchain_community import vectorstores +from langchain_core.vectorstores import VectorStore + + +def test_all_imports() -> None: + """Simple test to make sure all things can be imported.""" + for cls in vectorstores.__all__: + if cls not in [ + "AlibabaCloudOpenSearchSettings", + "ClickhouseSettings", + "MyScaleSettings", + ]: + assert issubclass(getattr(vectorstores, cls), VectorStore) diff --git a/.scripts/community_split/libs/core/langchain_core/load/load.py b/.scripts/community_split/libs/core/langchain_core/load/load.py new file mode 100644 index 00000000000..188aa45d61d --- /dev/null +++ b/.scripts/community_split/libs/core/langchain_core/load/load.py @@ -0,0 +1,144 @@ +import importlib +import json +import os +from typing import Any, Dict, List, Optional + +from langchain_core.load.mapping import SERIALIZABLE_MAPPING +from langchain_core.load.serializable import Serializable + +DEFAULT_NAMESPACES = ["langchain", "langchain_core", "langchain_community"] + + +class Reviver: + """Reviver for JSON objects.""" + + def __init__( + self, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, + ) -> None: + self.secrets_map = secrets_map or dict() + # By default only support langchain, but user can pass in additional namespaces + self.valid_namespaces = ( + [*DEFAULT_NAMESPACES, *valid_namespaces] + if valid_namespaces + else DEFAULT_NAMESPACES + ) + + def __call__(self, value: Dict[str, Any]) -> Any: + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "secret" + and value.get("id", None) is not None + ): + [key] = value["id"] + if key in self.secrets_map: + return self.secrets_map[key] + else: + if key in os.environ and os.environ[key]: + return os.environ[key] + raise KeyError(f'Missing key "{key}" in load(secrets_map)') + + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "not_implemented" + and value.get("id", None) is not None + ): + raise NotImplementedError( + "Trying to load an object that doesn't implement " + f"serialization: {value}" + ) + + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "constructor" + and value.get("id", None) is not None + ): + [*namespace, name] = value["id"] + + if namespace[0] not in self.valid_namespaces: + raise ValueError(f"Invalid namespace: {value}") + + # The root namespace "langchain" is not a valid identifier. + if len(namespace) == 1 and namespace[0] == "langchain": + raise ValueError(f"Invalid namespace: {value}") + + # Get the importable path + key = tuple(namespace + [name]) + if key not in SERIALIZABLE_MAPPING: + raise ValueError( + "Trying to deserialize something that cannot " + "be deserialized in current version of langchain-core: " + f"{key}" + ) + import_path = SERIALIZABLE_MAPPING[key] + # Split into module and name + import_dir, import_obj = import_path[:-1], import_path[-1] + # Import module + mod = importlib.import_module(".".join(import_dir)) + # Import class + cls = getattr(mod, import_obj) + + # The class must be a subclass of Serializable. + if not issubclass(cls, Serializable): + raise ValueError(f"Invalid namespace: {value}") + + # We don't need to recurse on kwargs + # as json.loads will do that for us. + kwargs = value.get("kwargs", dict()) + return cls(**kwargs) + + return value + + +def loads( + text: str, + *, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, +) -> Any: + """Revive a LangChain class from a JSON string. + Equivalent to `load(json.loads(text))`. + + Args: + text: The string to load. + secrets_map: A map of secrets to load. + valid_namespaces: A list of additional namespaces (modules) + to allow to be deserialized. + + Returns: + Revived LangChain objects. + """ + return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces)) + + +def load( + obj: Any, + *, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, +) -> Any: + """Revive a LangChain class from a JSON object. Use this if you already + have a parsed JSON object, eg. from `json.load` or `orjson.loads`. + + Args: + obj: The object to load. + secrets_map: A map of secrets to load. + valid_namespaces: A list of additional namespaces (modules) + to allow to be deserialized. + + Returns: + Revived LangChain objects. + """ + reviver = Reviver(secrets_map, valid_namespaces) + + def _load(obj: Any) -> Any: + if isinstance(obj, dict): + # Need to revive leaf nodes before reviving this node + loaded_obj = {k: _load(v) for k, v in obj.items()} + return reviver(loaded_obj) + if isinstance(obj, list): + return [_load(o) for o in obj] + return obj + + return _load(obj) diff --git a/.scripts/community_split/libs/core/langchain_core/utils/__init__.py b/.scripts/community_split/libs/core/langchain_core/utils/__init__.py new file mode 100644 index 00000000000..6491a85f17f --- /dev/null +++ b/.scripts/community_split/libs/core/langchain_core/utils/__init__.py @@ -0,0 +1,49 @@ +""" +**Utility functions** for LangChain. + +These functions do not depend on any other LangChain module. +""" + +from langchain_core.utils.env import get_from_dict_or_env, get_from_env +from langchain_core.utils.formatting import StrictFormatter, formatter +from langchain_core.utils.input import ( + get_bolded_text, + get_color_mapping, + get_colored_text, + print_text, +) +from langchain_core.utils.loading import try_load_from_hub +from langchain_core.utils.strings import comma_list, stringify_dict, stringify_value +from langchain_core.utils.utils import ( + build_extra_kwargs, + check_package_version, + convert_to_secret_str, + get_pydantic_field_names, + guard_import, + mock_now, + raise_for_status_with_text, + xor_args, +) + +__all__ = [ + "StrictFormatter", + "check_package_version", + "convert_to_secret_str", + "formatter", + "get_bolded_text", + "get_color_mapping", + "get_colored_text", + "get_pydantic_field_names", + "guard_import", + "mock_now", + "print_text", + "raise_for_status_with_text", + "xor_args", + "try_load_from_hub", + "build_extra_kwargs", + "get_from_env", + "get_from_dict_or_env", + "stringify_dict", + "comma_list", + "stringify_value", +] diff --git a/.scripts/community_split/libs/core/langchain_core/utils/env.py b/.scripts/community_split/libs/core/langchain_core/utils/env.py new file mode 100644 index 00000000000..b1579e07b7d --- /dev/null +++ b/.scripts/community_split/libs/core/langchain_core/utils/env.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, Optional + + +def env_var_is_set(env_var: str) -> bool: + """Check if an environment variable is set. + + Args: + env_var (str): The name of the environment variable. + + Returns: + bool: True if the environment variable is set, False otherwise. + """ + return env_var in os.environ and os.environ[env_var] not in ( + "", + "0", + "false", + "False", + ) + + +def get_from_dict_or_env( + data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None +) -> str: + """Get a value from a dictionary or an environment variable.""" + if key in data and data[key]: + return data[key] + else: + return get_from_env(key, env_key, default=default) + + +def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: + """Get a value from a dictionary or an environment variable.""" + if env_key in os.environ and os.environ[env_key]: + return os.environ[env_key] + elif default is not None: + return default + else: + raise ValueError( + f"Did not find {key}, please add an environment variable" + f" `{env_key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) diff --git a/.scripts/community_split/libs/core/tests/unit_tests/utils/test_imports.py b/.scripts/community_split/libs/core/tests/unit_tests/utils/test_imports.py new file mode 100644 index 00000000000..9ebbb7e1e21 --- /dev/null +++ b/.scripts/community_split/libs/core/tests/unit_tests/utils/test_imports.py @@ -0,0 +1,28 @@ +from langchain_core.utils import __all__ + +EXPECTED_ALL = [ + "StrictFormatter", + "check_package_version", + "convert_to_secret_str", + "formatter", + "get_bolded_text", + "get_color_mapping", + "get_colored_text", + "get_pydantic_field_names", + "guard_import", + "mock_now", + "print_text", + "raise_for_status_with_text", + "xor_args", + "try_load_from_hub", + "build_extra_kwargs", + "get_from_dict_or_env", + "get_from_env", + "stringify_dict", + "comma_list", + "stringify_value" +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/.scripts/community_split/libs/langchain/langchain/callbacks/__init__.py b/.scripts/community_split/libs/langchain/langchain/callbacks/__init__.py new file mode 100644 index 00000000000..422c4b54b25 --- /dev/null +++ b/.scripts/community_split/libs/langchain/langchain/callbacks/__init__.py @@ -0,0 +1,83 @@ +"""**Callback handlers** allow listening to events in LangChain. + +**Class hierarchy:** + +.. code-block:: + + BaseCallbackHandler --> CallbackHandler # Example: AimCallbackHandler +""" + +from langchain_core.callbacks import StdOutCallbackHandler, StreamingStdOutCallbackHandler +from langchain_core.tracers.langchain import LangChainTracer +from langchain_core.tracers.context import ( + collect_runs, + tracing_enabled, + tracing_v2_enabled, +) + +from langchain_community.callbacks.aim_callback import AimCallbackHandler +from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler +from langchain_community.callbacks.arize_callback import ArizeCallbackHandler +from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler +from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler +from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler +from langchain_community.callbacks.context_callback import ContextCallbackHandler +from langchain.callbacks.file import FileCallbackHandler +from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler +from langchain_community.callbacks.human import HumanApprovalCallbackHandler +from langchain_community.callbacks.infino_callback import InfinoCallbackHandler +from langchain_community.callbacks.labelstudio_callback import LabelStudioCallbackHandler +from langchain_community.callbacks.llmonitor_callback import LLMonitorCallbackHandler +from langchain_community.callbacks.mlflow_callback import MlflowCallbackHandler +from langchain_community.callbacks.openai_info import OpenAICallbackHandler +from langchain_community.callbacks.promptlayer_callback import PromptLayerCallbackHandler +from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler +from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler +from langchain.callbacks.streaming_stdout_final_only import ( + FinalStreamingStdOutCallbackHandler, +) +from langchain_community.callbacks.streamlit import LLMThoughtLabeler, StreamlitCallbackHandler +from langchain_community.callbacks.trubrics_callback import TrubricsCallbackHandler +from langchain_community.callbacks.wandb_callback import WandbCallbackHandler +from langchain_community.callbacks.whylabs_callback import WhyLabsCallbackHandler + +from langchain_community.callbacks.manager import ( + get_openai_callback, + wandb_tracing_enabled, +) + + +__all__ = [ + "AimCallbackHandler", + "ArgillaCallbackHandler", + "ArizeCallbackHandler", + "PromptLayerCallbackHandler", + "ArthurCallbackHandler", + "ClearMLCallbackHandler", + "CometCallbackHandler", + "ContextCallbackHandler", + "FileCallbackHandler", + "HumanApprovalCallbackHandler", + "InfinoCallbackHandler", + "MlflowCallbackHandler", + "LLMonitorCallbackHandler", + "OpenAICallbackHandler", + "StdOutCallbackHandler", + "AsyncIteratorCallbackHandler", + "StreamingStdOutCallbackHandler", + "FinalStreamingStdOutCallbackHandler", + "LLMThoughtLabeler", + "LangChainTracer", + "StreamlitCallbackHandler", + "WandbCallbackHandler", + "WhyLabsCallbackHandler", + "get_openai_callback", + "tracing_enabled", + "tracing_v2_enabled", + "collect_runs", + "wandb_tracing_enabled", + "FlyteCallbackHandler", + "SageMakerCallbackHandler", + "LabelStudioCallbackHandler", + "TrubricsCallbackHandler", +] diff --git a/.scripts/community_split/libs/langchain/langchain/callbacks/manager.py b/.scripts/community_split/libs/langchain/langchain/callbacks/manager.py new file mode 100644 index 00000000000..6c61a4cdb6c --- /dev/null +++ b/.scripts/community_split/libs/langchain/langchain/callbacks/manager.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainGroup, + AsyncCallbackManagerForChainRun, + AsyncCallbackManagerForLLMRun, + AsyncCallbackManagerForRetrieverRun, + AsyncCallbackManagerForToolRun, + AsyncParentRunManager, + AsyncRunManager, + BaseRunManager, + CallbackManager, + CallbackManagerForChainGroup, + CallbackManagerForChainRun, + CallbackManagerForLLMRun, + CallbackManagerForRetrieverRun, + CallbackManagerForToolRun, + Callbacks, + ParentRunManager, + RunManager, + ahandle_event, + atrace_as_chain_group, + handle_event, + trace_as_chain_group, +) +from langchain_core.tracers.context import ( + collect_runs, + tracing_enabled, + tracing_v2_enabled, +) +from langchain_core.utils.env import env_var_is_set +from langchain_community.callbacks.manager import ( + get_openai_callback, + wandb_tracing_enabled, +) + + +__all__ = [ + "BaseRunManager", + "RunManager", + "ParentRunManager", + "AsyncRunManager", + "AsyncParentRunManager", + "CallbackManagerForLLMRun", + "AsyncCallbackManagerForLLMRun", + "CallbackManagerForChainRun", + "AsyncCallbackManagerForChainRun", + "CallbackManagerForToolRun", + "AsyncCallbackManagerForToolRun", + "CallbackManagerForRetrieverRun", + "AsyncCallbackManagerForRetrieverRun", + "CallbackManager", + "CallbackManagerForChainGroup", + "AsyncCallbackManager", + "AsyncCallbackManagerForChainGroup", + "tracing_enabled", + "tracing_v2_enabled", + "collect_runs", + "atrace_as_chain_group", + "trace_as_chain_group", + "handle_event", + "ahandle_event", + "Callbacks", + "env_var_is_set", + "get_openai_callback", + "wandb_tracing_enabled", +] diff --git a/.scripts/community_split/libs/langchain/tests/unit_tests/callbacks/test_manager.py b/.scripts/community_split/libs/langchain/tests/unit_tests/callbacks/test_manager.py new file mode 100644 index 00000000000..8ee369e81fd --- /dev/null +++ b/.scripts/community_split/libs/langchain/tests/unit_tests/callbacks/test_manager.py @@ -0,0 +1,36 @@ +from langchain.callbacks.manager import __all__ + +EXPECTED_ALL = [ + "BaseRunManager", + "RunManager", + "ParentRunManager", + "AsyncRunManager", + "AsyncParentRunManager", + "CallbackManagerForLLMRun", + "AsyncCallbackManagerForLLMRun", + "CallbackManagerForChainRun", + "AsyncCallbackManagerForChainRun", + "CallbackManagerForToolRun", + "AsyncCallbackManagerForToolRun", + "CallbackManagerForRetrieverRun", + "AsyncCallbackManagerForRetrieverRun", + "CallbackManager", + "CallbackManagerForChainGroup", + "AsyncCallbackManager", + "AsyncCallbackManagerForChainGroup", + "tracing_enabled", + "tracing_v2_enabled", + "collect_runs", + "atrace_as_chain_group", + "trace_as_chain_group", + "handle_event", + "ahandle_event", + "env_var_is_set", + "Callbacks", + "get_openai_callback", + "wandb_tracing_enabled", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/.scripts/community_split/libs/langchain/tests/unit_tests/chains/test_llm.py b/.scripts/community_split/libs/langchain/tests/unit_tests/chains/test_llm.py new file mode 100644 index 00000000000..0179cd135f2 --- /dev/null +++ b/.scripts/community_split/libs/langchain/tests/unit_tests/chains/test_llm.py @@ -0,0 +1,75 @@ +"""Test LLM chain.""" +from tempfile import TemporaryDirectory +from typing import Dict, List, Union +from unittest.mock import patch + +import pytest +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.prompts import PromptTemplate + +from langchain.chains.llm import LLMChain +from tests.unit_tests.llms.fake_llm import FakeLLM + + +class FakeOutputParser(BaseOutputParser): + """Fake output parser class for testing.""" + + def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]: + """Parse by splitting.""" + return text.split() + + +@pytest.fixture +def fake_llm_chain() -> LLMChain: + """Fake LLM chain for testing purposes.""" + prompt = PromptTemplate(input_variables=["bar"], template="This is a {bar}:") + return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") + + +@patch( + "langchain_community.llms.loading.get_type_to_cls_dict", + lambda: {"fake": lambda: FakeLLM}, +) +def test_serialization(fake_llm_chain: LLMChain) -> None: + """Test serialization.""" + from langchain.chains.loading import load_chain + + with TemporaryDirectory() as temp_dir: + file = temp_dir + "/llm.json" + fake_llm_chain.save(file) + loaded_chain = load_chain(file) + assert loaded_chain == fake_llm_chain + + +def test_missing_inputs(fake_llm_chain: LLMChain) -> None: + """Test error is raised if inputs are missing.""" + with pytest.raises(ValueError): + fake_llm_chain({"foo": "bar"}) + + +def test_valid_call(fake_llm_chain: LLMChain) -> None: + """Test valid call of LLM chain.""" + output = fake_llm_chain({"bar": "baz"}) + assert output == {"bar": "baz", "text1": "foo"} + + # Test with stop words. + output = fake_llm_chain({"bar": "baz", "stop": ["foo"]}) + # Response should be `bar` now. + assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"} + + +def test_predict_method(fake_llm_chain: LLMChain) -> None: + """Test predict method works.""" + output = fake_llm_chain.predict(bar="baz") + assert output == "foo" + + +def test_predict_and_parse() -> None: + """Test parsing ability.""" + prompt = PromptTemplate( + input_variables=["foo"], template="{foo}", output_parser=FakeOutputParser() + ) + llm = FakeLLM(queries={"foo": "foo bar"}) + chain = LLMChain(prompt=prompt, llm=llm) + output = chain.predict_and_parse(foo="foo") + assert output == ["foo", "bar"] diff --git a/.scripts/community_split/libs/langchain/tests/unit_tests/load/test_serializable.py b/.scripts/community_split/libs/langchain/tests/unit_tests/load/test_serializable.py new file mode 100644 index 00000000000..cb6c4f9ff4c --- /dev/null +++ b/.scripts/community_split/libs/langchain/tests/unit_tests/load/test_serializable.py @@ -0,0 +1,57 @@ +import importlib +import pkgutil + +from langchain_core.load.mapping import SERIALIZABLE_MAPPING + + +def import_all_modules(package_name: str) -> dict: + package = importlib.import_module(package_name) + classes: dict = {} + + for attribute_name in dir(package): + attribute = getattr(package, attribute_name) + if hasattr(attribute, "is_lc_serializable") and isinstance(attribute, type): + if ( + isinstance(attribute.is_lc_serializable(), bool) # type: ignore + and attribute.is_lc_serializable() # type: ignore + ): + key = tuple(attribute.lc_id()) # type: ignore + value = tuple(attribute.__module__.split(".") + [attribute.__name__]) + if key in classes and classes[key] != value: + raise ValueError + classes[key] = value + if hasattr(package, "__path__"): + for loader, module_name, is_pkg in pkgutil.walk_packages( + package.__path__, package_name + "." + ): + if module_name not in ( + "langchain.chains.llm_bash", + "langchain.chains.llm_symbolic_math", + "langchain.tools.python", + "langchain.vectorstores._pgvector_data_models", + # TODO: why does this error? + "langchain.agents.agent_toolkits.openapi.planner", + ): + importlib.import_module(module_name) + new_classes = import_all_modules(module_name) + for k, v in new_classes.items(): + if k in classes and classes[k] != v: + raise ValueError + classes[k] = v + return classes + + +def test_serializable_mapping() -> None: + serializable_modules = import_all_modules("langchain") + missing = set(SERIALIZABLE_MAPPING).difference(serializable_modules) + assert missing == set() + extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING) + assert extra == set() + + for k, import_path in serializable_modules.items(): + import_dir, import_obj = import_path[:-1], import_path[-1] + # Import module + mod = importlib.import_module(".".join(import_dir)) + # Import class + cls = getattr(mod, import_obj) + assert list(k) == cls.lc_id() diff --git a/.scripts/community_split/libs/langchain/tests/unit_tests/test_dependencies.py b/.scripts/community_split/libs/langchain/tests/unit_tests/test_dependencies.py new file mode 100644 index 00000000000..872b01d6213 --- /dev/null +++ b/.scripts/community_split/libs/langchain/tests/unit_tests/test_dependencies.py @@ -0,0 +1,113 @@ +"""A unit test meant to catch accidental introduction of non-optional dependencies.""" +from pathlib import Path +from typing import Any, Dict, Mapping + +import pytest +import toml + +HERE = Path(__file__).parent + +PYPROJECT_TOML = HERE / "../../pyproject.toml" + + +@pytest.fixture() +def poetry_conf() -> Dict[str, Any]: + """Load the pyproject.toml file.""" + with open(PYPROJECT_TOML) as f: + return toml.load(f)["tool"]["poetry"] + + +def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None: + """A test that checks if a new non-optional dependency is being introduced. + + If this test is triggered, it means that a contributor is trying to introduce a new + required dependency. This should be avoided in most situations. + """ + # Get the dependencies from the [tool.poetry.dependencies] section + dependencies = poetry_conf["dependencies"] + + is_required = { + package_name: isinstance(requirements, str) + or not requirements.get("optional", False) + for package_name, requirements in dependencies.items() + } + required_dependencies = [ + package_name for package_name, required in is_required.items() if required + ] + + assert sorted(required_dependencies) == sorted( + [ + "PyYAML", + "SQLAlchemy", + "aiohttp", + "async-timeout", + "dataclasses-json", + "jsonpatch", + "langchain-core", + "langsmith", + "numpy", + "pydantic", + "python", + "requests", + "tenacity", + "langchain-community", + ] + ) + + unrequired_dependencies = [ + package_name for package_name, required in is_required.items() if not required + ] + in_extras = [dep for group in poetry_conf["extras"].values() for dep in group] + assert set(unrequired_dependencies) == set(in_extras) + + +def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: + """Check if someone is attempting to add additional test dependencies. + + Only dependencies associated with test running infrastructure should be added + to the test group; e.g., pytest, pytest-cov etc. + + Examples of dependencies that should NOT be included: boto3, azure, postgres, etc. + """ + + test_group_deps = sorted(poetry_conf["group"]["test"]["dependencies"]) + + assert test_group_deps == sorted( + [ + "duckdb-engine", + "freezegun", + "langchain-core", + "lark", + "pandas", + "pytest", + "pytest-asyncio", + "pytest-cov", + "pytest-dotenv", + "pytest-mock", + "pytest-socket", + "pytest-watcher", + "responses", + "syrupy", + "requests-mock", + ] + ) + + +def test_imports() -> None: + """Test that you can import all top level things okay.""" + from langchain_core.prompts import BasePromptTemplate # noqa: F401 + + from langchain.agents import OpenAIFunctionsAgent # noqa: F401 + from langchain.callbacks import OpenAICallbackHandler # noqa: F401 + from langchain.chains import LLMChain # noqa: F401 + from langchain.chat_models import ChatOpenAI # noqa: F401 + from langchain.document_loaders import BSHTMLLoader # noqa: F401 + from langchain.embeddings import OpenAIEmbeddings # noqa: F401 + from langchain.llms import OpenAI # noqa: F401 + from langchain.retrievers import VespaRetriever # noqa: F401 + from langchain.tools import DuckDuckGoSearchResults # noqa: F401 + from langchain.utilities import ( + SearchApiAPIWrapper, # noqa: F401 + SerpAPIWrapper, # noqa: F401 + ) + from langchain.vectorstores import FAISS # noqa: F401 diff --git a/.scripts/community_split/script_integrations.sh b/.scripts/community_split/script_integrations.sh new file mode 100755 index 00000000000..56e5222261f --- /dev/null +++ b/.scripts/community_split/script_integrations.sh @@ -0,0 +1,313 @@ +#!/bin/bash + +cd libs + +# cleanup anything existing +git checkout master -- langchain/{langchain,tests} +git checkout master -- core/{langchain_core,tests} +git checkout master -- experimental/{langchain_experimental,tests} +rm -rf community/{langchain_community,tests} + +# make new dirs +mkdir -p community/langchain_community +touch community/langchain_community/__init__.py +touch community/langchain_community/py.typed +touch community/README.md +mkdir -p community/tests +touch community/tests/__init__.py +mkdir community/tests/unit_tests +touch community/tests/unit_tests/__init__.py +mkdir community/tests/integration_tests/ +touch community/tests/integration_tests/__init__.py +mkdir -p community/langchain_community/utils +touch community/langchain_community/utils/__init__.py +mkdir -p community/tests/unit_tests/utils +touch community/tests/unit_tests/utils/__init__.py +mkdir -p community/langchain_community/indexes +touch community/langchain_community/indexes/__init__.py +mkdir community/tests/unit_tests/indexes +touch community/tests/unit_tests/indexes/__init__.py + +# import core stuff from core +cd langchain + +git grep -l 'from langchain.pydantic_v1' | xargs sed -i '' 's/from langchain.pydantic_v1/from langchain_core.pydantic_v1/g' +git grep -l 'from langchain.tools.base' | xargs sed -i '' 's/from langchain.tools.base/from langchain_core.tools/g' +git grep -l 'from langchain.chat_models.base' | xargs sed -i '' 's/from langchain.chat_models.base/from langchain_core.language_models.chat_models/g' +git grep -l 'from langchain.llms.base' | xargs sed -i '' 's/from langchain.llms.base/from langchain_core.language_models.llms/g' +git grep -l 'from langchain.embeddings.base' | xargs sed -i '' 's/from langchain.embeddings.base/from langchain_core.embeddings/g' +git grep -l 'from langchain.vectorstores.base' | xargs sed -i '' 's/from langchain.vectorstores.base/from langchain_core.vectorstores/g' +git grep -l 'from langchain.agents.tools' | xargs sed -i '' 's/from langchain.agents.tools/from langchain_core.tools/g' +git grep -l 'from langchain.schema.output' | xargs sed -i '' 's/from langchain.schema.output/from langchain_core.outputs/g' +git grep -l 'from langchain.schema.messages' | xargs sed -i '' 's/from langchain.schema.messages/from langchain_core.messages/g' +git grep -l 'from langchain.schema.embeddings' | xargs sed -i '' 's/from langchain.schema.embeddings/from langchain_core.embeddings/g' + +# mv stuff to community +cd .. + +mv langchain/langchain/adapters community/langchain_community +mv langchain/langchain/callbacks community/langchain_community/callbacks +mv langchain/langchain/chat_loaders community/langchain_community +mv langchain/langchain/chat_models community/langchain_community +mv langchain/langchain/document_loaders community/langchain_community +mv langchain/langchain/docstore community/langchain_community +mv langchain/langchain/document_transformers community/langchain_community +mv langchain/langchain/embeddings community/langchain_community +mv langchain/langchain/graphs community/langchain_community +mv langchain/langchain/llms community/langchain_community +mv langchain/langchain/memory/chat_message_histories community/langchain_community +mv langchain/langchain/retrievers community/langchain_community +mv langchain/langchain/storage community/langchain_community +mv langchain/langchain/tools community/langchain_community +mv langchain/langchain/utilities community/langchain_community +mv langchain/langchain/vectorstores community/langchain_community + +mv langchain/langchain/agents/agent_toolkits community/langchain_community +mv langchain/langchain/cache.py community/langchain_community +mv langchain/langchain/indexes/base.py community/langchain_community/indexes +mv langchain/langchain/indexes/_sql_record_manager.py community/langchain_community/indexes +mv langchain/langchain/utils/{math,openai,openai_functions}.py community/langchain_community/utils + +# mv stuff to core +mv langchain/langchain/utils/json_schema.py core/langchain_core/utils +mv langchain/langchain/utils/html.py core/langchain_core/utils +mv langchain/langchain/utils/strings.py core/langchain_core/utils +cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py +rm langchain/langchain/utils/env.py + +# mv unit tests to community +mv langchain/tests/unit_tests/chat_loaders community/tests/unit_tests +mv langchain/tests/unit_tests/document_loaders community/tests/unit_tests +mv langchain/tests/unit_tests/docstore community/tests/unit_tests +mv langchain/tests/unit_tests/document_transformers community/tests/unit_tests +mv langchain/tests/unit_tests/embeddings community/tests/unit_tests +mv langchain/tests/unit_tests/graphs community/tests/unit_tests +mv langchain/tests/unit_tests/llms community/tests/unit_tests +mv langchain/tests/unit_tests/chat_models community/tests/unit_tests +mv langchain/tests/unit_tests/memory/chat_message_histories community/tests/unit_tests +mv langchain/tests/unit_tests/storage community/tests/unit_tests +mv langchain/tests/unit_tests/tools community/tests/unit_tests +mv langchain/tests/unit_tests/utilities community/tests/unit_tests +mv langchain/tests/unit_tests/vectorstores community/tests/unit_tests +mv langchain/tests/unit_tests/retrievers community/tests/unit_tests +mv langchain/tests/unit_tests/callbacks community/tests/unit_tests +mv langchain/tests/unit_tests/indexes/test_sql_record_manager.py community/tests/unit_tests/indexes +mv langchain/tests/unit_tests/utils/test_math.py community/tests/unit_tests/utils + +# cp some test helpers back to langchain +mkdir -p langchain/tests/unit_tests/llms +cp {community,langchain}/tests/unit_tests/llms/fake_llm.py +cp {community,langchain}/tests/unit_tests/llms/fake_chat_model.py +mkdir -p langchain/tests/unit_tests/callbacks +cp {community,langchain}/tests/unit_tests/callbacks/fake_callback_handler.py + +# mv unit tests to core +mv langchain/tests/unit_tests/utils/test_json_schema.py core/tests/unit_tests/utils +mv langchain/tests/unit_tests/utils/test_html.py core/tests/unit_tests/utils + +# mv integration tests to community +mv langchain/tests/integration_tests/document_loaders community/tests/integration_tests +mv langchain/tests/integration_tests/embeddings community/tests/integration_tests +mv langchain/tests/integration_tests/graphs community/tests/integration_tests +mv langchain/tests/integration_tests/llms community/tests/integration_tests +mv langchain/tests/integration_tests/chat_models community/tests/integration_tests +mv langchain/tests/integration_tests/memory/chat_message_histories community/tests/integration_tests +mv langchain/tests/integration_tests/storage community/tests/integration_tests +mv langchain/tests/integration_tests/tools community/tests/integration_tests +mv langchain/tests/integration_tests/utilities community/tests/integration_tests +mv langchain/tests/integration_tests/vectorstores community/tests/integration_tests +mv langchain/tests/integration_tests/retrievers community/tests/integration_tests +mv langchain/tests/integration_tests/adapters community/tests/integration_tests +mv langchain/tests/integration_tests/callbacks community/tests/integration_tests +mv langchain/tests/integration_tests/{test_kuzu,test_nebulagraph}.py community/tests/integration_tests/graphs +touch community/tests/integration_tests/{chat_message_histories,tools}/__init__.py + +# import new core stuff from core (everywhere) +git grep -l 'from langchain.utils.json_schema' | xargs sed -i '' 's/from langchain.utils.json_schema/from langchain_core.utils.json_schema/g' +git grep -l 'from langchain.utils.html' | xargs sed -i '' 's/from langchain.utils.html/from langchain_core.utils.html/g' +git grep -l 'from langchain.utils.strings' | xargs sed -i '' 's/from langchain.utils.strings/from langchain_core.utils.strings/g' +git grep -l 'from langchain.utils.env' | xargs sed -i '' 's/from langchain.utils.env/from langchain_core.utils.env/g' + +git add community +cd community + +# import core stuff from core +git grep -l 'from langchain.pydantic_v1' | xargs sed -i '' 's/from langchain.pydantic_v1/from langchain_core.pydantic_v1/g' +git grep -l 'from langchain.callbacks.base' | xargs sed -i '' 's/from langchain.callbacks.base/from langchain_core.callbacks/g' +git grep -l 'from langchain.callbacks.stdout' | xargs sed -i '' 's/from langchain.callbacks.stdout/from langchain_core.callbacks/g' +git grep -l 'from langchain.callbacks.streaming_stdout' | xargs sed -i '' 's/from langchain.callbacks.streaming_stdout/from langchain_core.callbacks/g' +git grep -l 'from langchain.callbacks.manager' | xargs sed -i '' 's/from langchain.callbacks.manager/from langchain_core.callbacks/g' +git grep -l 'from langchain.callbacks.tracers.base' | xargs sed -i '' 's/from langchain.callbacks.tracers.base/from langchain_core.tracers/g' +git grep -l 'from langchain.tools.base' | xargs sed -i '' 's/from langchain.tools.base/from langchain_core.tools/g' +git grep -l 'from langchain.agents.tools' | xargs sed -i '' 's/from langchain.agents.tools/from langchain_core.tools/g' +git grep -l 'from langchain.schema.output' | xargs sed -i '' 's/from langchain.schema.output/from langchain_core.outputs/g' +git grep -l 'from langchain.schema.messages' | xargs sed -i '' 's/from langchain.schema.messages/from langchain_core.messages/g' +git grep -l 'from langchain.schema import BaseRetriever' | xargs sed -i '' 's/from langchain.schema\ import\ BaseRetriever/from langchain_core.retrievers import BaseRetriever/g' +git grep -l 'from langchain.schema import Document' | xargs sed -i '' 's/from langchain.schema\ import\ Document/from langchain_core.documents import Document/g' + +# import openai stuff from openai +git grep -l 'from langchain.utils.math' | xargs sed -i '' 's/from langchain.utils.math/from langchain_community.utils.math/g' +git grep -l 'from langchain.utils.openai_functions' | xargs sed -i '' 's/from langchain.utils.openai_functions/from langchain_community.utils.openai_functions/g' +git grep -l 'from langchain.utils.openai' | xargs sed -i '' 's/from langchain.utils.openai/from langchain_community.utils.openai/g' +git grep -l 'from langchain.utils' | xargs sed -i '' 's/from langchain.utils/from langchain_core.utils/g' +git grep -l 'from langchain\.' | xargs sed -i '' 's/from langchain\./from langchain_community./g' +git grep -l 'from langchain_community.memory.chat_message_histories' | xargs sed -i '' 's/from langchain_community.memory.chat_message_histories/from langchain_community.chat_message_histories/g' +git grep -l 'from langchain_community.agents.agent_toolkits' | xargs sed -i '' 's/from langchain_community.agents.agent_toolkits/from langchain_community.agent_toolkits/g' + +sed -i '' 's/from\ langchain.chat_models\ import\ ChatOpenAI/from langchain_openai.chat_models import ChatOpenAI/g' langchain_community/chat_models/promptlayer_openai.py + +git grep -l 'from langchain_community\.text_splitter' | xargs sed -i '' 's/from langchain_community\.text_splitter/from langchain.text_splitter/g' +git grep -l 'from langchain_community\.chains' | xargs sed -i '' 's/from langchain_community\.chains/from langchain.chains/g' +git grep -l 'from langchain_community\.agents' | xargs sed -i '' 's/from langchain_community\.agents/from langchain.agents/g' +git grep -l 'from langchain_community\.memory' | xargs sed -i '' 's/from langchain_community\.memory/from langchain.memory/g' +git grep -l 'langchain\.__version__' | xargs sed -i '' 's/langchain\.__version__/langchain_community.__version__/g' +git grep -l 'langchain\.document_loaders' | xargs sed -i '' 's/langchain\.document_loaders/langchain_community.document_loaders/g' +git grep -l 'langchain\.callbacks' | xargs sed -i '' 's/langchain\.callbacks/langchain_community.callbacks/g' +git grep -l 'langchain\.tools' | xargs sed -i '' 's/langchain\.tools/langchain_community.tools/g' +git grep -l 'langchain\.llms' | xargs sed -i '' 's/langchain\.llms/langchain_community.llms/g' +git grep -l 'import langchain$' | xargs sed -i '' 's/import\ langchain$/import\ langchain_community/g' +git grep -l 'from\ langchain\ ' | xargs sed -i '' 's/from\ langchain\ /from\ langchain_community\ /g' +git grep -l 'langchain_core.language_models.llmsten' | xargs sed -i '' 's/langchain_core.language_models.llmsten/langchain_community.llms.baseten/g' + +# update all moved langchain files to re-export classes and functions +cd ../langchain +git checkout master -- langchain + +python ../../.scripts/community_split/update_imports.py langchain/chat_loaders langchain_community.chat_loaders +python ../../.scripts/community_split/update_imports.py langchain/callbacks langchain_community.callbacks +python ../../.scripts/community_split/update_imports.py langchain/document_loaders langchain_community.document_loaders +python ../../.scripts/community_split/update_imports.py langchain/docstore langchain_community.docstore +python ../../.scripts/community_split/update_imports.py langchain/document_transformers langchain_community.document_transformers +python ../../.scripts/community_split/update_imports.py langchain/embeddings langchain_community.embeddings +python ../../.scripts/community_split/update_imports.py langchain/graphs langchain_community.graphs +python ../../.scripts/community_split/update_imports.py langchain/llms langchain_community.llms +python ../../.scripts/community_split/update_imports.py langchain/chat_models langchain_community.chat_models +python ../../.scripts/community_split/update_imports.py langchain/memory/chat_message_histories langchain_community.chat_message_histories +python ../../.scripts/community_split/update_imports.py langchain/storage langchain_community.storage +python ../../.scripts/community_split/update_imports.py langchain/tools langchain_community.tools +python ../../.scripts/community_split/update_imports.py langchain/utilities langchain_community.utilities +python ../../.scripts/community_split/update_imports.py langchain/vectorstores langchain_community.vectorstores +python ../../.scripts/community_split/update_imports.py langchain/retrievers langchain_community.retrievers +python ../../.scripts/community_split/update_imports.py langchain/adapters langchain_community.adapters +python ../../.scripts/community_split/update_imports.py langchain/agents/agent_toolkits langchain_community.agent_toolkits +python ../../.scripts/community_split/update_imports.py langchain/cache.py langchain_community.cache +python ../../.scripts/community_split/update_imports.py langchain/utils/math.py langchain_community.utils.math +python ../../.scripts/community_split/update_imports.py langchain/utils/json_schema.py langchain_core.utils.json_schema +python ../../.scripts/community_split/update_imports.py langchain/utils/html.py langchain_core.utils.html +python ../../.scripts/community_split/update_imports.py langchain/utils/env.py langchain_core.utils.env +python ../../.scripts/community_split/update_imports.py langchain/utils/strings.py langchain_core.utils.strings +python ../../.scripts/community_split/update_imports.py langchain/utils/openai.py langchain_community.utils.openai +python ../../.scripts/community_split/update_imports.py langchain/utils/openai_functions.py langchain_community.utils.openai_functions + +# update core and openai imports +git grep -l 'from langchain.llms.base ' | xargs sed -i '' 's/from langchain.llms.base /from langchain_core.language_models.llms /g' +git grep -l 'from langchain.chat_models.base ' | xargs sed -i '' 's/from langchain.chat_models.base /from langchain_core.language_models.chat_models /g' +git grep -l 'from langchain.tools.base' | xargs sed -i '' 's/from langchain.tools.base/from langchain_core.tools/g' + +git grep -l 'langchain_core.language_models.llmsten' | xargs sed -i '' 's/langchain_core.language_models.llmsten/langchain_community.llms.baseten/g' + +cd .. + +mv community/langchain_community/utilities/loading.py langchain/langchain/utilities +mv community/langchain_community/utilities/asyncio.py langchain/langchain/utilities + +#git add partners +git add core + +# rm files from community that just export core classes +rm community/langchain_community/{chat_models,llms,tools,embeddings,vectorstores,callbacks}/base.py +rm community/tests/unit_tests/{chat_models,llms,tools,callbacks}/test_base.py +rm community/tests/unit_tests/callbacks/test_manager.py +rm community/langchain_community/callbacks/{stdout,streaming_stdout}.py +rm community/langchain_community/callbacks/tracers/{base,evaluation,langchain,langchain_v1,log_stream,root_listeners,run_collector,schemas,stdout}.py + +# keep export tests in langchain +git checkout master -- langchain/tests/unit_tests/{chat_models,llms,tools,callbacks,document_loaders}/test_base.py +git checkout master -- langchain/tests/unit_tests/{callbacks,docstore,document_loaders,document_transformers,embeddings,graphs,llms,chat_models,storage,tools,utilities,vectorstores}/test_imports.py +git checkout master -- langchain/tests/unit_tests/callbacks/test_manager.py +git checkout master -- langchain/tests/unit_tests/document_loaders/blob_loaders/test_public_api.py +git checkout master -- langchain/tests/unit_tests/document_loaders/parsers/test_public_api.py +git checkout master -- langchain/tests/unit_tests/vectorstores/test_public_api.py +git checkout master -- langchain/tests/unit_tests/schema + +# keep some non-integration stuff in langchain. rm from community and add back to langchain +rm community/langchain_community/retrievers/{multi_query,multi_vector,contextual_compression,ensemble,merger_retriever,parent_document_retriever,re_phraser,web_research,time_weighted_retriever}.py +rm -r community/langchain_community/retrievers/{self_query,document_compressors} +rm community/tests/unit_tests/retrievers/test_{ensemble,multi_query,multi_vector,parent_document,time_weighted_retriever,web_research}.py +rm community/tests/integration_tests/retrievers/test_{contextual_compression,merger_retriever}.py +rm -r community/tests/unit_tests/retrievers/{self_query,document_compressors} +rm -r community/tests/integration_tests/retrievers/document_compressors + +rm community/langchain_community/agent_toolkits/{pandas,python,spark}/__init__.py +rm community/langchain_community/tools/python/__init__.py + +rm -r community/langchain_community/agent_toolkits/conversational_retrieval/ +rm -r community/langchain_community/agent_toolkits/vectorstore/ +rm community/langchain_community/callbacks/tracers/logging.py +rm community/langchain_community/callbacks/{file,streaming_aiter_final_only,streaming_aiter,streaming_stdout_final_only}.py +rm community/langchain_community/embeddings/cache.py +rm community/langchain_community/storage/{encoder_backed,file_system,in_memory,_lc_store}.py +rm community/langchain_community/tools/retriever.py +rm community/tests/unit_tests/callbacks/tracers/test_logging.py +rm community/tests/unit_tests/embeddings/test_caching.py +rm community/tests/unit_tests/storage/test_{filesystem,in_memory,lc_store}.py + +git checkout master -- langchain/langchain/retrievers/{multi_query,multi_vector,self_query/base,contextual_compression,ensemble,merger_retriever,parent_document_retriever,re_phraser,web_research,time_weighted_retriever}.py +git checkout master -- langchain/langchain/retrievers/{self_query,document_compressors} +git checkout master -- langchain/tests/unit_tests/retrievers/test_{ensemble,multi_query,multi_vector,parent_document,time_weighted_retriever,web_research}.py +git checkout master -- langchain/tests/integration_tests/retrievers/test_{contextual_compression,merger_retriever}.py +git checkout master -- langchain/tests/unit_tests/retrievers/{self_query,document_compressors} +git checkout master -- langchain/tests/integration_tests/retrievers/document_compressors +touch langchain/tests/unit_tests/{llms,chat_models,tools,callbacks,runnables,document_loaders,docstore,document_transformers,embeddings,graphs,storage,utilities,vectorstores,retrievers}/__init__.py +touch langchain/tests/unit_tests/document_loaders/{blob_loaders,parsers}/__init__.py +mv {community,langchain}/tests/unit_tests/retrievers/sequential_retriever.py + +git checkout master -- langchain/langchain/agents/agent_toolkits/conversational_retrieval/ +git checkout master -- langchain/langchain/agents/agent_toolkits/vectorstore/ +git checkout master -- langchain/langchain/callbacks/tracers/logging.py +git checkout master -- langchain/langchain/callbacks/{file,streaming_aiter_final_only,streaming_aiter,streaming_stdout_final_only}.py +git checkout master -- langchain/langchain/embeddings/cache.py +git checkout master -- langchain/langchain/storage/{encoder_backed,file_system,in_memory,_lc_store}.py +git checkout master -- langchain/langchain/tools/retriever.py +git checkout master -- langchain/tests/unit_tests/callbacks/tracers/{test_logging,__init__}.py +git checkout master -- langchain/tests/unit_tests/embeddings/{__init__,test_caching}.py +git checkout master -- langchain/tests/unit_tests/storage/test_{filesystem,in_memory,lc_store}.py +git checkout master -- langchain/tests/unit_tests/storage/__init__.py + +# cp lint scripts +cp -r core/scripts community + +# cp test helpers +cp -r langchain/tests/integration_tests/examples community/tests +cp -r langchain/tests/integration_tests/examples community/tests/integration_tests +cp -r langchain/tests/unit_tests/examples community/tests/unit_tests +cp langchain/tests/unit_tests/conftest.py community/tests/unit_tests +cp community/tests/integration_tests/vectorstores/fake_embeddings.py langchain/tests/integration_tests/cache/ +cp langchain/tests/integration_tests/test_compile.py community/tests/integration_tests + +# cp manually changed files +cp -r ../.scripts/community_split/libs/* . + +# mv some tests to integrations +mv community/tests/{unit_tests,integration_tests}/document_loaders/test_telegram.py +mv community/tests/{unit_tests,integration_tests}/document_loaders/parsers/test_docai.py +mv community/tests/{unit_tests,integration_tests}/chat_message_histories/test_streamlit.py + +# fix some final tests +git grep -l 'integration_tests\.vectorstores\.fake_embeddings' langchain/tests | xargs sed -i '' 's/integration_tests\.vectorstores\.fake_embeddings/integration_tests.cache.fake_embeddings/g' +touch community/langchain_community/agent_toolkits/amadeus/__init__.py + +# format +cd core +make format +cd ../langchain +make format +cd ../experimental +make format +cd ../community +make format + +cd .. +sed -E -i '' '1 s/(.*)/\1\ \ \#\ noqa\:\ E501/g' langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py +sed -E -i '' 's/import\ importlib$/import importlib.util/g' experimental/langchain_experimental/prompts/load.py +git add . \ No newline at end of file diff --git a/.scripts/community_split/update_imports.py b/.scripts/community_split/update_imports.py new file mode 100644 index 00000000000..e555c42d9cf --- /dev/null +++ b/.scripts/community_split/update_imports.py @@ -0,0 +1,85 @@ +import ast +import os +import sys +from pathlib import Path + + +class ImportTransformer(ast.NodeTransformer): + def __init__(self, public_items, module_name): + self.public_items = public_items + self.module_name = module_name + + def visit_Module(self, node): + imports = [ + ast.ImportFrom( + module=self.module_name, + names=[ast.alias(name=item, asname=None)], + level=0, + ) + for item in self.public_items + ] + all_assignment = ast.Assign( + targets=[ast.Name(id="__all__", ctx=ast.Store())], + value=ast.List( + elts=[ast.Str(s=item) for item in self.public_items], ctx=ast.Load() + ), + ) + node.body = imports + [all_assignment] + return node + + +def find_public_classes_and_methods(file_path): + with open(file_path, "r") as file: + node = ast.parse(file.read(), filename=file_path) + + public_items = [] + for item in node.body: + if isinstance(item, ast.ClassDef) or isinstance(item, ast.FunctionDef): + public_items.append(item.name) + if ( + isinstance(item, ast.Assign) + and hasattr(item.targets[0], "id") + and item.targets[0].id not in ("__all__", "logger") + ): + public_items.append(item.targets[0].id) + + return public_items or None + + +def process_file(file_path, module_name): + public_items = find_public_classes_and_methods(file_path) + if public_items is None: + return + + with open(file_path, "r") as file: + contents = file.read() + tree = ast.parse(contents, filename=file_path) + + tree = ImportTransformer(public_items, module_name).visit(tree) + tree = ast.fix_missing_locations(tree) + + with open(file_path, "w") as file: + file.write(ast.unparse(tree)) + + +def process_directory(directory_path, base_module_name): + if Path(directory_path).is_file(): + process_file(directory_path, base_module_name) + else: + for root, dirs, files in os.walk(directory_path): + for filename in files: + if filename.endswith(".py") and not filename.startswith("_"): + file_path = os.path.join(root, filename) + relative_path = os.path.relpath(file_path, directory_path) + module_name = f"{base_module_name}.{os.path.splitext(relative_path)[0].replace(os.sep, '.')}" + process_file(file_path, module_name) + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python script_name.py ") + sys.exit(1) + + directory_path = sys.argv[1] + base_module_name = sys.argv[2] + process_directory(directory_path, base_module_name) diff --git a/README.md b/README.md index 7e1a97d6f30..38dd3c40dc2 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,10 @@ This framework consists of several parts. - **[LangServe](https://github.com/langchain-ai/langserve)**: A library for deploying LangChain chains as a REST API. - **[LangSmith](https://smith.langchain.com)**: A developer platform that lets you debug, test, evaluate, and monitor chains built on any LLM framework and seamlessly integrates with LangChain. -**This repo contains the `langchain` ([here](libs/langchain)), `langchain-experimental` ([here](libs/experimental)), and `langchain-cli` ([here](libs/cli)) Python packages, as well as [LangChain Templates](templates).** +The LangChain libraries themselves are made up of several different packages. +- **[`langchain-core`](libs/core)**: Base abstractions and LangChain Expression Language. +- **[`langchain-community`](libs/community)**: Third party integrations. +- **[`langchain`](libs/langchain)**: Chains, agents, and retrieval strategies that make up an application's cognitive architecture. ![LangChain Stack](docs/static/img/langchain_stack.png) diff --git a/docs/api_reference/create_api_rst.py b/docs/api_reference/create_api_rst.py index 8eda2a46d5f..176bf444f2a 100644 --- a/docs/api_reference/create_api_rst.py +++ b/docs/api_reference/create_api_rst.py @@ -14,9 +14,10 @@ HERE = Path(__file__).parent PKG_DIR = ROOT_DIR / "libs" / "langchain" / "langchain" EXP_DIR = ROOT_DIR / "libs" / "experimental" / "langchain_experimental" CORE_DIR = ROOT_DIR / "libs" / "core" / "langchain_core" +COMMUNITY_DIR = ROOT_DIR / "libs" / "core" / "langchain_community" WRITE_FILE = HERE / "api_reference.rst" EXP_WRITE_FILE = HERE / "experimental_api_reference.rst" -CORE_WRITE_FILE = HERE / "core_api_reference.rst" +COMMUNITY_WRITE_FILE = HERE / "community_api_reference.rst" ClassKind = Literal["TypedDict", "Regular", "Pydantic", "enum"] @@ -302,6 +303,7 @@ package_namespace = { "langchain": "langchain", "experimental": "langchain_experimental", "core": "langchain_core", + "community": "langchain_community", } @@ -316,6 +318,7 @@ def _out_file_path(package_name: str = "langchain") -> Path: "langchain": "", "experimental": "experimental_", "core": "core_", + "community": "community_", } return HERE / f"{name_prefix[package_name]}api_reference.rst" @@ -326,6 +329,7 @@ def _doc_first_line(package_name: str = "langchain") -> str: "langchain": "", "experimental": "experimental", "core": "core", + "community": "community", } return f".. {prefix[package_name]}_api_reference:\n\n" @@ -335,6 +339,7 @@ def main() -> None: _build_rst_file(package_name="core") _build_rst_file(package_name="langchain") _build_rst_file(package_name="experimental") + _build_rst_file(package_name="community") if __name__ == "__main__": diff --git a/docs/api_reference/requirements.txt b/docs/api_reference/requirements.txt index 59acb690193..ecc1cd9565b 100644 --- a/docs/api_reference/requirements.txt +++ b/docs/api_reference/requirements.txt @@ -1,6 +1,7 @@ --e libs/langchain -e libs/experimental +-e libs/langchain -e libs/core +-e libs/community pydantic<2 autodoc_pydantic==1.8.0 myst_parser diff --git a/docs/api_reference/themes/scikit-learn-modern/nav.html b/docs/api_reference/themes/scikit-learn-modern/nav.html index 37c59b466d5..118a41c8a16 100644 --- a/docs/api_reference/themes/scikit-learn-modern/nav.html +++ b/docs/api_reference/themes/scikit-learn-modern/nav.html @@ -37,6 +37,9 @@ + diff --git a/docs/docs/get_started/introduction.mdx b/docs/docs/get_started/introduction.mdx index d3e42904ff7..2cef4cc5f65 100644 --- a/docs/docs/get_started/introduction.mdx +++ b/docs/docs/get_started/introduction.mdx @@ -29,6 +29,11 @@ The main value props of the LangChain packages are: Off-the-shelf chains make it easy to get started. Components make it easy to customize existing chains and build new ones. +The LangChain libraries themselves are made up of several different packages. +- **`langchain-core`**: Base abstractions and LangChain Expression Language. +- **`langchain-community`**: Third party integrations. +- **`langchain`**: Chains, agents, and retrieval strategies that make up an application's cognitive architecture. + ## Get started [Here’s](/docs/get_started/installation) how to install LangChain, set up your environment, and start building. diff --git a/libs/community/Makefile b/libs/community/Makefile new file mode 100644 index 00000000000..fe1509fe8dd --- /dev/null +++ b/libs/community/Makefile @@ -0,0 +1,69 @@ +.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests + +# Default target executed when no arguments are given to make. +all: help + +# Define a variable for the test file path. +TEST_FILE ?= tests/unit_tests/ + +test: + poetry run pytest $(TEST_FILE) + +tests: + poetry run pytest $(TEST_FILE) + +test_watch: + poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests + +check_imports: langchain_community/**/*.py + for f in $^ ; do \ + python -c "from importlib.machinery import SourceFileLoader; SourceFileLoader('x', '$$f').load_module()" || exit 1; \ + done + +extended_tests: + poetry run pytest --only-extended tests/unit_tests + + +###################### +# LINTING AND FORMATTING +###################### + +# Define a variable for Python and notebook files. +PYTHON_FILES=. +MYPY_CACHE=.mypy_cache +lint format: PYTHON_FILES=. +lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/community -name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') +lint_package: PYTHON_FILES=langchain_community +lint_tests: PYTHON_FILES=tests +lint_tests: MYPY_CACHE=.mypy_cache_test + +lint lint_diff lint_package lint_tests: + ./scripts/check_pydantic.sh . + ./scripts/check_imports.sh + poetry run ruff . + [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff + [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + +format format_diff: + poetry run ruff format $(PYTHON_FILES) + poetry run ruff --select I --fix $(PYTHON_FILES) + +spell_check: + poetry run codespell --toml pyproject.toml + +spell_fix: + poetry run codespell --toml pyproject.toml -w + +###################### +# HELP +###################### + +help: + @echo '----' + @echo 'format - run code formatters' + @echo 'lint - run linters' + @echo 'test - run unit tests' + @echo 'tests - run unit tests' + @echo 'test TEST_FILE= - run all tests in file' + @echo 'test_watch - run unit tests in watch mode' diff --git a/libs/community/README.md b/libs/community/README.md new file mode 100644 index 00000000000..d0f8c668bdc --- /dev/null +++ b/libs/community/README.md @@ -0,0 +1 @@ +# langchain-community \ No newline at end of file diff --git a/libs/community/langchain_community/__init__.py b/libs/community/langchain_community/__init__.py new file mode 100644 index 00000000000..ac7aeef6faf --- /dev/null +++ b/libs/community/langchain_community/__init__.py @@ -0,0 +1,9 @@ +"""Main entrypoint into package.""" +from importlib import metadata + +try: + __version__ = metadata.version(__package__) +except metadata.PackageNotFoundError: + # Case where package metadata is not available. + __version__ = "" +del metadata # optional, avoids polluting the results of dir(__package__) diff --git a/libs/langchain/tests/integration_tests/adapters/__init__.py b/libs/community/langchain_community/adapters/__init__.py similarity index 100% rename from libs/langchain/tests/integration_tests/adapters/__init__.py rename to libs/community/langchain_community/adapters/__init__.py diff --git a/libs/community/langchain_community/adapters/openai.py b/libs/community/langchain_community/adapters/openai.py new file mode 100644 index 00000000000..0af759ebf5b --- /dev/null +++ b/libs/community/langchain_community/adapters/openai.py @@ -0,0 +1,399 @@ +from __future__ import annotations + +import importlib +from typing import ( + Any, + AsyncIterator, + Dict, + Iterable, + List, + Mapping, + Sequence, + Union, + overload, +) + +from langchain_core.chat_sessions import ChatSession +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.pydantic_v1 import BaseModel +from typing_extensions import Literal + + +async def aenumerate( + iterable: AsyncIterator[Any], start: int = 0 +) -> AsyncIterator[tuple[int, Any]]: + """Async version of enumerate function.""" + i = start + async for x in iterable: + yield i, x + i += 1 + + +class IndexableBaseModel(BaseModel): + """Allows a BaseModel to return its fields by string variable indexing""" + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) + + +class Choice(IndexableBaseModel): + message: dict + + +class ChatCompletions(IndexableBaseModel): + choices: List[Choice] + + +class ChoiceChunk(IndexableBaseModel): + delta: dict + + +class ChatCompletionChunk(IndexableBaseModel): + choices: List[ChoiceChunk] + + +def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + """Convert a dictionary to a LangChain message. + + Args: + _dict: The dictionary. + + Returns: + The LangChain message. + """ + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + # Fix for azure + # Also OpenAI returns None for tool invocations + content = _dict.get("content", "") or "" + additional_kwargs: Dict = {} + if _dict.get("function_call"): + additional_kwargs["function_call"] = dict(_dict["function_call"]) + if _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = _dict["tool_calls"] + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + elif role == "tool": + return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def convert_message_to_dict(message: BaseMessage) -> dict: + """Convert a LangChain message to a dictionary. + + Args: + message: The LangChain message. + + Returns: + The dictionary. + """ + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + # If function call only, content is None not empty string + if message_dict["content"] == "": + message_dict["content"] = None + if "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] + # If tool calls only, content is None not empty string + if message_dict["content"] == "": + message_dict["content"] = None + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + elif isinstance(message, ToolMessage): + message_dict = { + "role": "tool", + "content": message.content, + "tool_call_id": message.tool_call_id, + } + else: + raise TypeError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]: + """Convert dictionaries representing OpenAI messages to LangChain format. + + Args: + messages: List of dictionaries representing OpenAI messages + + Returns: + List of LangChain BaseMessage objects. + """ + return [convert_dict_to_message(m) for m in messages] + + +def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict: + _dict: Dict[str, Any] = {} + if isinstance(chunk, AIMessageChunk): + if i == 0: + # Only shows up in the first chunk + _dict["role"] = "assistant" + if "function_call" in chunk.additional_kwargs: + _dict["function_call"] = chunk.additional_kwargs["function_call"] + # If the first chunk is a function call, the content is not empty string, + # not missing, but None. + if i == 0: + _dict["content"] = None + else: + _dict["content"] = chunk.content + else: + raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}") + # This only happens at the end of streams, and OpenAI returns as empty dict + if _dict == {"content": ""}: + _dict = {} + return _dict + + +def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]: + _dict = _convert_message_chunk(chunk, i) + return {"choices": [{"delta": _dict}]} + + +class ChatCompletion: + """Chat completion.""" + + @overload + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[False] = False, + **kwargs: Any, + ) -> dict: + ... + + @overload + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[True], + **kwargs: Any, + ) -> Iterable: + ... + + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: bool = False, + **kwargs: Any, + ) -> Union[dict, Iterable]: + models = importlib.import_module("langchain.chat_models") + model_cls = getattr(models, provider) + model_config = model_cls(**kwargs) + converted_messages = convert_openai_messages(messages) + if not stream: + result = model_config.invoke(converted_messages) + return {"choices": [{"message": convert_message_to_dict(result)}]} + else: + return ( + _convert_message_chunk_to_delta(c, i) + for i, c in enumerate(model_config.stream(converted_messages)) + ) + + @overload + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[False] = False, + **kwargs: Any, + ) -> dict: + ... + + @overload + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[True], + **kwargs: Any, + ) -> AsyncIterator: + ... + + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: bool = False, + **kwargs: Any, + ) -> Union[dict, AsyncIterator]: + models = importlib.import_module("langchain.chat_models") + model_cls = getattr(models, provider) + model_config = model_cls(**kwargs) + converted_messages = convert_openai_messages(messages) + if not stream: + result = await model_config.ainvoke(converted_messages) + return {"choices": [{"message": convert_message_to_dict(result)}]} + else: + return ( + _convert_message_chunk_to_delta(c, i) + async for i, c in aenumerate(model_config.astream(converted_messages)) + ) + + +def _has_assistant_message(session: ChatSession) -> bool: + """Check if chat session has an assistant message.""" + return any([isinstance(m, AIMessage) for m in session["messages"]]) + + +def convert_messages_for_finetuning( + sessions: Iterable[ChatSession], +) -> List[List[dict]]: + """Convert messages to a list of lists of dictionaries for fine-tuning. + + Args: + sessions: The chat sessions. + + Returns: + The list of lists of dictionaries. + """ + return [ + [convert_message_to_dict(s) for s in session["messages"]] + for session in sessions + if _has_assistant_message(session) + ] + + +class Completions: + """Completion.""" + + @overload + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[False] = False, + **kwargs: Any, + ) -> ChatCompletions: + ... + + @overload + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[True], + **kwargs: Any, + ) -> Iterable: + ... + + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: bool = False, + **kwargs: Any, + ) -> Union[ChatCompletions, Iterable]: + models = importlib.import_module("langchain.chat_models") + model_cls = getattr(models, provider) + model_config = model_cls(**kwargs) + converted_messages = convert_openai_messages(messages) + if not stream: + result = model_config.invoke(converted_messages) + return ChatCompletions( + choices=[Choice(message=convert_message_to_dict(result))] + ) + else: + return ( + ChatCompletionChunk( + choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))] + ) + for i, c in enumerate(model_config.stream(converted_messages)) + ) + + @overload + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[False] = False, + **kwargs: Any, + ) -> ChatCompletions: + ... + + @overload + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[True], + **kwargs: Any, + ) -> AsyncIterator: + ... + + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: bool = False, + **kwargs: Any, + ) -> Union[ChatCompletions, AsyncIterator]: + models = importlib.import_module("langchain.chat_models") + model_cls = getattr(models, provider) + model_config = model_cls(**kwargs) + converted_messages = convert_openai_messages(messages) + if not stream: + result = await model_config.ainvoke(converted_messages) + return ChatCompletions( + choices=[Choice(message=convert_message_to_dict(result))] + ) + else: + return ( + ChatCompletionChunk( + choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))] + ) + async for i, c in aenumerate(model_config.astream(converted_messages)) + ) + + +class Chat: + def __init__(self) -> None: + self.completions = Completions() + + +chat = Chat() diff --git a/libs/community/langchain_community/agent_toolkits/__init__.py b/libs/community/langchain_community/agent_toolkits/__init__.py new file mode 100644 index 00000000000..9501eb5db14 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/__init__.py @@ -0,0 +1,78 @@ +"""Agent toolkits contain integrations with various resources and services. + +LangChain has a large ecosystem of integrations with various external resources +like local and remote file systems, APIs and databases. + +These integrations allow developers to create versatile applications that combine the +power of LLMs with the ability to access, interact with and manipulate external +resources. + +When developing an application, developers should inspect the capabilities and +permissions of the tools that underlie the given agent toolkit, and determine +whether permissions of the given toolkit are appropriate for the application. + +See [Security](https://python.langchain.com/docs/security) for more information. +""" +from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit +from langchain_community.agent_toolkits.amadeus.toolkit import AmadeusToolkit +from langchain_community.agent_toolkits.azure_cognitive_services import ( + AzureCognitiveServicesToolkit, +) +from langchain_community.agent_toolkits.conversational_retrieval.openai_functions import ( # noqa: E501 + create_conversational_retrieval_agent, +) +from langchain_community.agent_toolkits.file_management.toolkit import ( + FileManagementToolkit, +) +from langchain_community.agent_toolkits.gmail.toolkit import GmailToolkit +from langchain_community.agent_toolkits.jira.toolkit import JiraToolkit +from langchain_community.agent_toolkits.json.base import create_json_agent +from langchain_community.agent_toolkits.json.toolkit import JsonToolkit +from langchain_community.agent_toolkits.multion.toolkit import MultionToolkit +from langchain_community.agent_toolkits.nasa.toolkit import NasaToolkit +from langchain_community.agent_toolkits.nla.toolkit import NLAToolkit +from langchain_community.agent_toolkits.office365.toolkit import O365Toolkit +from langchain_community.agent_toolkits.openapi.base import create_openapi_agent +from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit +from langchain_community.agent_toolkits.playwright.toolkit import ( + PlayWrightBrowserToolkit, +) +from langchain_community.agent_toolkits.powerbi.base import create_pbi_agent +from langchain_community.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent +from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit +from langchain_community.agent_toolkits.slack.toolkit import SlackToolkit +from langchain_community.agent_toolkits.spark_sql.base import create_spark_sql_agent +from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit +from langchain_community.agent_toolkits.sql.base import create_sql_agent +from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain_community.agent_toolkits.steam.toolkit import SteamToolkit +from langchain_community.agent_toolkits.zapier.toolkit import ZapierToolkit + +__all__ = [ + "AINetworkToolkit", + "AmadeusToolkit", + "AzureCognitiveServicesToolkit", + "FileManagementToolkit", + "GmailToolkit", + "JiraToolkit", + "JsonToolkit", + "MultionToolkit", + "NasaToolkit", + "NLAToolkit", + "O365Toolkit", + "OpenAPIToolkit", + "PlayWrightBrowserToolkit", + "PowerBIToolkit", + "SlackToolkit", + "SteamToolkit", + "SQLDatabaseToolkit", + "SparkSQLToolkit", + "ZapierToolkit", + "create_json_agent", + "create_openapi_agent", + "create_pbi_agent", + "create_pbi_chat_agent", + "create_spark_sql_agent", + "create_sql_agent", + "create_conversational_retrieval_agent", +] diff --git a/libs/community/langchain_community/agent_toolkits/ainetwork/__init__.py b/libs/community/langchain_community/agent_toolkits/ainetwork/__init__.py new file mode 100644 index 00000000000..c4295f2ef48 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/ainetwork/__init__.py @@ -0,0 +1 @@ +"""AINetwork toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/ainetwork/toolkit.py b/libs/community/langchain_community/agent_toolkits/ainetwork/toolkit.py new file mode 100644 index 00000000000..3e0c140e99c --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/ainetwork/toolkit.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Literal, Optional + +from langchain_core.pydantic_v1 import root_validator + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.ainetwork.app import AINAppOps +from langchain_community.tools.ainetwork.owner import AINOwnerOps +from langchain_community.tools.ainetwork.rule import AINRuleOps +from langchain_community.tools.ainetwork.transfer import AINTransfer +from langchain_community.tools.ainetwork.utils import authenticate +from langchain_community.tools.ainetwork.value import AINValueOps + +if TYPE_CHECKING: + from ain.ain import Ain + + +class AINetworkToolkit(BaseToolkit): + """Toolkit for interacting with AINetwork Blockchain. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by reading, creating, updating, deleting + data associated with this service. + + See https://python.langchain.com/docs/security for more information. + """ + + network: Optional[Literal["mainnet", "testnet"]] = "testnet" + interface: Optional[Ain] = None + + @root_validator(pre=True) + def set_interface(cls, values: dict) -> dict: + if not values.get("interface"): + values["interface"] = authenticate(network=values.get("network", "testnet")) + return values + + class Config: + """Pydantic config.""" + + validate_all = True + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + AINAppOps(), + AINOwnerOps(), + AINRuleOps(), + AINTransfer(), + AINValueOps(), + ] diff --git a/libs/langchain/tests/integration_tests/callbacks/__init__.py b/libs/community/langchain_community/agent_toolkits/amadeus/__init__.py similarity index 100% rename from libs/langchain/tests/integration_tests/callbacks/__init__.py rename to libs/community/langchain_community/agent_toolkits/amadeus/__init__.py diff --git a/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py b/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py new file mode 100644 index 00000000000..90bc5da6476 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List + +from langchain_core.pydantic_v1 import Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.amadeus.closest_airport import AmadeusClosestAirport +from langchain_community.tools.amadeus.flight_search import AmadeusFlightSearch +from langchain_community.tools.amadeus.utils import authenticate + +if TYPE_CHECKING: + from amadeus import Client + + +class AmadeusToolkit(BaseToolkit): + """Toolkit for interacting with Amadeus which offers APIs for travel.""" + + client: Client = Field(default_factory=authenticate) + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + AmadeusClosestAirport(), + AmadeusFlightSearch(), + ] diff --git a/libs/community/langchain_community/agent_toolkits/azure_cognitive_services.py b/libs/community/langchain_community/agent_toolkits/azure_cognitive_services.py new file mode 100644 index 00000000000..3c2de898ae4 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/azure_cognitive_services.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import sys +from typing import List + +from langchain_core.tools import BaseTool + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools.azure_cognitive_services import ( + AzureCogsFormRecognizerTool, + AzureCogsImageAnalysisTool, + AzureCogsSpeech2TextTool, + AzureCogsText2SpeechTool, + AzureCogsTextAnalyticsHealthTool, +) + + +class AzureCognitiveServicesToolkit(BaseToolkit): + """Toolkit for Azure Cognitive Services.""" + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + + tools: List[BaseTool] = [ + AzureCogsFormRecognizerTool(), + AzureCogsSpeech2TextTool(), + AzureCogsText2SpeechTool(), + AzureCogsTextAnalyticsHealthTool(), + ] + + # TODO: Remove check once azure-ai-vision supports MacOS. + if sys.platform.startswith("linux") or sys.platform.startswith("win"): + tools.append(AzureCogsImageAnalysisTool()) + return tools diff --git a/libs/community/langchain_community/agent_toolkits/base.py b/libs/community/langchain_community/agent_toolkits/base.py new file mode 100644 index 00000000000..0315184115e --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/base.py @@ -0,0 +1,15 @@ +"""Toolkits for agents.""" +from abc import ABC, abstractmethod +from typing import List + +from langchain_core.pydantic_v1 import BaseModel + +from langchain_community.tools import BaseTool + + +class BaseToolkit(BaseModel, ABC): + """Base Toolkit representing a collection of related tools.""" + + @abstractmethod + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" diff --git a/libs/langchain/tests/integration_tests/chat_models/__init__.py b/libs/community/langchain_community/agent_toolkits/clickup/__init__.py similarity index 100% rename from libs/langchain/tests/integration_tests/chat_models/__init__.py rename to libs/community/langchain_community/agent_toolkits/clickup/__init__.py diff --git a/libs/community/langchain_community/agent_toolkits/clickup/toolkit.py b/libs/community/langchain_community/agent_toolkits/clickup/toolkit.py new file mode 100644 index 00000000000..23a1c9a76c8 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/clickup/toolkit.py @@ -0,0 +1,108 @@ +from typing import Dict, List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.clickup.prompt import ( + CLICKUP_FOLDER_CREATE_PROMPT, + CLICKUP_GET_ALL_TEAMS_PROMPT, + CLICKUP_GET_FOLDERS_PROMPT, + CLICKUP_GET_LIST_PROMPT, + CLICKUP_GET_SPACES_PROMPT, + CLICKUP_GET_TASK_ATTRIBUTE_PROMPT, + CLICKUP_GET_TASK_PROMPT, + CLICKUP_LIST_CREATE_PROMPT, + CLICKUP_TASK_CREATE_PROMPT, + CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT, + CLICKUP_UPDATE_TASK_PROMPT, +) +from langchain_community.tools.clickup.tool import ClickupAction +from langchain_community.utilities.clickup import ClickupAPIWrapper + + +class ClickupToolkit(BaseToolkit): + """Clickup Toolkit. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by reading, creating, updating, deleting + data associated with this service. + + See https://python.langchain.com/docs/security for more information. + """ + + tools: List[BaseTool] = [] + + @classmethod + def from_clickup_api_wrapper( + cls, clickup_api_wrapper: ClickupAPIWrapper + ) -> "ClickupToolkit": + operations: List[Dict] = [ + { + "mode": "get_task", + "name": "Get task", + "description": CLICKUP_GET_TASK_PROMPT, + }, + { + "mode": "get_task_attribute", + "name": "Get task attribute", + "description": CLICKUP_GET_TASK_ATTRIBUTE_PROMPT, + }, + { + "mode": "get_teams", + "name": "Get Teams", + "description": CLICKUP_GET_ALL_TEAMS_PROMPT, + }, + { + "mode": "create_task", + "name": "Create Task", + "description": CLICKUP_TASK_CREATE_PROMPT, + }, + { + "mode": "create_list", + "name": "Create List", + "description": CLICKUP_LIST_CREATE_PROMPT, + }, + { + "mode": "create_folder", + "name": "Create Folder", + "description": CLICKUP_FOLDER_CREATE_PROMPT, + }, + { + "mode": "get_list", + "name": "Get all lists in the space", + "description": CLICKUP_GET_LIST_PROMPT, + }, + { + "mode": "get_folders", + "name": "Get all folders in the workspace", + "description": CLICKUP_GET_FOLDERS_PROMPT, + }, + { + "mode": "get_spaces", + "name": "Get all spaces in the workspace", + "description": CLICKUP_GET_SPACES_PROMPT, + }, + { + "mode": "update_task", + "name": "Update task", + "description": CLICKUP_UPDATE_TASK_PROMPT, + }, + { + "mode": "update_task_assignees", + "name": "Update task assignees", + "description": CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT, + }, + ] + tools = [ + ClickupAction( + name=action["name"], + description=action["description"], + mode=action["mode"], + api_wrapper=clickup_api_wrapper, + ) + for action in operations + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return self.tools diff --git a/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py b/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py new file mode 100644 index 00000000000..405ccd39357 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Optional + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.memory import BaseMemory +from langchain_core.messages import SystemMessage +from langchain_core.prompts.chat import MessagesPlaceholder +from langchain_core.tools import BaseTool + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def _get_default_system_message() -> SystemMessage: + return SystemMessage( + content=( + "Do your best to answer the questions. " + "Feel free to use any tools available to look up " + "relevant information, only if necessary" + ) + ) + + +def create_conversational_retrieval_agent( + llm: BaseLanguageModel, + tools: List[BaseTool], + remember_intermediate_steps: bool = True, + memory_key: str = "chat_history", + system_message: Optional[SystemMessage] = None, + verbose: bool = False, + max_token_limit: int = 2000, + **kwargs: Any, +) -> AgentExecutor: + """A convenience method for creating a conversational retrieval agent. + + Args: + llm: The language model to use, should be ChatOpenAI + tools: A list of tools the agent has access to + remember_intermediate_steps: Whether the agent should remember intermediate + steps or not. Intermediate steps refer to prior action/observation + pairs from previous questions. The benefit of remembering these is if + there is relevant information in there, the agent can use it to answer + follow up questions. The downside is it will take up more tokens. + memory_key: The name of the memory key in the prompt. + system_message: The system message to use. By default, a basic one will + be used. + verbose: Whether or not the final AgentExecutor should be verbose or not, + defaults to False. + max_token_limit: The max number of tokens to keep around in memory. + Defaults to 2000. + + Returns: + An agent executor initialized appropriately + """ + from langchain.agents.agent import AgentExecutor + from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( + AgentTokenBufferMemory, + ) + from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent + from langchain.memory.token_buffer import ConversationTokenBufferMemory + + if remember_intermediate_steps: + memory: BaseMemory = AgentTokenBufferMemory( + memory_key=memory_key, llm=llm, max_token_limit=max_token_limit + ) + else: + memory = ConversationTokenBufferMemory( + memory_key=memory_key, + return_messages=True, + output_key="output", + llm=llm, + max_token_limit=max_token_limit, + ) + + _system_message = system_message or _get_default_system_message() + prompt = OpenAIFunctionsAgent.create_prompt( + system_message=_system_message, + extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)], + ) + agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) + return AgentExecutor( + agent=agent, + tools=tools, + memory=memory, + verbose=verbose, + return_intermediate_steps=remember_intermediate_steps, + **kwargs, + ) diff --git a/libs/community/langchain_community/agent_toolkits/csv/__init__.py b/libs/community/langchain_community/agent_toolkits/csv/__init__.py new file mode 100644 index 00000000000..4b049802888 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/csv/__init__.py @@ -0,0 +1,26 @@ +from pathlib import Path +from typing import Any + +from langchain_core._api.path import as_import_path + + +def __getattr__(name: str) -> Any: + """Get attr name.""" + + if name == "create_csv_agent": + # Get directory of langchain package + HERE = Path(__file__).parents[3] + here = as_import_path(Path(__file__).parent, relative_to=HERE) + + old_path = "langchain." + here + "." + name + new_path = "langchain_experimental." + here + "." + name + raise ImportError( + "This agent has been moved to langchain experiment. " + "This agent relies on python REPL tool under the hood, so to use it " + "safely please sandbox the python REPL. " + "Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md " + "and https://github.com/langchain-ai/langchain/discussions/11680" + "To keep using this code as is, install langchain experimental and " + f"update your import statement from:\n `{old_path}` to `{new_path}`." + ) + raise AttributeError(f"{name} does not exist") diff --git a/libs/community/langchain_community/agent_toolkits/file_management/__init__.py b/libs/community/langchain_community/agent_toolkits/file_management/__init__.py new file mode 100644 index 00000000000..53ce9329f91 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/file_management/__init__.py @@ -0,0 +1,7 @@ +"""Local file management toolkit.""" + +from langchain_community.agent_toolkits.file_management.toolkit import ( + FileManagementToolkit, +) + +__all__ = ["FileManagementToolkit"] diff --git a/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py new file mode 100644 index 00000000000..538569755b5 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import List, Optional + +from langchain_core.pydantic_v1 import root_validator + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.file_management.copy import CopyFileTool +from langchain_community.tools.file_management.delete import DeleteFileTool +from langchain_community.tools.file_management.file_search import FileSearchTool +from langchain_community.tools.file_management.list_dir import ListDirectoryTool +from langchain_community.tools.file_management.move import MoveFileTool +from langchain_community.tools.file_management.read import ReadFileTool +from langchain_community.tools.file_management.write import WriteFileTool + +_FILE_TOOLS = { + # "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined] + tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined] + for tool_cls in [ + CopyFileTool, + DeleteFileTool, + FileSearchTool, + MoveFileTool, + ReadFileTool, + WriteFileTool, + ListDirectoryTool, + ] +} + + +class FileManagementToolkit(BaseToolkit): + """Toolkit for interacting with local files. + + *Security Notice*: This toolkit provides methods to interact with local files. + If providing this toolkit to an agent on an LLM, ensure you scope + the agent's permissions to only include the necessary permissions + to perform the desired operations. + + By **default** the agent will have access to all files within + the root dir and will be able to Copy, Delete, Move, Read, Write + and List files in that directory. + + Consider the following: + - Limit access to particular directories using `root_dir`. + - Use filesystem permissions to restrict access and permissions to only + the files and directories required by the agent. + - Limit the tools available to the agent to only the file operations + necessary for the agent's intended use. + - Sandbox the agent by running it in a container. + + See https://python.langchain.com/docs/security for more information. + """ + + root_dir: Optional[str] = None + """If specified, all file operations are made relative to root_dir.""" + selected_tools: Optional[List[str]] = None + """If provided, only provide the selected tools. Defaults to all.""" + + @root_validator + def validate_tools(cls, values: dict) -> dict: + selected_tools = values.get("selected_tools") or [] + for tool_name in selected_tools: + if tool_name not in _FILE_TOOLS: + raise ValueError( + f"File Tool of name {tool_name} not supported." + f" Permitted tools: {list(_FILE_TOOLS)}" + ) + return values + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + allowed_tools = self.selected_tools or _FILE_TOOLS.keys() + tools: List[BaseTool] = [] + for tool in allowed_tools: + tool_cls = _FILE_TOOLS[tool] + tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore + return tools + + +__all__ = ["FileManagementToolkit"] diff --git a/libs/community/langchain_community/agent_toolkits/github/__init__.py b/libs/community/langchain_community/agent_toolkits/github/__init__.py new file mode 100644 index 00000000000..bcd9368a52a --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/github/__init__.py @@ -0,0 +1 @@ +"""GitHub Toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/github/toolkit.py b/libs/community/langchain_community/agent_toolkits/github/toolkit.py new file mode 100644 index 00000000000..7c85504e1c6 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/github/toolkit.py @@ -0,0 +1,287 @@ +"""GitHub Toolkit.""" +from typing import Dict, List + +from langchain_core.pydantic_v1 import BaseModel, Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.github.prompt import ( + COMMENT_ON_ISSUE_PROMPT, + CREATE_BRANCH_PROMPT, + CREATE_FILE_PROMPT, + CREATE_PULL_REQUEST_PROMPT, + CREATE_REVIEW_REQUEST_PROMPT, + DELETE_FILE_PROMPT, + GET_FILES_FROM_DIRECTORY_PROMPT, + GET_ISSUE_PROMPT, + GET_ISSUES_PROMPT, + GET_PR_PROMPT, + LIST_BRANCHES_IN_REPO_PROMPT, + LIST_PRS_PROMPT, + LIST_PULL_REQUEST_FILES, + OVERVIEW_EXISTING_FILES_BOT_BRANCH, + OVERVIEW_EXISTING_FILES_IN_MAIN, + READ_FILE_PROMPT, + SEARCH_CODE_PROMPT, + SEARCH_ISSUES_AND_PRS_PROMPT, + SET_ACTIVE_BRANCH_PROMPT, + UPDATE_FILE_PROMPT, +) +from langchain_community.tools.github.tool import GitHubAction +from langchain_community.utilities.github import GitHubAPIWrapper + + +class NoInput(BaseModel): + no_input: str = Field("", description="No input required, e.g. `` (empty string).") + + +class GetIssue(BaseModel): + issue_number: int = Field(0, description="Issue number as an integer, e.g. `42`") + + +class CommentOnIssue(BaseModel): + input: str = Field(..., description="Follow the required formatting.") + + +class GetPR(BaseModel): + pr_number: int = Field(0, description="The PR number as an integer, e.g. `12`") + + +class CreatePR(BaseModel): + formatted_pr: str = Field(..., description="Follow the required formatting.") + + +class CreateFile(BaseModel): + formatted_file: str = Field(..., description="Follow the required formatting.") + + +class ReadFile(BaseModel): + formatted_filepath: str = Field( + ..., + description=( + "The full file path of the file you would like to read where the " + "path must NOT start with a slash, e.g. `some_dir/my_file.py`." + ), + ) + + +class UpdateFile(BaseModel): + formatted_file_update: str = Field( + ..., description="Strictly follow the provided rules." + ) + + +class DeleteFile(BaseModel): + formatted_filepath: str = Field( + ..., + description=( + "The full file path of the file you would like to delete" + " where the path must NOT start with a slash, e.g." + " `some_dir/my_file.py`. Only input a string," + " not the param name." + ), + ) + + +class DirectoryPath(BaseModel): + input: str = Field( + "", + description=( + "The path of the directory, e.g. `some_dir/inner_dir`." + " Only input a string, do not include the parameter name." + ), + ) + + +class BranchName(BaseModel): + branch_name: str = Field( + ..., description="The name of the branch, e.g. `my_branch`." + ) + + +class SearchCode(BaseModel): + search_query: str = Field( + ..., + description=( + "A keyword-focused natural language search" + "query for code, e.g. `MyFunctionName()`." + ), + ) + + +class CreateReviewRequest(BaseModel): + username: str = Field( + ..., + description="GitHub username of the user being requested, e.g. `my_username`.", + ) + + +class SearchIssuesAndPRs(BaseModel): + search_query: str = Field( + ..., + description="Natural language search query, e.g. `My issue title or topic`.", + ) + + +class GitHubToolkit(BaseToolkit): + """GitHub Toolkit. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by creating, deleting, or updating, + reading underlying data. + + For example, this toolkit can be used to create issues, pull requests, + and comments on GitHub. + + See [Security](https://python.langchain.com/docs/security) for more information. + """ + + tools: List[BaseTool] = [] + + @classmethod + def from_github_api_wrapper( + cls, github_api_wrapper: GitHubAPIWrapper + ) -> "GitHubToolkit": + operations: List[Dict] = [ + { + "mode": "get_issues", + "name": "Get Issues", + "description": GET_ISSUES_PROMPT, + "args_schema": NoInput, + }, + { + "mode": "get_issue", + "name": "Get Issue", + "description": GET_ISSUE_PROMPT, + "args_schema": GetIssue, + }, + { + "mode": "comment_on_issue", + "name": "Comment on Issue", + "description": COMMENT_ON_ISSUE_PROMPT, + "args_schema": CommentOnIssue, + }, + { + "mode": "list_open_pull_requests", + "name": "List open pull requests (PRs)", + "description": LIST_PRS_PROMPT, + "args_schema": NoInput, + }, + { + "mode": "get_pull_request", + "name": "Get Pull Request", + "description": GET_PR_PROMPT, + "args_schema": GetPR, + }, + { + "mode": "list_pull_request_files", + "name": "Overview of files included in PR", + "description": LIST_PULL_REQUEST_FILES, + "args_schema": GetPR, + }, + { + "mode": "create_pull_request", + "name": "Create Pull Request", + "description": CREATE_PULL_REQUEST_PROMPT, + "args_schema": CreatePR, + }, + { + "mode": "list_pull_request_files", + "name": "List Pull Requests' Files", + "description": LIST_PULL_REQUEST_FILES, + "args_schema": GetPR, + }, + { + "mode": "create_file", + "name": "Create File", + "description": CREATE_FILE_PROMPT, + "args_schema": CreateFile, + }, + { + "mode": "read_file", + "name": "Read File", + "description": READ_FILE_PROMPT, + "args_schema": ReadFile, + }, + { + "mode": "update_file", + "name": "Update File", + "description": UPDATE_FILE_PROMPT, + "args_schema": UpdateFile, + }, + { + "mode": "delete_file", + "name": "Delete File", + "description": DELETE_FILE_PROMPT, + "args_schema": DeleteFile, + }, + { + "mode": "list_files_in_main_branch", + "name": "Overview of existing files in Main branch", + "description": OVERVIEW_EXISTING_FILES_IN_MAIN, + "args_schema": NoInput, + }, + { + "mode": "list_files_in_bot_branch", + "name": "Overview of files in current working branch", + "description": OVERVIEW_EXISTING_FILES_BOT_BRANCH, + "args_schema": NoInput, + }, + { + "mode": "list_branches_in_repo", + "name": "List branches in this repository", + "description": LIST_BRANCHES_IN_REPO_PROMPT, + "args_schema": NoInput, + }, + { + "mode": "set_active_branch", + "name": "Set active branch", + "description": SET_ACTIVE_BRANCH_PROMPT, + "args_schema": BranchName, + }, + { + "mode": "create_branch", + "name": "Create a new branch", + "description": CREATE_BRANCH_PROMPT, + "args_schema": BranchName, + }, + { + "mode": "get_files_from_directory", + "name": "Get files from a directory", + "description": GET_FILES_FROM_DIRECTORY_PROMPT, + "args_schema": DirectoryPath, + }, + { + "mode": "search_issues_and_prs", + "name": "Search issues and pull requests", + "description": SEARCH_ISSUES_AND_PRS_PROMPT, + "args_schema": SearchIssuesAndPRs, + }, + { + "mode": "search_code", + "name": "Search code", + "description": SEARCH_CODE_PROMPT, + "args_schema": SearchCode, + }, + { + "mode": "create_review_request", + "name": "Create review request", + "description": CREATE_REVIEW_REQUEST_PROMPT, + "args_schema": CreateReviewRequest, + }, + ] + tools = [ + GitHubAction( + name=action["name"], + description=action["description"], + mode=action["mode"], + api_wrapper=github_api_wrapper, + args_schema=action.get("args_schema", None), + ) + for action in operations + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return self.tools diff --git a/libs/community/langchain_community/agent_toolkits/gitlab/__init__.py b/libs/community/langchain_community/agent_toolkits/gitlab/__init__.py new file mode 100644 index 00000000000..7d3ca720636 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/gitlab/__init__.py @@ -0,0 +1 @@ +"""GitLab Toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/gitlab/toolkit.py b/libs/community/langchain_community/agent_toolkits/gitlab/toolkit.py new file mode 100644 index 00000000000..1a56b821f5b --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/gitlab/toolkit.py @@ -0,0 +1,94 @@ +"""GitHub Toolkit.""" +from typing import Dict, List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.gitlab.prompt import ( + COMMENT_ON_ISSUE_PROMPT, + CREATE_FILE_PROMPT, + CREATE_PULL_REQUEST_PROMPT, + DELETE_FILE_PROMPT, + GET_ISSUE_PROMPT, + GET_ISSUES_PROMPT, + READ_FILE_PROMPT, + UPDATE_FILE_PROMPT, +) +from langchain_community.tools.gitlab.tool import GitLabAction +from langchain_community.utilities.gitlab import GitLabAPIWrapper + + +class GitLabToolkit(BaseToolkit): + """GitLab Toolkit. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by creating, deleting, or updating, + reading underlying data. + + For example, this toolkit can be used to create issues, pull requests, + and comments on GitLab. + + See https://python.langchain.com/docs/security for more information. + """ + + tools: List[BaseTool] = [] + + @classmethod + def from_gitlab_api_wrapper( + cls, gitlab_api_wrapper: GitLabAPIWrapper + ) -> "GitLabToolkit": + operations: List[Dict] = [ + { + "mode": "get_issues", + "name": "Get Issues", + "description": GET_ISSUES_PROMPT, + }, + { + "mode": "get_issue", + "name": "Get Issue", + "description": GET_ISSUE_PROMPT, + }, + { + "mode": "comment_on_issue", + "name": "Comment on Issue", + "description": COMMENT_ON_ISSUE_PROMPT, + }, + { + "mode": "create_pull_request", + "name": "Create Pull Request", + "description": CREATE_PULL_REQUEST_PROMPT, + }, + { + "mode": "create_file", + "name": "Create File", + "description": CREATE_FILE_PROMPT, + }, + { + "mode": "read_file", + "name": "Read File", + "description": READ_FILE_PROMPT, + }, + { + "mode": "update_file", + "name": "Update File", + "description": UPDATE_FILE_PROMPT, + }, + { + "mode": "delete_file", + "name": "Delete File", + "description": DELETE_FILE_PROMPT, + }, + ] + tools = [ + GitLabAction( + name=action["name"], + description=action["description"], + mode=action["mode"], + api_wrapper=gitlab_api_wrapper, + ) + for action in operations + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return self.tools diff --git a/libs/community/langchain_community/agent_toolkits/gmail/__init__.py b/libs/community/langchain_community/agent_toolkits/gmail/__init__.py new file mode 100644 index 00000000000..02e7f81659f --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/gmail/__init__.py @@ -0,0 +1 @@ +"""Gmail toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/gmail/toolkit.py b/libs/community/langchain_community/agent_toolkits/gmail/toolkit.py new file mode 100644 index 00000000000..1e4af3fd614 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/gmail/toolkit.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List + +from langchain_core.pydantic_v1 import Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.gmail.create_draft import GmailCreateDraft +from langchain_community.tools.gmail.get_message import GmailGetMessage +from langchain_community.tools.gmail.get_thread import GmailGetThread +from langchain_community.tools.gmail.search import GmailSearch +from langchain_community.tools.gmail.send_message import GmailSendMessage +from langchain_community.tools.gmail.utils import build_resource_service + +if TYPE_CHECKING: + # This is for linting and IDE typehints + from googleapiclient.discovery import Resource +else: + try: + # We do this so pydantic can resolve the types when instantiating + from googleapiclient.discovery import Resource + except ImportError: + pass + + +SCOPES = ["https://mail.google.com/"] + + +class GmailToolkit(BaseToolkit): + """Toolkit for interacting with Gmail. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by reading, creating, updating, deleting + data associated with this service. + + For example, this toolkit can be used to send emails on behalf of the + associated account. + + See https://python.langchain.com/docs/security for more information. + """ + + api_resource: Resource = Field(default_factory=build_resource_service) + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + GmailCreateDraft(api_resource=self.api_resource), + GmailSendMessage(api_resource=self.api_resource), + GmailSearch(api_resource=self.api_resource), + GmailGetMessage(api_resource=self.api_resource), + GmailGetThread(api_resource=self.api_resource), + ] diff --git a/libs/community/langchain_community/agent_toolkits/jira/__init__.py b/libs/community/langchain_community/agent_toolkits/jira/__init__.py new file mode 100644 index 00000000000..9f7c67558fa --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/jira/__init__.py @@ -0,0 +1 @@ +"""Jira Toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/jira/toolkit.py b/libs/community/langchain_community/agent_toolkits/jira/toolkit.py new file mode 100644 index 00000000000..425c8d4c065 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/jira/toolkit.py @@ -0,0 +1,70 @@ +from typing import Dict, List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.jira.prompt import ( + JIRA_CATCH_ALL_PROMPT, + JIRA_CONFLUENCE_PAGE_CREATE_PROMPT, + JIRA_GET_ALL_PROJECTS_PROMPT, + JIRA_ISSUE_CREATE_PROMPT, + JIRA_JQL_PROMPT, +) +from langchain_community.tools.jira.tool import JiraAction +from langchain_community.utilities.jira import JiraAPIWrapper + + +class JiraToolkit(BaseToolkit): + """Jira Toolkit. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by creating, deleting, or updating, + reading underlying data. + + See https://python.langchain.com/docs/security for more information. + """ + + tools: List[BaseTool] = [] + + @classmethod + def from_jira_api_wrapper(cls, jira_api_wrapper: JiraAPIWrapper) -> "JiraToolkit": + operations: List[Dict] = [ + { + "mode": "jql", + "name": "JQL Query", + "description": JIRA_JQL_PROMPT, + }, + { + "mode": "get_projects", + "name": "Get Projects", + "description": JIRA_GET_ALL_PROJECTS_PROMPT, + }, + { + "mode": "create_issue", + "name": "Create Issue", + "description": JIRA_ISSUE_CREATE_PROMPT, + }, + { + "mode": "other", + "name": "Catch all Jira API call", + "description": JIRA_CATCH_ALL_PROMPT, + }, + { + "mode": "create_page", + "name": "Create confluence page", + "description": JIRA_CONFLUENCE_PAGE_CREATE_PROMPT, + }, + ] + tools = [ + JiraAction( + name=action["name"], + description=action["description"], + mode=action["mode"], + api_wrapper=jira_api_wrapper, + ) + for action in operations + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return self.tools diff --git a/libs/community/langchain_community/agent_toolkits/json/__init__.py b/libs/community/langchain_community/agent_toolkits/json/__init__.py new file mode 100644 index 00000000000..bfab0ec6f83 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/json/__init__.py @@ -0,0 +1 @@ +"""Json agent.""" diff --git a/libs/community/langchain_community/agent_toolkits/json/base.py b/libs/community/langchain_community/agent_toolkits/json/base.py new file mode 100644 index 00000000000..06830d90228 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/json/base.py @@ -0,0 +1,59 @@ +"""Json agent.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX +from langchain_community.agent_toolkits.json.toolkit import JsonToolkit + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def create_json_agent( + llm: BaseLanguageModel, + toolkit: JsonToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = JSON_PREFIX, + suffix: str = JSON_SUFFIX, + format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a json agent from an LLM and tools.""" + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + + tools = toolkit.get_tools() + prompt_params = ( + {"format_instructions": format_instructions} + if format_instructions is not None + else {} + ) + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + **prompt_params, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/community/langchain_community/agent_toolkits/json/prompt.py b/libs/community/langchain_community/agent_toolkits/json/prompt.py new file mode 100644 index 00000000000..a3b7584aca2 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/json/prompt.py @@ -0,0 +1,25 @@ +# flake8: noqa + +JSON_PREFIX = """You are an agent designed to interact with JSON. +Your goal is to return a final answer by interacting with the JSON. +You have access to the following tools which help you learn more about the JSON you are interacting with. +Only use the below tools. Only use the information returned by the below tools to construct your final answer. +Do not make up any information that is not contained in the JSON. +Your input to the tools should be in the form of `data["key"][0]` where `data` is the JSON blob you are interacting with, and the syntax used is Python. +You should only use keys that you know for a fact exist. You must validate that a key exists by seeing it previously when calling `json_spec_list_keys`. +If you have not seen a key in one of those responses, you cannot use it. +You should only add one key at a time to the path. You cannot add multiple keys at once. +If you encounter a "KeyError", go back to the previous key, look at the available keys, and try again. + +If the question does not seem to be related to the JSON, just return "I don't know" as the answer. +Always begin your interaction with the `json_spec_list_keys` tool with input "data" to see what keys exist in the JSON. + +Note that sometimes the value at a given path is large. In this case, you will get an error "Value is a large dictionary, should explore its keys directly". +In this case, you should ALWAYS follow up by using the `json_spec_list_keys` tool to see what keys exist at that path. +Do not simply refer the user to the JSON or a section of the JSON, as this is not a valid answer. Keep digging until you find the answer and explicitly return it. +""" +JSON_SUFFIX = """Begin!" + +Question: {input} +Thought: I should look at the keys that exist in data to see what I have access to +{agent_scratchpad}""" diff --git a/libs/community/langchain_community/agent_toolkits/json/toolkit.py b/libs/community/langchain_community/agent_toolkits/json/toolkit.py new file mode 100644 index 00000000000..8a4aa00a6ea --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/json/toolkit.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.json.tool import ( + JsonGetValueTool, + JsonListKeysTool, + JsonSpec, +) + + +class JsonToolkit(BaseToolkit): + """Toolkit for interacting with a JSON spec.""" + + spec: JsonSpec + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + JsonListKeysTool(spec=self.spec), + JsonGetValueTool(spec=self.spec), + ] diff --git a/libs/community/langchain_community/agent_toolkits/multion/__init__.py b/libs/community/langchain_community/agent_toolkits/multion/__init__.py new file mode 100644 index 00000000000..56c7215199b --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/multion/__init__.py @@ -0,0 +1 @@ +"""MultiOn Toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/multion/toolkit.py b/libs/community/langchain_community/agent_toolkits/multion/toolkit.py new file mode 100644 index 00000000000..1d8f9029112 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/multion/toolkit.py @@ -0,0 +1,33 @@ +"""MultiOn agent.""" +from __future__ import annotations + +from typing import List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.multion.close_session import MultionCloseSession +from langchain_community.tools.multion.create_session import MultionCreateSession +from langchain_community.tools.multion.update_session import MultionUpdateSession + + +class MultionToolkit(BaseToolkit): + """Toolkit for interacting with the Browser Agent. + + **Security Note**: This toolkit contains tools that interact with the + user's browser via the multion API which grants an agent + access to the user's browser. + + Please review the documentation for the multion API to understand + the security implications of using this toolkit. + + See https://python.langchain.com/docs/security for more information. + """ + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [MultionCreateSession(), MultionUpdateSession(), MultionCloseSession()] diff --git a/libs/community/langchain_community/agent_toolkits/nasa/__init__.py b/libs/community/langchain_community/agent_toolkits/nasa/__init__.py new file mode 100644 index 00000000000..a13c3ec706c --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/nasa/__init__.py @@ -0,0 +1 @@ +"""NASA Toolkit""" diff --git a/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py b/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py new file mode 100644 index 00000000000..46edd98af3f --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py @@ -0,0 +1,57 @@ +from typing import Dict, List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.nasa.prompt import ( + NASA_CAPTIONS_PROMPT, + NASA_MANIFEST_PROMPT, + NASA_METADATA_PROMPT, + NASA_SEARCH_PROMPT, +) +from langchain_community.tools.nasa.tool import NasaAction +from langchain_community.utilities.nasa import NasaAPIWrapper + + +class NasaToolkit(BaseToolkit): + """Nasa Toolkit.""" + + tools: List[BaseTool] = [] + + @classmethod + def from_nasa_api_wrapper(cls, nasa_api_wrapper: NasaAPIWrapper) -> "NasaToolkit": + operations: List[Dict] = [ + { + "mode": "search_media", + "name": "Search NASA Image and Video Library media", + "description": NASA_SEARCH_PROMPT, + }, + { + "mode": "get_media_metadata_manifest", + "name": "Get NASA Image and Video Library media metadata manifest", + "description": NASA_MANIFEST_PROMPT, + }, + { + "mode": "get_media_metadata_location", + "name": "Get NASA Image and Video Library media metadata location", + "description": NASA_METADATA_PROMPT, + }, + { + "mode": "get_video_captions_location", + "name": "Get NASA Image and Video Library video captions location", + "description": NASA_CAPTIONS_PROMPT, + }, + ] + tools = [ + NasaAction( + name=action["name"], + description=action["description"], + mode=action["mode"], + api_wrapper=nasa_api_wrapper, + ) + for action in operations + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return self.tools diff --git a/libs/langchain/tests/integration_tests/document_loaders/parsers/__init__.py b/libs/community/langchain_community/agent_toolkits/nla/__init__.py similarity index 100% rename from libs/langchain/tests/integration_tests/document_loaders/parsers/__init__.py rename to libs/community/langchain_community/agent_toolkits/nla/__init__.py diff --git a/libs/community/langchain_community/agent_toolkits/nla/tool.py b/libs/community/langchain_community/agent_toolkits/nla/tool.py new file mode 100644 index 00000000000..b947c7df977 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/nla/tool.py @@ -0,0 +1,58 @@ +"""Tool for interacting with a single API with natural language definition.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.tools import Tool + +from langchain_community.tools.openapi.utils.api_models import APIOperation +from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec +from langchain_community.utilities.requests import Requests + +if TYPE_CHECKING: + from langchain.chains.api.openapi.chain import OpenAPIEndpointChain + + +class NLATool(Tool): + """Natural Language API Tool.""" + + @classmethod + def from_open_api_endpoint_chain( + cls, chain: OpenAPIEndpointChain, api_title: str + ) -> "NLATool": + """Convert an endpoint chain to an API endpoint tool.""" + expanded_name = ( + f'{api_title.replace(" ", "_")}.{chain.api_operation.operation_id}' + ) + description = ( + f"I'm an AI from {api_title}. Instruct what you want," + " and I'll assist via an API with description:" + f" {chain.api_operation.description}" + ) + return cls(name=expanded_name, func=chain.run, description=description) + + @classmethod + def from_llm_and_method( + cls, + llm: BaseLanguageModel, + path: str, + method: str, + spec: OpenAPISpec, + requests: Optional[Requests] = None, + verbose: bool = False, + return_intermediate_steps: bool = False, + **kwargs: Any, + ) -> "NLATool": + """Instantiate the tool from the specified path and method.""" + api_operation = APIOperation.from_openapi_spec(spec, path, method) + chain = OpenAPIEndpointChain.from_api_operation( + api_operation, + llm, + requests=requests, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + **kwargs, + ) + return cls.from_open_api_endpoint_chain(chain, spec.info.title) diff --git a/libs/community/langchain_community/agent_toolkits/nla/toolkit.py b/libs/community/langchain_community/agent_toolkits/nla/toolkit.py new file mode 100644 index 00000000000..ae02aeb53bc --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/nla/toolkit.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from typing import Any, List, Optional, Sequence + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.pydantic_v1 import Field +from langchain_core.tools import BaseTool + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.agent_toolkits.nla.tool import NLATool +from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec +from langchain_community.tools.plugin import AIPlugin +from langchain_community.utilities.requests import Requests + + +class NLAToolkit(BaseToolkit): + """Natural Language API Toolkit. + + *Security Note*: This toolkit creates tools that enable making calls + to an Open API compliant API. + + The tools created by this toolkit may be able to make GET, POST, + PATCH, PUT, DELETE requests to any of the exposed endpoints on + the API. + + Control access to who can use this toolkit. + + See https://python.langchain.com/docs/security for more information. + """ + + nla_tools: Sequence[NLATool] = Field(...) + """List of API Endpoint Tools.""" + + def get_tools(self) -> List[BaseTool]: + """Get the tools for all the API operations.""" + return list(self.nla_tools) + + @staticmethod + def _get_http_operation_tools( + llm: BaseLanguageModel, + spec: OpenAPISpec, + requests: Optional[Requests] = None, + verbose: bool = False, + **kwargs: Any, + ) -> List[NLATool]: + """Get the tools for all the API operations.""" + if not spec.paths: + return [] + http_operation_tools = [] + for path in spec.paths: + for method in spec.get_methods_for_path(path): + endpoint_tool = NLATool.from_llm_and_method( + llm=llm, + path=path, + method=method, + spec=spec, + requests=requests, + verbose=verbose, + **kwargs, + ) + http_operation_tools.append(endpoint_tool) + return http_operation_tools + + @classmethod + def from_llm_and_spec( + cls, + llm: BaseLanguageModel, + spec: OpenAPISpec, + requests: Optional[Requests] = None, + verbose: bool = False, + **kwargs: Any, + ) -> NLAToolkit: + """Instantiate the toolkit by creating tools for each operation.""" + http_operation_tools = cls._get_http_operation_tools( + llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs + ) + return cls(nla_tools=http_operation_tools) + + @classmethod + def from_llm_and_url( + cls, + llm: BaseLanguageModel, + open_api_url: str, + requests: Optional[Requests] = None, + verbose: bool = False, + **kwargs: Any, + ) -> NLAToolkit: + """Instantiate the toolkit from an OpenAPI Spec URL""" + spec = OpenAPISpec.from_url(open_api_url) + return cls.from_llm_and_spec( + llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs + ) + + @classmethod + def from_llm_and_ai_plugin( + cls, + llm: BaseLanguageModel, + ai_plugin: AIPlugin, + requests: Optional[Requests] = None, + verbose: bool = False, + **kwargs: Any, + ) -> NLAToolkit: + """Instantiate the toolkit from an OpenAPI Spec URL""" + spec = OpenAPISpec.from_url(ai_plugin.api.url) + # TODO: Merge optional Auth information with the `requests` argument + return cls.from_llm_and_spec( + llm=llm, + spec=spec, + requests=requests, + verbose=verbose, + **kwargs, + ) + + @classmethod + def from_llm_and_ai_plugin_url( + cls, + llm: BaseLanguageModel, + ai_plugin_url: str, + requests: Optional[Requests] = None, + verbose: bool = False, + **kwargs: Any, + ) -> NLAToolkit: + """Instantiate the toolkit from an OpenAPI Spec URL""" + plugin = AIPlugin.from_url(ai_plugin_url) + return cls.from_llm_and_ai_plugin( + llm=llm, ai_plugin=plugin, requests=requests, verbose=verbose, **kwargs + ) diff --git a/libs/community/langchain_community/agent_toolkits/office365/__init__.py b/libs/community/langchain_community/agent_toolkits/office365/__init__.py new file mode 100644 index 00000000000..acd0a87f955 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/office365/__init__.py @@ -0,0 +1 @@ +"""Office365 toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/office365/toolkit.py b/libs/community/langchain_community/agent_toolkits/office365/toolkit.py new file mode 100644 index 00000000000..990264ce1db --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/office365/toolkit.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List + +from langchain_core.pydantic_v1 import Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.office365.create_draft_message import ( + O365CreateDraftMessage, +) +from langchain_community.tools.office365.events_search import O365SearchEvents +from langchain_community.tools.office365.messages_search import O365SearchEmails +from langchain_community.tools.office365.send_event import O365SendEvent +from langchain_community.tools.office365.send_message import O365SendMessage +from langchain_community.tools.office365.utils import authenticate + +if TYPE_CHECKING: + from O365 import Account + + +class O365Toolkit(BaseToolkit): + """Toolkit for interacting with Office 365. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by reading, creating, updating, deleting + data associated with this service. + + For example, this toolkit can be used search through emails and events, + send messages and event invites, and create draft messages. + + Please make sure that the permissions given by this toolkit + are appropriate for your use case. + + See https://python.langchain.com/docs/security for more information. + """ + + account: Account = Field(default_factory=authenticate) + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + O365SearchEvents(), + O365CreateDraftMessage(), + O365SearchEmails(), + O365SendEvent(), + O365SendMessage(), + ] diff --git a/libs/community/langchain_community/agent_toolkits/openapi/__init__.py b/libs/community/langchain_community/agent_toolkits/openapi/__init__.py new file mode 100644 index 00000000000..5d06e271cd0 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/openapi/__init__.py @@ -0,0 +1 @@ +"""OpenAPI spec agent.""" diff --git a/libs/community/langchain_community/agent_toolkits/openapi/base.py b/libs/community/langchain_community/agent_toolkits/openapi/base.py new file mode 100644 index 00000000000..63a513811ea --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/openapi/base.py @@ -0,0 +1,83 @@ +"""OpenAPI spec agent.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.openapi.prompt import ( + OPENAPI_PREFIX, + OPENAPI_SUFFIX, +) +from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def create_openapi_agent( + llm: BaseLanguageModel, + toolkit: OpenAPIToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = OPENAPI_PREFIX, + suffix: str = OPENAPI_SUFFIX, + format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + verbose: bool = False, + return_intermediate_steps: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct an OpenAPI agent from an LLM and tools. + + *Security Note*: When creating an OpenAPI agent, check the permissions + and capabilities of the underlying toolkit. + + For example, if the default implementation of OpenAPIToolkit + uses the RequestsToolkit which contains tools to make arbitrary + network requests against any URL (e.g., GET, POST, PATCH, PUT, DELETE), + + Control access to who can submit issue requests using this toolkit and + what network access it has. + + See https://python.langchain.com/docs/security for more information. + """ + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + + tools = toolkit.get_tools() + prompt_params = ( + {"format_instructions": format_instructions} + if format_instructions is not None + else {} + ) + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + **prompt_params, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/community/langchain_community/agent_toolkits/openapi/planner.py b/libs/community/langchain_community/agent_toolkits/openapi/planner.py new file mode 100644 index 00000000000..e95e011e395 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/openapi/planner.py @@ -0,0 +1,374 @@ +"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach.""" +import json +import re +from functools import partial +from typing import Any, Callable, Dict, List, Optional + +import yaml +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.pydantic_v1 import Field +from langchain_core.tools import BaseTool, Tool + +from langchain_community.agent_toolkits.openapi.planner_prompt import ( + API_CONTROLLER_PROMPT, + API_CONTROLLER_TOOL_DESCRIPTION, + API_CONTROLLER_TOOL_NAME, + API_ORCHESTRATOR_PROMPT, + API_PLANNER_PROMPT, + API_PLANNER_TOOL_DESCRIPTION, + API_PLANNER_TOOL_NAME, + PARSING_DELETE_PROMPT, + PARSING_GET_PROMPT, + PARSING_PATCH_PROMPT, + PARSING_POST_PROMPT, + PARSING_PUT_PROMPT, + REQUESTS_DELETE_TOOL_DESCRIPTION, + REQUESTS_GET_TOOL_DESCRIPTION, + REQUESTS_PATCH_TOOL_DESCRIPTION, + REQUESTS_POST_TOOL_DESCRIPTION, + REQUESTS_PUT_TOOL_DESCRIPTION, +) +from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec +from langchain_community.llms import OpenAI +from langchain_community.tools.requests.tool import BaseRequestsTool +from langchain_community.utilities.requests import RequestsWrapper + +# +# Requests tools with LLM-instructed extraction of truncated responses. +# +# Of course, truncating so bluntly may lose a lot of valuable +# information in the response. +# However, the goal for now is to have only a single inference step. +MAX_RESPONSE_LENGTH = 5000 +"""Maximum length of the response to be returned.""" + + +def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any: + from langchain.chains.llm import LLMChain + + return LLMChain( + llm=OpenAI(), + prompt=prompt, + ) + + +def _get_default_llm_chain_factory( + prompt: BasePromptTemplate, +) -> Callable[[], Any]: + """Returns a default LLMChain factory.""" + return partial(_get_default_llm_chain, prompt) + + +class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): + """Requests GET tool with LLM-instructed extraction of truncated responses.""" + + name: str = "requests_get" + """Tool name.""" + description = REQUESTS_GET_TOOL_DESCRIPTION + """Tool description.""" + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """Maximum length of the response to be returned.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT) + ) + """LLMChain used to extract the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + data_params = data.get("params") + response = self.requests_wrapper.get(data["url"], params=data_params) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): + """Requests POST tool with LLM-instructed extraction of truncated responses.""" + + name: str = "requests_post" + """Tool name.""" + description = REQUESTS_POST_TOOL_DESCRIPTION + """Tool description.""" + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """Maximum length of the response to be returned.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT) + ) + """LLMChain used to extract the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + response = self.requests_wrapper.post(data["url"], data["data"]) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool): + """Requests PATCH tool with LLM-instructed extraction of truncated responses.""" + + name: str = "requests_patch" + """Tool name.""" + description = REQUESTS_PATCH_TOOL_DESCRIPTION + """Tool description.""" + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """Maximum length of the response to be returned.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT) + ) + """LLMChain used to extract the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + response = self.requests_wrapper.patch(data["url"], data["data"]) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool): + """Requests PUT tool with LLM-instructed extraction of truncated responses.""" + + name: str = "requests_put" + """Tool name.""" + description = REQUESTS_PUT_TOOL_DESCRIPTION + """Tool description.""" + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """Maximum length of the response to be returned.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT) + ) + """LLMChain used to extract the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + response = self.requests_wrapper.put(data["url"], data["data"]) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool): + """A tool that sends a DELETE request and parses the response.""" + + name: str = "requests_delete" + """The name of the tool.""" + description = REQUESTS_DELETE_TOOL_DESCRIPTION + """The description of the tool.""" + + response_length: Optional[int] = MAX_RESPONSE_LENGTH + """The maximum length of the response.""" + llm_chain: Any = Field( + default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT) + ) + """The LLM chain used to parse the response.""" + + def _run(self, text: str) -> str: + from langchain.output_parsers.json import parse_json_markdown + + try: + data = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise e + response = self.requests_wrapper.delete(data["url"]) + response = response[: self.response_length] + return self.llm_chain.predict( + response=response, instructions=data["output_instructions"] + ).strip() + + async def _arun(self, text: str) -> str: + raise NotImplementedError() + + +# +# Orchestrator, planner, controller. +# +def _create_api_planner_tool( + api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel +) -> Tool: + from langchain.chains.llm import LLMChain + + endpoint_descriptions = [ + f"{name} {description}" for name, description, _ in api_spec.endpoints + ] + prompt = PromptTemplate( + template=API_PLANNER_PROMPT, + input_variables=["query"], + partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)}, + ) + chain = LLMChain(llm=llm, prompt=prompt) + tool = Tool( + name=API_PLANNER_TOOL_NAME, + description=API_PLANNER_TOOL_DESCRIPTION, + func=chain.run, + ) + return tool + + +def _create_api_controller_agent( + api_url: str, + api_docs: str, + requests_wrapper: RequestsWrapper, + llm: BaseLanguageModel, +) -> Any: + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + + get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT) + post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT) + tools: List[BaseTool] = [ + RequestsGetToolWithParsing( + requests_wrapper=requests_wrapper, llm_chain=get_llm_chain + ), + RequestsPostToolWithParsing( + requests_wrapper=requests_wrapper, llm_chain=post_llm_chain + ), + ] + prompt = PromptTemplate( + template=API_CONTROLLER_PROMPT, + input_variables=["input", "agent_scratchpad"], + partial_variables={ + "api_url": api_url, + "api_docs": api_docs, + "tool_names": ", ".join([tool.name for tool in tools]), + "tool_descriptions": "\n".join( + [f"{tool.name}: {tool.description}" for tool in tools] + ), + }, + ) + agent = ZeroShotAgent( + llm_chain=LLMChain(llm=llm, prompt=prompt), + allowed_tools=[tool.name for tool in tools], + ) + return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True) + + +def _create_api_controller_tool( + api_spec: ReducedOpenAPISpec, + requests_wrapper: RequestsWrapper, + llm: BaseLanguageModel, +) -> Tool: + """Expose controller as a tool. + + The tool is invoked with a plan from the planner, and dynamically + creates a controller agent with relevant documentation only to + constrain the context. + """ + + base_url = api_spec.servers[0]["url"] # TODO: do better. + + def _create_and_run_api_controller_agent(plan_str: str) -> str: + pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*" + matches = re.findall(pattern, plan_str) + endpoint_names = [ + "{method} {route}".format(method=method, route=route.split("?")[0]) + for method, route in matches + ] + docs_str = "" + for endpoint_name in endpoint_names: + found_match = False + for name, _, docs in api_spec.endpoints: + regex_name = re.compile(re.sub("\{.*?\}", ".*", name)) + if regex_name.match(endpoint_name): + found_match = True + docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n" + if not found_match: + raise ValueError(f"{endpoint_name} endpoint does not exist.") + + agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm) + return agent.run(plan_str) + + return Tool( + name=API_CONTROLLER_TOOL_NAME, + func=_create_and_run_api_controller_agent, + description=API_CONTROLLER_TOOL_DESCRIPTION, + ) + + +def create_openapi_agent( + api_spec: ReducedOpenAPISpec, + requests_wrapper: RequestsWrapper, + llm: BaseLanguageModel, + shared_memory: Optional[Any] = None, + callback_manager: Optional[BaseCallbackManager] = None, + verbose: bool = True, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> Any: + """Instantiate OpenAI API planner and controller for a given spec. + + Inject credentials via requests_wrapper. + + We use a top-level "orchestrator" agent to invoke the planner and controller, + rather than a top-level planner + that invokes a controller with its plan. This is to keep the planner simple. + """ + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + + tools = [ + _create_api_planner_tool(api_spec, llm), + _create_api_controller_tool(api_spec, requests_wrapper, llm), + ] + prompt = PromptTemplate( + template=API_ORCHESTRATOR_PROMPT, + input_variables=["input", "agent_scratchpad"], + partial_variables={ + "tool_names": ", ".join([tool.name for tool in tools]), + "tool_descriptions": "\n".join( + [f"{tool.name}: {tool.description}" for tool in tools] + ), + }, + ) + agent = ZeroShotAgent( + llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory), + allowed_tools=[tool.name for tool in tools], + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/community/langchain_community/agent_toolkits/openapi/planner_prompt.py b/libs/community/langchain_community/agent_toolkits/openapi/planner_prompt.py new file mode 100644 index 00000000000..ec99e823e06 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/openapi/planner_prompt.py @@ -0,0 +1,235 @@ +# flake8: noqa + +from langchain_core.prompts.prompt import PromptTemplate + + +API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API. + +You should: +1) evaluate whether the user query can be solved by the API documentated below. If no, say why. +2) if yes, generate a plan of API calls and say what they are doing step by step. +3) If the plan includes a DELETE call, you should always return an ask from the User for authorization first unless the User has specifically asked to delete something. + +You should only use API endpoints documented below ("Endpoints you can use:"). +You can only use the DELETE tool if the User has specifically asked to delete something. Otherwise, you should return a request authorization from the User first. +Some user queries can be resolved in a single API call, but some will require several API calls. +The plan will be passed to an API controller that can format it into web requests and return the responses. + +---- + +Here are some examples: + +Fake endpoints for examples: +GET /user to get information about the current user +GET /products/search search across products +POST /users/{{id}}/cart to add products to a user's cart +PATCH /users/{{id}}/cart to update a user's cart +PUT /users/{{id}}/coupon to apply idempotent coupon to a user's cart +DELETE /users/{{id}}/cart to delete a user's cart + +User query: tell me a joke +Plan: Sorry, this API's domain is shopping, not comedy. + +User query: I want to buy a couch +Plan: 1. GET /products with a query param to search for couches +2. GET /user to find the user's id +3. POST /users/{{id}}/cart to add a couch to the user's cart + +User query: I want to add a lamp to my cart +Plan: 1. GET /products with a query param to search for lamps +2. GET /user to find the user's id +3. PATCH /users/{{id}}/cart to add a lamp to the user's cart + +User query: I want to add a coupon to my cart +Plan: 1. GET /user to find the user's id +2. PUT /users/{{id}}/coupon to apply the coupon + +User query: I want to delete my cart +Plan: 1. GET /user to find the user's id +2. DELETE required. Did user specify DELETE or previously authorize? Yes, proceed. +3. DELETE /users/{{id}}/cart to delete the user's cart + +User query: I want to start a new cart +Plan: 1. GET /user to find the user's id +2. DELETE required. Did user specify DELETE or previously authorize? No, ask for authorization. +3. Are you sure you want to delete your cart? +---- + +Here are endpoints you can use. Do not reference any of the endpoints above. + +{endpoints} + +---- + +User query: {query} +Plan:""" +API_PLANNER_TOOL_NAME = "api_planner" +API_PLANNER_TOOL_DESCRIPTION = f"Can be used to generate the right API calls to assist with a user query, like {API_PLANNER_TOOL_NAME}(query). Should always be called before trying to call the API controller." + +# Execution. +API_CONTROLLER_PROMPT = """You are an agent that gets a sequence of API calls and given their documentation, should execute them and return the final response. +If you cannot complete them and run into issues, you should explain the issue. If you're unable to resolve an API call, you can retry the API call. When interacting with API objects, you should extract ids for inputs to other API calls but ids and names for outputs returned to the User. + + +Here is documentation on the API: +Base url: {api_url} +Endpoints: +{api_docs} + + +Here are tools to execute requests against the API: {tool_descriptions} + + +Starting below, you should follow this format: + +Plan: the plan of API calls to execute +Thought: you should always think about what to do +Action: the action to take, should be one of the tools [{tool_names}] +Action Input: the input to the action +Observation: the output of the action +... (this Thought/Action/Action Input/Observation can repeat N times) +Thought: I am finished executing the plan (or, I cannot finish executing the plan without knowing some other information.) +Final Answer: the final output from executing the plan or missing information I'd need to re-plan correctly. + + +Begin! + +Plan: {input} +Thought: +{agent_scratchpad} +""" +API_CONTROLLER_TOOL_NAME = "api_controller" +API_CONTROLLER_TOOL_DESCRIPTION = f"Can be used to execute a plan of API calls, like {API_CONTROLLER_TOOL_NAME}(plan)." + +# Orchestrate planning + execution. +# The goal is to have an agent at the top-level (e.g. so it can recover from errors and re-plan) while +# keeping planning (and specifically the planning prompt) simple. +API_ORCHESTRATOR_PROMPT = """You are an agent that assists with user queries against API, things like querying information or creating resources. +Some user queries can be resolved in a single API call, particularly if you can find appropriate params from the OpenAPI spec; though some require several API calls. +You should always plan your API calls first, and then execute the plan second. +If the plan includes a DELETE call, be sure to ask the User for authorization first unless the User has specifically asked to delete something. +You should never return information without executing the api_controller tool. + + +Here are the tools to plan and execute API requests: {tool_descriptions} + + +Starting below, you should follow this format: + +User query: the query a User wants help with related to the API +Thought: you should always think about what to do +Action: the action to take, should be one of the tools [{tool_names}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can repeat N times) +Thought: I am finished executing a plan and have the information the user asked for or the data the user asked to create +Final Answer: the final output from executing the plan + + +Example: +User query: can you add some trendy stuff to my shopping cart. +Thought: I should plan API calls first. +Action: api_planner +Action Input: I need to find the right API calls to add trendy items to the users shopping cart +Observation: 1) GET /items with params 'trending' is 'True' to get trending item ids +2) GET /user to get user +3) POST /cart to post the trending items to the user's cart +Thought: I'm ready to execute the API calls. +Action: api_controller +Action Input: 1) GET /items params 'trending' is 'True' to get trending item ids +2) GET /user to get user +3) POST /cart to post the trending items to the user's cart +... + +Begin! + +User query: {input} +Thought: I should generate a plan to help with this query and then copy that plan exactly to the controller. +{agent_scratchpad}""" + +REQUESTS_GET_TOOL_DESCRIPTION = """Use this to GET content from a website. +Input to the tool should be a json string with 3 keys: "url", "params" and "output_instructions". +The value of "url" should be a string. +The value of "params" should be a dict of the needed and available parameters from the OpenAPI spec related to the endpoint. +If parameters are not needed, or not available, leave it empty. +The value of "output_instructions" should be instructions on what information to extract from the response, +for example the id(s) for a resource(s) that the GET request fetches. +""" + +PARSING_GET_PROMPT = PromptTemplate( + template="""Here is an API response:\n\n{response}\n\n==== +Your task is to extract some information according to these instructions: {instructions} +When working with API objects, you should usually use ids over names. +If the response indicates an error, you should instead output a summary of the error. + +Output:""", + input_variables=["response", "instructions"], +) + +REQUESTS_POST_TOOL_DESCRIPTION = """Use this when you want to POST to a website. +Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions". +The value of "url" should be a string. +The value of "data" should be a dictionary of key-value pairs you want to POST to the url. +The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the POST request creates. +Always use double quotes for strings in the json string.""" + +PARSING_POST_PROMPT = PromptTemplate( + template="""Here is an API response:\n\n{response}\n\n==== +Your task is to extract some information according to these instructions: {instructions} +When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response. +If the response indicates an error, you should instead output a summary of the error. + +Output:""", + input_variables=["response", "instructions"], +) + +REQUESTS_PATCH_TOOL_DESCRIPTION = """Use this when you want to PATCH content on a website. +Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions". +The value of "url" should be a string. +The value of "data" should be a dictionary of key-value pairs of the body params available in the OpenAPI spec you want to PATCH the content with at the url. +The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PATCH request creates. +Always use double quotes for strings in the json string.""" + +PARSING_PATCH_PROMPT = PromptTemplate( + template="""Here is an API response:\n\n{response}\n\n==== +Your task is to extract some information according to these instructions: {instructions} +When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response. +If the response indicates an error, you should instead output a summary of the error. + +Output:""", + input_variables=["response", "instructions"], +) + +REQUESTS_PUT_TOOL_DESCRIPTION = """Use this when you want to PUT to a website. +Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions". +The value of "url" should be a string. +The value of "data" should be a dictionary of key-value pairs you want to PUT to the url. +The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PUT request creates. +Always use double quotes for strings in the json string.""" + +PARSING_PUT_PROMPT = PromptTemplate( + template="""Here is an API response:\n\n{response}\n\n==== +Your task is to extract some information according to these instructions: {instructions} +When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response. +If the response indicates an error, you should instead output a summary of the error. + +Output:""", + input_variables=["response", "instructions"], +) + +REQUESTS_DELETE_TOOL_DESCRIPTION = """ONLY USE THIS TOOL WHEN THE USER HAS SPECIFICALLY REQUESTED TO DELETE CONTENT FROM A WEBSITE. +Input to the tool should be a json string with 2 keys: "url", and "output_instructions". +The value of "url" should be a string. +The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the DELETE request creates. +Always use double quotes for strings in the json string. +ONLY USE THIS TOOL IF THE USER HAS SPECIFICALLY REQUESTED TO DELETE SOMETHING.""" + +PARSING_DELETE_PROMPT = PromptTemplate( + template="""Here is an API response:\n\n{response}\n\n==== +Your task is to extract some information according to these instructions: {instructions} +When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response. +If the response indicates an error, you should instead output a summary of the error. + +Output:""", + input_variables=["response", "instructions"], +) diff --git a/libs/community/langchain_community/agent_toolkits/openapi/prompt.py b/libs/community/langchain_community/agent_toolkits/openapi/prompt.py new file mode 100644 index 00000000000..0484f5bf5d9 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/openapi/prompt.py @@ -0,0 +1,29 @@ +# flake8: noqa + +OPENAPI_PREFIX = """You are an agent designed to answer questions by making web requests to an API given the openapi spec. + +If the question does not seem related to the API, return I don't know. Do not make up an answer. +Only use information provided by the tools to construct your response. + +First, find the base URL needed to make the request. + +Second, find the relevant paths needed to answer the question. Take note that, sometimes, you might need to make more than one request to more than one path to answer the question. + +Third, find the required parameters needed to make the request. For GET requests, these are usually URL parameters and for POST requests, these are request body parameters. + +Fourth, make the requests needed to answer the question. Ensure that you are sending the correct parameters to the request by checking which parameters are required. For parameters with a fixed set of values, please use the spec to look at which values are allowed. + +Use the exact parameter names as listed in the spec, do not make up any names or abbreviate the names of parameters. +If you get a not found error, ensure that you are using a path that actually exists in the spec. +""" +OPENAPI_SUFFIX = """Begin! + +Question: {input} +Thought: I should explore the spec to find the base url for the API. +{agent_scratchpad}""" + +DESCRIPTION = """Can be used to answer questions about the openapi spec for the API. Always use this tool before trying to make a request. +Example inputs to this tool: + 'What are the required query parameters for a GET request to the /bar endpoint?` + 'What are the required parameters in the request body for a POST request to the /foo endpoint?' +Always give this tool a specific question.""" diff --git a/libs/community/langchain_community/agent_toolkits/openapi/spec.py b/libs/community/langchain_community/agent_toolkits/openapi/spec.py new file mode 100644 index 00000000000..29b529fc71e --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/openapi/spec.py @@ -0,0 +1,75 @@ +"""Quick and dirty representation for OpenAPI specs.""" + +from dataclasses import dataclass +from typing import List, Tuple + +from langchain_core.utils.json_schema import dereference_refs + + +@dataclass(frozen=True) +class ReducedOpenAPISpec: + """A reduced OpenAPI spec. + + This is a quick and dirty representation for OpenAPI specs. + + Attributes: + servers: The servers in the spec. + description: The description of the spec. + endpoints: The endpoints in the spec. + """ + + servers: List[dict] + description: str + endpoints: List[Tuple[str, str, dict]] + + +def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPISpec: + """Simplify/distill/minify a spec somehow. + + I want a smaller target for retrieval and (more importantly) + I want smaller results from retrieval. + I was hoping https://openapi.tools/ would have some useful bits + to this end, but doesn't seem so. + """ + # 1. Consider only get, post, patch, put, delete endpoints. + endpoints = [ + (f"{operation_name.upper()} {route}", docs.get("description"), docs) + for route, operation in spec["paths"].items() + for operation_name, docs in operation.items() + if operation_name in ["get", "post", "patch", "put", "delete"] + ] + + # 2. Replace any refs so that complete docs are retrieved. + # Note: probably want to do this post-retrieval, it blows up the size of the spec. + if dereference: + endpoints = [ + (name, description, dereference_refs(docs, full_schema=spec)) + for name, description, docs in endpoints + ] + + # 3. Strip docs down to required request args + happy path response. + def reduce_endpoint_docs(docs: dict) -> dict: + out = {} + if docs.get("description"): + out["description"] = docs.get("description") + if docs.get("parameters"): + out["parameters"] = [ + parameter + for parameter in docs.get("parameters", []) + if parameter.get("required") + ] + if "200" in docs["responses"]: + out["responses"] = docs["responses"]["200"] + if docs.get("requestBody"): + out["requestBody"] = docs.get("requestBody") + return out + + endpoints = [ + (name, description, reduce_endpoint_docs(docs)) + for name, description, docs in endpoints + ] + return ReducedOpenAPISpec( + servers=spec["servers"], + description=spec["info"].get("description", ""), + endpoints=endpoints, + ) diff --git a/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py b/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py new file mode 100644 index 00000000000..5b7f3fcbd52 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py @@ -0,0 +1,90 @@ +"""Requests toolkit.""" +from __future__ import annotations + +from typing import Any, List + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.tools import Tool + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.agent_toolkits.json.base import create_json_agent +from langchain_community.agent_toolkits.json.toolkit import JsonToolkit +from langchain_community.agent_toolkits.openapi.prompt import DESCRIPTION +from langchain_community.tools import BaseTool +from langchain_community.tools.json.tool import JsonSpec +from langchain_community.tools.requests.tool import ( + RequestsDeleteTool, + RequestsGetTool, + RequestsPatchTool, + RequestsPostTool, + RequestsPutTool, +) +from langchain_community.utilities.requests import TextRequestsWrapper + + +class RequestsToolkit(BaseToolkit): + """Toolkit for making REST requests. + + *Security Note*: This toolkit contains tools to make GET, POST, PATCH, PUT, + and DELETE requests to an API. + + Exercise care in who is allowed to use this toolkit. If exposing + to end users, consider that users will be able to make arbitrary + requests on behalf of the server hosting the code. For example, + users could ask the server to make a request to a private API + that is only accessible from the server. + + Control access to who can submit issue requests using this toolkit and + what network access it has. + + See https://python.langchain.com/docs/security for more information. + """ + + requests_wrapper: TextRequestsWrapper + + def get_tools(self) -> List[BaseTool]: + """Return a list of tools.""" + return [ + RequestsGetTool(requests_wrapper=self.requests_wrapper), + RequestsPostTool(requests_wrapper=self.requests_wrapper), + RequestsPatchTool(requests_wrapper=self.requests_wrapper), + RequestsPutTool(requests_wrapper=self.requests_wrapper), + RequestsDeleteTool(requests_wrapper=self.requests_wrapper), + ] + + +class OpenAPIToolkit(BaseToolkit): + """Toolkit for interacting with an OpenAPI API. + + *Security Note*: This toolkit contains tools that can read and modify + the state of a service; e.g., by creating, deleting, or updating, + reading underlying data. + + For example, this toolkit can be used to delete data exposed via + an OpenAPI compliant API. + """ + + json_agent: Any + requests_wrapper: TextRequestsWrapper + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + json_agent_tool = Tool( + name="json_explorer", + func=self.json_agent.run, + description=DESCRIPTION, + ) + request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper) + return [*request_toolkit.get_tools(), json_agent_tool] + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + json_spec: JsonSpec, + requests_wrapper: TextRequestsWrapper, + **kwargs: Any, + ) -> OpenAPIToolkit: + """Create json agent from llm, then initialize.""" + json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs) + return cls(json_agent=json_agent, requests_wrapper=requests_wrapper) diff --git a/libs/community/langchain_community/agent_toolkits/playwright/__init__.py b/libs/community/langchain_community/agent_toolkits/playwright/__init__.py new file mode 100644 index 00000000000..7fc7f6d9950 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/playwright/__init__.py @@ -0,0 +1,6 @@ +"""Playwright browser toolkit.""" +from langchain_community.agent_toolkits.playwright.toolkit import ( + PlayWrightBrowserToolkit, +) + +__all__ = ["PlayWrightBrowserToolkit"] diff --git a/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py b/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py new file mode 100644 index 00000000000..5a78e8c21a8 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py @@ -0,0 +1,110 @@ +"""Playwright web browser toolkit.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, Type, cast + +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.tools import BaseTool + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools.playwright.base import ( + BaseBrowserTool, + lazy_import_playwright_browsers, +) +from langchain_community.tools.playwright.click import ClickTool +from langchain_community.tools.playwright.current_page import CurrentWebPageTool +from langchain_community.tools.playwright.extract_hyperlinks import ( + ExtractHyperlinksTool, +) +from langchain_community.tools.playwright.extract_text import ExtractTextTool +from langchain_community.tools.playwright.get_elements import GetElementsTool +from langchain_community.tools.playwright.navigate import NavigateTool +from langchain_community.tools.playwright.navigate_back import NavigateBackTool + +if TYPE_CHECKING: + from playwright.async_api import Browser as AsyncBrowser + from playwright.sync_api import Browser as SyncBrowser +else: + try: + # We do this so pydantic can resolve the types when instantiating + from playwright.async_api import Browser as AsyncBrowser + from playwright.sync_api import Browser as SyncBrowser + except ImportError: + pass + + +class PlayWrightBrowserToolkit(BaseToolkit): + """Toolkit for PlayWright browser tools. + + **Security Note**: This toolkit provides code to control a web-browser. + + Careful if exposing this toolkit to end-users. The tools in the toolkit + are capable of navigating to arbitrary webpages, clicking on arbitrary + elements, and extracting arbitrary text and hyperlinks from webpages. + + Specifically, by default this toolkit allows navigating to: + + - Any URL (including any internal network URLs) + - And local files + + If exposing to end-users, consider limiting network access to the + server that hosts the agent; in addition, consider it is advised + to create a custom NavigationTool wht an args_schema that limits the URLs + that can be navigated to (e.g., only allow navigating to URLs that + start with a particular prefix). + + Remember to scope permissions to the minimal permissions necessary for + the application. If the default tool selection is not appropriate for + the application, consider creating a custom toolkit with the appropriate + tools. + + See https://python.langchain.com/docs/security for more information. + """ + + sync_browser: Optional["SyncBrowser"] = None + async_browser: Optional["AsyncBrowser"] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator + def validate_imports_and_browser_provided(cls, values: dict) -> dict: + """Check that the arguments are valid.""" + lazy_import_playwright_browsers() + if values.get("async_browser") is None and values.get("sync_browser") is None: + raise ValueError("Either async_browser or sync_browser must be specified.") + return values + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + tool_classes: List[Type[BaseBrowserTool]] = [ + ClickTool, + NavigateTool, + NavigateBackTool, + ExtractTextTool, + ExtractHyperlinksTool, + GetElementsTool, + CurrentWebPageTool, + ] + + tools = [ + tool_cls.from_browser( + sync_browser=self.sync_browser, async_browser=self.async_browser + ) + for tool_cls in tool_classes + ] + return cast(List[BaseTool], tools) + + @classmethod + def from_browser( + cls, + sync_browser: Optional[SyncBrowser] = None, + async_browser: Optional[AsyncBrowser] = None, + ) -> PlayWrightBrowserToolkit: + """Instantiate the toolkit.""" + # This is to raise a better error than the forward ref ones Pydantic would have + lazy_import_playwright_browsers() + return cls(sync_browser=sync_browser, async_browser=async_browser) diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/__init__.py b/libs/community/langchain_community/agent_toolkits/powerbi/__init__.py new file mode 100644 index 00000000000..42a9b09ac7e --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/powerbi/__init__.py @@ -0,0 +1 @@ +"""Power BI agent.""" diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/base.py b/libs/community/langchain_community/agent_toolkits/powerbi/base.py new file mode 100644 index 00000000000..06de2b97a7b --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/powerbi/base.py @@ -0,0 +1,73 @@ +"""Power BI agent.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.powerbi.prompt import ( + POWERBI_PREFIX, + POWERBI_SUFFIX, +) +from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit +from langchain_community.utilities.powerbi import PowerBIDataset + +if TYPE_CHECKING: + from langchain.agents import AgentExecutor + + +def create_pbi_agent( + llm: BaseLanguageModel, + toolkit: Optional[PowerBIToolkit] = None, + powerbi: Optional[PowerBIDataset] = None, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = POWERBI_PREFIX, + suffix: str = POWERBI_SUFFIX, + format_instructions: Optional[str] = None, + examples: Optional[str] = None, + input_variables: Optional[List[str]] = None, + top_k: int = 10, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a Power BI agent from an LLM and tools.""" + from langchain.agents import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + + if toolkit is None: + if powerbi is None: + raise ValueError("Must provide either a toolkit or powerbi dataset") + toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples) + tools = toolkit.get_tools() + tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names + prompt_params = ( + {"format_instructions": format_instructions} + if format_instructions is not None + else {} + ) + agent = ZeroShotAgent( + llm_chain=LLMChain( + llm=llm, + prompt=ZeroShotAgent.create_prompt( + tools, + prefix=prefix.format(top_k=top_k).format(tables=tables), + suffix=suffix, + input_variables=input_variables, + **prompt_params, + ), + callback_manager=callback_manager, # type: ignore + verbose=verbose, + ), + allowed_tools=[tool.name for tool in tools], + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py b/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py new file mode 100644 index 00000000000..acad61b4422 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/powerbi/chat_base.py @@ -0,0 +1,71 @@ +"""Power BI agent.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models.chat_models import BaseChatModel + +from langchain_community.agent_toolkits.powerbi.prompt import ( + POWERBI_CHAT_PREFIX, + POWERBI_CHAT_SUFFIX, +) +from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit +from langchain_community.utilities.powerbi import PowerBIDataset + +if TYPE_CHECKING: + from langchain.agents import AgentExecutor + from langchain.agents.agent import AgentOutputParser + from langchain.memory.chat_memory import BaseChatMemory + + +def create_pbi_chat_agent( + llm: BaseChatModel, + toolkit: Optional[PowerBIToolkit] = None, + powerbi: Optional[PowerBIDataset] = None, + callback_manager: Optional[BaseCallbackManager] = None, + output_parser: Optional[AgentOutputParser] = None, + prefix: str = POWERBI_CHAT_PREFIX, + suffix: str = POWERBI_CHAT_SUFFIX, + examples: Optional[str] = None, + input_variables: Optional[List[str]] = None, + memory: Optional[BaseChatMemory] = None, + top_k: int = 10, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a Power BI agent from a Chat LLM and tools. + + If you supply only a toolkit and no Power BI dataset, the same LLM is used for both. + """ + from langchain.agents import AgentExecutor + from langchain.agents.conversational_chat.base import ConversationalChatAgent + from langchain.memory import ConversationBufferMemory + + if toolkit is None: + if powerbi is None: + raise ValueError("Must provide either a toolkit or powerbi dataset") + toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples) + tools = toolkit.get_tools() + tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names + agent = ConversationalChatAgent.from_llm_and_tools( + llm=llm, + tools=tools, + system_message=prefix.format(top_k=top_k).format(tables=tables), + human_message=suffix, + input_variables=input_variables, + callback_manager=callback_manager, + output_parser=output_parser, + verbose=verbose, + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + memory=memory + or ConversationBufferMemory(memory_key="chat_history", return_messages=True), + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/prompt.py b/libs/community/langchain_community/agent_toolkits/powerbi/prompt.py new file mode 100644 index 00000000000..673a6bed29b --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/powerbi/prompt.py @@ -0,0 +1,38 @@ +# flake8: noqa +"""Prompts for PowerBI agent.""" + + +POWERBI_PREFIX = """You are an agent designed to help users interact with a PowerBI Dataset. + +Agent has access to a tool that can write a query based on the question and then run those against PowerBI, Microsofts business intelligence tool. The questions from the users should be interpreted as related to the dataset that is available and not general questions about the world. If the question does not seem related to the dataset, return "This does not appear to be part of this dataset." as the answer. + +Given an input question, ask to run the questions against the dataset, then look at the results and return the answer, the answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. +""" + +POWERBI_SUFFIX = """Begin! + +Question: {input} +Thought: I can first ask which tables I have, then how each table is defined and then ask the query tool the question I need, and finally create a nice sentence that answers the question. +{agent_scratchpad}""" + +POWERBI_CHAT_PREFIX = """Assistant is a large language model built to help users interact with a PowerBI Dataset. + +Assistant should try to create a correct and complete answer to the question from the user. If the user asks a question not related to the dataset it should return "This does not appear to be part of this dataset." as the answer. The user might make a mistake with the spelling of certain values, if you think that is the case, ask the user to confirm the spelling of the value and then run the query again. Unless the user specifies a specific number of examples they wish to obtain, and the results are too large, limit your query to at most {top_k} results, but make it clear when answering which field was used for the filtering. The user has access to these tables: {{tables}}. + +The answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000. +""" + +POWERBI_CHAT_SUFFIX = """TOOLS +------ +Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are: + +{{tools}} + +{format_instructions} + +USER'S INPUT +-------------------- +Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else): + +{{{{input}}}} +""" diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py b/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py new file mode 100644 index 00000000000..1a89d77ad2a --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/powerbi/toolkit.py @@ -0,0 +1,108 @@ +"""Toolkit for interacting with a Power BI dataset.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, Union + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain_core.pydantic_v1 import Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.powerbi.prompt import ( + QUESTION_TO_QUERY_BASE, + SINGLE_QUESTION_TO_QUERY, + USER_INPUT, +) +from langchain_community.tools.powerbi.tool import ( + InfoPowerBITool, + ListPowerBITool, + QueryPowerBITool, +) +from langchain_community.utilities.powerbi import PowerBIDataset + +if TYPE_CHECKING: + from langchain.chains.llm import LLMChain + + +class PowerBIToolkit(BaseToolkit): + """Toolkit for interacting with Power BI dataset. + + *Security Note*: This toolkit interacts with an external service. + + Control access to who can use this toolkit. + + Make sure that the capabilities given by this toolkit to the calling + code are appropriately scoped to the application. + + See https://python.langchain.com/docs/security for more information. + """ + + powerbi: PowerBIDataset = Field(exclude=True) + llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True) + examples: Optional[str] = None + max_iterations: int = 5 + callback_manager: Optional[BaseCallbackManager] = None + output_token_limit: Optional[int] = None + tiktoken_model_name: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + QueryPowerBITool( + llm_chain=self._get_chain(), + powerbi=self.powerbi, + examples=self.examples, + max_iterations=self.max_iterations, + output_token_limit=self.output_token_limit, + tiktoken_model_name=self.tiktoken_model_name, + ), + InfoPowerBITool(powerbi=self.powerbi), + ListPowerBITool(powerbi=self.powerbi), + ] + + def _get_chain(self) -> LLMChain: + """Construct the chain based on the callback manager and model type.""" + from langchain.chains.llm import LLMChain + + if isinstance(self.llm, BaseLanguageModel): + return LLMChain( + llm=self.llm, + callback_manager=self.callback_manager + if self.callback_manager + else None, + prompt=PromptTemplate( + template=SINGLE_QUESTION_TO_QUERY, + input_variables=["tool_input", "tables", "schemas", "examples"], + ), + ) + + system_prompt = SystemMessagePromptTemplate( + prompt=PromptTemplate( + template=QUESTION_TO_QUERY_BASE, + input_variables=["tables", "schemas", "examples"], + ) + ) + human_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template=USER_INPUT, + input_variables=["tool_input"], + ) + ) + return LLMChain( + llm=self.llm, + callback_manager=self.callback_manager if self.callback_manager else None, + prompt=ChatPromptTemplate.from_messages([system_prompt, human_prompt]), + ) diff --git a/libs/community/langchain_community/agent_toolkits/slack/__init__.py b/libs/community/langchain_community/agent_toolkits/slack/__init__.py new file mode 100644 index 00000000000..1ec5ae704ce --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/slack/__init__.py @@ -0,0 +1 @@ +"""Slack toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/slack/toolkit.py b/libs/community/langchain_community/agent_toolkits/slack/toolkit.py new file mode 100644 index 00000000000..bc0f09ff1db --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/slack/toolkit.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List + +from langchain_core.pydantic_v1 import Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.slack.get_channel import SlackGetChannel +from langchain_community.tools.slack.get_message import SlackGetMessage +from langchain_community.tools.slack.schedule_message import SlackScheduleMessage +from langchain_community.tools.slack.send_message import SlackSendMessage +from langchain_community.tools.slack.utils import login + +if TYPE_CHECKING: + from slack_sdk import WebClient + + +class SlackToolkit(BaseToolkit): + """Toolkit for interacting with Slack.""" + + client: WebClient = Field(default_factory=login) + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + SlackGetChannel(), + SlackGetMessage(), + SlackScheduleMessage(), + SlackSendMessage(), + ] diff --git a/libs/community/langchain_community/agent_toolkits/spark_sql/__init__.py b/libs/community/langchain_community/agent_toolkits/spark_sql/__init__.py new file mode 100644 index 00000000000..4308c079443 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/spark_sql/__init__.py @@ -0,0 +1 @@ +"""Spark SQL agent.""" diff --git a/libs/community/langchain_community/agent_toolkits/spark_sql/base.py b/libs/community/langchain_community/agent_toolkits/spark_sql/base.py new file mode 100644 index 00000000000..dca019b0425 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/spark_sql/base.py @@ -0,0 +1,70 @@ +"""Spark SQL agent.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from langchain_core.callbacks import BaseCallbackManager, Callbacks +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX +from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def create_spark_sql_agent( + llm: BaseLanguageModel, + toolkit: SparkSQLToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, + prefix: str = SQL_PREFIX, + suffix: str = SQL_SUFFIX, + format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, + top_k: int = 10, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a Spark SQL agent from an LLM and tools.""" + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + + tools = toolkit.get_tools() + prefix = prefix.format(top_k=top_k) + prompt_params = ( + {"format_instructions": format_instructions} + if format_instructions is not None + else {} + ) + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + **prompt_params, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + callbacks=callbacks, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + callbacks=callbacks, + verbose=verbose, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/community/langchain_community/agent_toolkits/spark_sql/prompt.py b/libs/community/langchain_community/agent_toolkits/spark_sql/prompt.py new file mode 100644 index 00000000000..b499085d3fe --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/spark_sql/prompt.py @@ -0,0 +1,21 @@ +# flake8: noqa + +SQL_PREFIX = """You are an agent designed to interact with Spark SQL. +Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +You have access to tools for interacting with the database. +Only use the below tools. Only use the information returned by the below tools to construct your final answer. +You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. + +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +If the question does not seem related to the database, just return "I don't know" as the answer. +""" + +SQL_SUFFIX = """Begin! + +Question: {input} +Thought: I should look at the tables in the database to see what I can query. +{agent_scratchpad}""" diff --git a/libs/community/langchain_community/agent_toolkits/spark_sql/toolkit.py b/libs/community/langchain_community/agent_toolkits/spark_sql/toolkit.py new file mode 100644 index 00000000000..fd0363132d0 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/spark_sql/toolkit.py @@ -0,0 +1,36 @@ +"""Toolkit for interacting with Spark SQL.""" +from typing import List + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.pydantic_v1 import Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.spark_sql.tool import ( + InfoSparkSQLTool, + ListSparkSQLTool, + QueryCheckerTool, + QuerySparkSQLTool, +) +from langchain_community.utilities.spark_sql import SparkSQL + + +class SparkSQLToolkit(BaseToolkit): + """Toolkit for interacting with Spark SQL.""" + + db: SparkSQL = Field(exclude=True) + llm: BaseLanguageModel = Field(exclude=True) + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + QuerySparkSQLTool(db=self.db), + InfoSparkSQLTool(db=self.db), + ListSparkSQLTool(db=self.db), + QueryCheckerTool(db=self.db, llm=self.llm), + ] diff --git a/libs/community/langchain_community/agent_toolkits/sql/__init__.py b/libs/community/langchain_community/agent_toolkits/sql/__init__.py new file mode 100644 index 00000000000..74293a52391 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/sql/__init__.py @@ -0,0 +1 @@ +"""SQL agent.""" diff --git a/libs/community/langchain_community/agent_toolkits/sql/base.py b/libs/community/langchain_community/agent_toolkits/sql/base.py new file mode 100644 index 00000000000..c2451f7c231 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/sql/base.py @@ -0,0 +1,108 @@ +"""SQL agent.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel +from langchain_core.messages import AIMessage, SystemMessage +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) + +from langchain_community.agent_toolkits.sql.prompt import ( + SQL_FUNCTIONS_SUFFIX, + SQL_PREFIX, + SQL_SUFFIX, +) +from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain_community.tools import BaseTool + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + from langchain.agents.agent_types import AgentType + + +def create_sql_agent( + llm: BaseLanguageModel, + toolkit: SQLDatabaseToolkit, + agent_type: Optional[AgentType] = None, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = SQL_PREFIX, + suffix: Optional[str] = None, + format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, + top_k: int = 10, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + extra_tools: Sequence[BaseTool] = (), + **kwargs: Any, +) -> AgentExecutor: + """Construct an SQL agent from an LLM and tools.""" + from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent + from langchain.agents.agent_types import AgentType + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent + from langchain.chains.llm import LLMChain + + agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION + tools = toolkit.get_tools() + list(extra_tools) + prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k) + agent: BaseSingleActionAgent + + if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: + prompt_params = ( + {"format_instructions": format_instructions} + if format_instructions is not None + else {} + ) + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix or SQL_SUFFIX, + input_variables=input_variables, + **prompt_params, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + + elif agent_type == AgentType.OPENAI_FUNCTIONS: + messages = [ + SystemMessage(content=prefix), + HumanMessagePromptTemplate.from_template("{input}"), + AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + input_variables = ["input", "agent_scratchpad"] + _prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages) + + agent = OpenAIFunctionsAgent( + llm=llm, + prompt=_prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) + else: + raise ValueError(f"Agent type {agent_type} not supported at the moment.") + + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/community/langchain_community/agent_toolkits/sql/prompt.py b/libs/community/langchain_community/agent_toolkits/sql/prompt.py new file mode 100644 index 00000000000..92464da4b9b --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/sql/prompt.py @@ -0,0 +1,23 @@ +# flake8: noqa + +SQL_PREFIX = """You are an agent designed to interact with a SQL database. +Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +You have access to tools for interacting with the database. +Only use the below tools. Only use the information returned by the below tools to construct your final answer. +You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. + +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +If the question does not seem related to the database, just return "I don't know" as the answer. +""" + +SQL_SUFFIX = """Begin! + +Question: {input} +Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables. +{agent_scratchpad}""" + +SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.""" diff --git a/libs/community/langchain_community/agent_toolkits/sql/toolkit.py b/libs/community/langchain_community/agent_toolkits/sql/toolkit.py new file mode 100644 index 00000000000..382fcbb4856 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/sql/toolkit.py @@ -0,0 +1,71 @@ +"""Toolkit for interacting with an SQL database.""" +from typing import List + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.pydantic_v1 import Field + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.sql_database.tool import ( + InfoSQLDatabaseTool, + ListSQLDatabaseTool, + QuerySQLCheckerTool, + QuerySQLDataBaseTool, +) +from langchain_community.utilities.sql_database import SQLDatabase + + +class SQLDatabaseToolkit(BaseToolkit): + """Toolkit for interacting with SQL databases.""" + + db: SQLDatabase = Field(exclude=True) + llm: BaseLanguageModel = Field(exclude=True) + + @property + def dialect(self) -> str: + """Return string representation of SQL dialect to use.""" + return self.db.dialect + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + list_sql_database_tool = ListSQLDatabaseTool(db=self.db) + info_sql_database_tool_description = ( + "Input to this tool is a comma-separated list of tables, output is the " + "schema and sample rows for those tables. " + "Be sure that the tables actually exist by calling " + f"{list_sql_database_tool.name} first! " + "Example Input: table1, table2, table3" + ) + info_sql_database_tool = InfoSQLDatabaseTool( + db=self.db, description=info_sql_database_tool_description + ) + query_sql_database_tool_description = ( + "Input to this tool is a detailed and correct SQL query, output is a " + "result from the database. If the query is not correct, an error message " + "will be returned. If an error is returned, rewrite the query, check the " + "query, and try again. If you encounter an issue with Unknown column " + f"'xxxx' in 'field list', use {info_sql_database_tool.name} " + "to query the correct table fields." + ) + query_sql_database_tool = QuerySQLDataBaseTool( + db=self.db, description=query_sql_database_tool_description + ) + query_sql_checker_tool_description = ( + "Use this tool to double check if your query is correct before executing " + "it. Always use this tool before executing a query with " + f"{query_sql_database_tool.name}!" + ) + query_sql_checker_tool = QuerySQLCheckerTool( + db=self.db, llm=self.llm, description=query_sql_checker_tool_description + ) + return [ + query_sql_database_tool, + info_sql_database_tool, + list_sql_database_tool, + query_sql_checker_tool, + ] diff --git a/libs/community/langchain_community/agent_toolkits/steam/__init__.py b/libs/community/langchain_community/agent_toolkits/steam/__init__.py new file mode 100644 index 00000000000..f9998108242 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/steam/__init__.py @@ -0,0 +1 @@ +"""Steam Toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/steam/toolkit.py b/libs/community/langchain_community/agent_toolkits/steam/toolkit.py new file mode 100644 index 00000000000..1fe57b032bf --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/steam/toolkit.py @@ -0,0 +1,48 @@ +"""Steam Toolkit.""" +from typing import List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.steam.prompt import ( + STEAM_GET_GAMES_DETAILS, + STEAM_GET_RECOMMENDED_GAMES, +) +from langchain_community.tools.steam.tool import SteamWebAPIQueryRun +from langchain_community.utilities.steam import SteamWebAPIWrapper + + +class SteamToolkit(BaseToolkit): + """Steam Toolkit.""" + + tools: List[BaseTool] = [] + + @classmethod + def from_steam_api_wrapper( + cls, steam_api_wrapper: SteamWebAPIWrapper + ) -> "SteamToolkit": + operations: List[dict] = [ + { + "mode": "get_games_details", + "name": "Get Games Details", + "description": STEAM_GET_GAMES_DETAILS, + }, + { + "mode": "get_recommended_games", + "name": "Get Recommended Games", + "description": STEAM_GET_RECOMMENDED_GAMES, + }, + ] + tools = [ + SteamWebAPIQueryRun( + name=action["name"], + description=action["description"], + mode=action["mode"], + api_wrapper=steam_api_wrapper, + ) + for action in operations + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return self.tools diff --git a/libs/community/langchain_community/agent_toolkits/vectorstore/base.py b/libs/community/langchain_community/agent_toolkits/vectorstore/base.py new file mode 100644 index 00000000000..7354dfb0df2 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/vectorstore/base.py @@ -0,0 +1,106 @@ +"""VectorStore agent.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.language_models import BaseLanguageModel + +from langchain_community.agent_toolkits.vectorstore.prompt import PREFIX, ROUTER_PREFIX +from langchain_community.agent_toolkits.vectorstore.toolkit import ( + VectorStoreRouterToolkit, + VectorStoreToolkit, +) + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + + +def create_vectorstore_agent( + llm: BaseLanguageModel, + toolkit: VectorStoreToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = PREFIX, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a VectorStore agent from an LLM and tools. + + Args: + llm (BaseLanguageModel): LLM that will be used by the agent + toolkit (VectorStoreToolkit): Set of tools for the agent + callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ] + prefix (str, optional): The prefix prompt for the agent. If not provided uses default PREFIX. + verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ] + agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ] + **kwargs: Additional named parameters to pass to the ZeroShotAgent. + + Returns: + AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response + """ # noqa: E501 + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + + tools = toolkit.get_tools() + prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) + + +def create_vectorstore_router_agent( + llm: BaseLanguageModel, + toolkit: VectorStoreRouterToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = ROUTER_PREFIX, + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> AgentExecutor: + """Construct a VectorStore router agent from an LLM and tools. + + Args: + llm (BaseLanguageModel): LLM that will be used by the agent + toolkit (VectorStoreRouterToolkit): Set of tools for the agent which have routing capability with multiple vector stores + callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ] + prefix (str, optional): The prefix prompt for the router agent. If not provided uses default ROUTER_PREFIX. + verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ] + agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ] + **kwargs: Additional named parameters to pass to the ZeroShotAgent. + + Returns: + AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response. + """ # noqa: E501 + from langchain.agents.agent import AgentExecutor + from langchain.agents.mrkl.base import ZeroShotAgent + from langchain.chains.llm import LLMChain + + tools = toolkit.get_tools() + prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/community/langchain_community/agent_toolkits/xorbits/__init__.py b/libs/community/langchain_community/agent_toolkits/xorbits/__init__.py new file mode 100644 index 00000000000..fd8fc13ba0d --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/xorbits/__init__.py @@ -0,0 +1,26 @@ +from pathlib import Path +from typing import Any + +from langchain_core._api.path import as_import_path + + +def __getattr__(name: str) -> Any: + """Get attr name.""" + + if name == "create_xorbits_agent": + # Get directory of langchain package + HERE = Path(__file__).parents[3] + here = as_import_path(Path(__file__).parent, relative_to=HERE) + + old_path = "langchain." + here + "." + name + new_path = "langchain_experimental." + here + "." + name + raise ImportError( + "This agent has been moved to langchain experiment. " + "This agent relies on python REPL tool under the hood, so to use it " + "safely please sandbox the python REPL. " + "Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md " + "and https://github.com/langchain-ai/langchain/discussions/11680" + "To keep using this code as is, install langchain experimental and " + f"update your import statement from:\n `{old_path}` to `{new_path}`." + ) + raise AttributeError(f"{name} does not exist") diff --git a/libs/community/langchain_community/agent_toolkits/zapier/__init__.py b/libs/community/langchain_community/agent_toolkits/zapier/__init__.py new file mode 100644 index 00000000000..faef4a32531 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/zapier/__init__.py @@ -0,0 +1 @@ +"""Zapier Toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/zapier/toolkit.py b/libs/community/langchain_community/agent_toolkits/zapier/toolkit.py new file mode 100644 index 00000000000..f5f335efdc8 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/zapier/toolkit.py @@ -0,0 +1,60 @@ +"""[DEPRECATED] Zapier Toolkit.""" +from typing import List + +from langchain_core._api import warn_deprecated + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.zapier.tool import ZapierNLARunAction +from langchain_community.utilities.zapier import ZapierNLAWrapper + + +class ZapierToolkit(BaseToolkit): + """Zapier Toolkit.""" + + tools: List[BaseTool] = [] + + @classmethod + def from_zapier_nla_wrapper( + cls, zapier_nla_wrapper: ZapierNLAWrapper + ) -> "ZapierToolkit": + """Create a toolkit from a ZapierNLAWrapper.""" + actions = zapier_nla_wrapper.list() + tools = [ + ZapierNLARunAction( + action_id=action["id"], + zapier_description=action["description"], + params_schema=action["params"], + api_wrapper=zapier_nla_wrapper, + ) + for action in actions + ] + return cls(tools=tools) + + @classmethod + async def async_from_zapier_nla_wrapper( + cls, zapier_nla_wrapper: ZapierNLAWrapper + ) -> "ZapierToolkit": + """Create a toolkit from a ZapierNLAWrapper.""" + actions = await zapier_nla_wrapper.alist() + tools = [ + ZapierNLARunAction( + action_id=action["id"], + zapier_description=action["description"], + params_schema=action["params"], + api_wrapper=zapier_nla_wrapper, + ) + for action in actions + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + warn_deprecated( + since="0.0.319", + message=( + "This tool will be deprecated on 2023-11-17. See " + "https://nla.zapier.com/sunset/ for details" + ), + ) + return self.tools diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py new file mode 100644 index 00000000000..a0070c2e59a --- /dev/null +++ b/libs/community/langchain_community/cache.py @@ -0,0 +1,1556 @@ +""" +.. warning:: + Beta Feature! + +**Cache** provides an optional caching layer for LLMs. + +Cache is useful for two reasons: + +- It can save you money by reducing the number of API calls you make to the LLM + provider if you're often requesting the same completion multiple times. +- It can speed up your application by reducing the number of API calls you make + to the LLM provider. + +Cache directly competes with Memory. See documentation for Pros and Cons. + +**Class hierarchy:** + +.. code-block:: + + BaseCache --> Cache # Examples: InMemoryCache, RedisCache, GPTCache +""" +from __future__ import annotations + +import hashlib +import inspect +import json +import logging +import uuid +import warnings +from datetime import timedelta +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + cast, +) + +from sqlalchemy import Column, Integer, Row, String, create_engine, select +from sqlalchemy.engine.base import Engine +from sqlalchemy.orm import Session + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + +from langchain_core.caches import RETURN_VAL_TYPE, BaseCache +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.llms import LLM, get_prompts +from langchain_core.load.dump import dumps +from langchain_core.load.load import loads +from langchain_core.outputs import ChatGeneration, Generation +from langchain_core.utils import get_from_env + +from langchain_community.vectorstores.redis import Redis as RedisVectorstore + +logger = logging.getLogger(__file__) + +if TYPE_CHECKING: + import momento + from cassandra.cluster import Session as CassandraSession + + +def _hash(_input: str) -> str: + """Use a deterministic hashing approach.""" + return hashlib.md5(_input.encode()).hexdigest() + + +def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str: + """Dump generations to json. + + Args: + generations (RETURN_VAL_TYPE): A list of language model generations. + + Returns: + str: Json representing a list of generations. + + Warning: would not work well with arbitrary subclasses of `Generation` + """ + return json.dumps([generation.dict() for generation in generations]) + + +def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE: + """Load generations from json. + + Args: + generations_json (str): A string of json representing a list of generations. + + Raises: + ValueError: Could not decode json string to list of generations. + + Returns: + RETURN_VAL_TYPE: A list of generations. + + Warning: would not work well with arbitrary subclasses of `Generation` + """ + try: + results = json.loads(generations_json) + return [Generation(**generation_dict) for generation_dict in results] + except json.JSONDecodeError: + raise ValueError( + f"Could not decode json to list of generations: {generations_json}" + ) + + +def _dumps_generations(generations: RETURN_VAL_TYPE) -> str: + """ + Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation` + + Args: + generations (RETURN_VAL_TYPE): A list of language model generations. + + Returns: + str: a single string representing a list of generations. + + This function (+ its counterpart `_loads_generations`) rely on + the dumps/loads pair with Reviver, so are able to deal + with all subclasses of Generation. + + Each item in the list can be `dumps`ed to a string, + then we make the whole list of strings into a json-dumped. + """ + return json.dumps([dumps(_item) for _item in generations]) + + +def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]: + """ + Deserialization of a string into a generic RETURN_VAL_TYPE + (i.e. a sequence of `Generation`). + + See `_dumps_generations`, the inverse of this function. + + Args: + generations_str (str): A string representing a list of generations. + + Compatible with the legacy cache-blob format + Does not raise exceptions for malformed entries, just logs a warning + and returns none: the caller should be prepared for such a cache miss. + + Returns: + RETURN_VAL_TYPE: A list of generations. + """ + try: + generations = [loads(_item_str) for _item_str in json.loads(generations_str)] + return generations + except (json.JSONDecodeError, TypeError): + # deferring the (soft) handling to after the legacy-format attempt + pass + + try: + gen_dicts = json.loads(generations_str) + # not relying on `_load_generations_from_json` (which could disappear): + generations = [Generation(**generation_dict) for generation_dict in gen_dicts] + logger.warning( + f"Legacy 'Generation' cached blob encountered: '{generations_str}'" + ) + return generations + except (json.JSONDecodeError, TypeError): + logger.warning( + f"Malformed/unparsable cached blob encountered: '{generations_str}'" + ) + return None + + +class InMemoryCache(BaseCache): + """Cache that stores things in memory.""" + + def __init__(self) -> None: + """Initialize with empty cache.""" + self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + return self._cache.get((prompt, llm_string), None) + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + self._cache[(prompt, llm_string)] = return_val + + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + self._cache = {} + + +Base = declarative_base() + + +class FullLLMCache(Base): # type: ignore + """SQLite table for full LLM Cache (all generations).""" + + __tablename__ = "full_llm_cache" + prompt = Column(String, primary_key=True) + llm = Column(String, primary_key=True) + idx = Column(Integer, primary_key=True) + response = Column(String) + + +class SQLAlchemyCache(BaseCache): + """Cache that uses SQAlchemy as a backend.""" + + def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache): + """Initialize by creating all tables.""" + self.engine = engine + self.cache_schema = cache_schema + self.cache_schema.metadata.create_all(self.engine) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + stmt = ( + select(self.cache_schema.response) + .where(self.cache_schema.prompt == prompt) # type: ignore + .where(self.cache_schema.llm == llm_string) + .order_by(self.cache_schema.idx) + ) + with Session(self.engine) as session: + rows = session.execute(stmt).fetchall() + if rows: + try: + return [loads(row[0]) for row in rows] + except Exception: + logger.warning( + "Retrieving a cache value that could not be deserialized " + "properly. This is likely due to the cache being in an " + "older format. Please recreate your cache to avoid this " + "error." + ) + # In a previous life we stored the raw text directly + # in the table, so assume it's in that format. + return [Generation(text=row[0]) for row in rows] + return None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update based on prompt and llm_string.""" + items = [ + self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i) + for i, gen in enumerate(return_val) + ] + with Session(self.engine) as session, session.begin(): + for item in items: + session.merge(item) + + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + with Session(self.engine) as session: + session.query(self.cache_schema).delete() + session.commit() + + +class SQLiteCache(SQLAlchemyCache): + """Cache that uses SQLite as a backend.""" + + def __init__(self, database_path: str = ".langchain.db"): + """Initialize by creating the engine and all tables.""" + engine = create_engine(f"sqlite:///{database_path}") + super().__init__(engine) + + +class UpstashRedisCache(BaseCache): + """Cache that uses Upstash Redis as a backend.""" + + def __init__(self, redis_: Any, *, ttl: Optional[int] = None): + """ + Initialize an instance of UpstashRedisCache. + + This method initializes an object with Upstash Redis caching capabilities. + It takes a `redis_` parameter, which should be an instance of an Upstash Redis + client class, allowing the object to interact with Upstash Redis + server for caching purposes. + + Parameters: + redis_: An instance of Upstash Redis client class + (e.g., Redis) used for caching. + This allows the object to communicate with + Redis server for caching operations on. + ttl (int, optional): Time-to-live (TTL) for cached items in seconds. + If provided, it sets the time duration for how long cached + items will remain valid. If not provided, cached items will not + have an automatic expiration. + """ + try: + from upstash_redis import Redis + except ImportError: + raise ValueError( + "Could not import upstash_redis python package. " + "Please install it with `pip install upstash_redis`." + ) + if not isinstance(redis_, Redis): + raise ValueError("Please pass in Upstash Redis object.") + self.redis = redis_ + self.ttl = ttl + + def _key(self, prompt: str, llm_string: str) -> str: + """Compute key from prompt and llm_string""" + return _hash(prompt + llm_string) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + generations = [] + # Read from a HASH + results = self.redis.hgetall(self._key(prompt, llm_string)) + if results: + for _, text in results.items(): + generations.append(Generation(text=text)) + return generations if generations else None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + for gen in return_val: + if not isinstance(gen, Generation): + raise ValueError( + "UpstashRedisCache supports caching of normal LLM generations, " + f"got {type(gen)}" + ) + if isinstance(gen, ChatGeneration): + warnings.warn( + "NOTE: Generation has not been cached. UpstashRedisCache does not" + " support caching ChatModel outputs." + ) + return + # Write to a HASH + key = self._key(prompt, llm_string) + + mapping = { + str(idx): generation.text for idx, generation in enumerate(return_val) + } + self.redis.hset(key=key, values=mapping) + + if self.ttl is not None: + self.redis.expire(key, self.ttl) + + def clear(self, **kwargs: Any) -> None: + """ + Clear cache. If `asynchronous` is True, flush asynchronously. + This flushes the *whole* db. + """ + asynchronous = kwargs.get("asynchronous", False) + if asynchronous: + asynchronous = "ASYNC" + else: + asynchronous = "SYNC" + self.redis.flushdb(flush_type=asynchronous) + + +class RedisCache(BaseCache): + """Cache that uses Redis as a backend.""" + + def __init__(self, redis_: Any, *, ttl: Optional[int] = None): + """ + Initialize an instance of RedisCache. + + This method initializes an object with Redis caching capabilities. + It takes a `redis_` parameter, which should be an instance of a Redis + client class, allowing the object to interact with a Redis + server for caching purposes. + + Parameters: + redis_ (Any): An instance of a Redis client class + (e.g., redis.Redis) used for caching. + This allows the object to communicate with a + Redis server for caching operations. + ttl (int, optional): Time-to-live (TTL) for cached items in seconds. + If provided, it sets the time duration for how long cached + items will remain valid. If not provided, cached items will not + have an automatic expiration. + """ + try: + from redis import Redis + except ImportError: + raise ValueError( + "Could not import redis python package. " + "Please install it with `pip install redis`." + ) + if not isinstance(redis_, Redis): + raise ValueError("Please pass in Redis object.") + self.redis = redis_ + self.ttl = ttl + + def _key(self, prompt: str, llm_string: str) -> str: + """Compute key from prompt and llm_string""" + return _hash(prompt + llm_string) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + generations = [] + # Read from a Redis HASH + results = self.redis.hgetall(self._key(prompt, llm_string)) + if results: + for _, text in results.items(): + try: + generations.append(loads(text)) + except Exception: + logger.warning( + "Retrieving a cache value that could not be deserialized " + "properly. This is likely due to the cache being in an " + "older format. Please recreate your cache to avoid this " + "error." + ) + # In a previous life we stored the raw text directly + # in the table, so assume it's in that format. + generations.append(Generation(text=text)) + return generations if generations else None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + for gen in return_val: + if not isinstance(gen, Generation): + raise ValueError( + "RedisCache only supports caching of normal LLM generations, " + f"got {type(gen)}" + ) + # Write to a Redis HASH + key = self._key(prompt, llm_string) + + with self.redis.pipeline() as pipe: + pipe.hset( + key, + mapping={ + str(idx): dumps(generation) + for idx, generation in enumerate(return_val) + }, + ) + if self.ttl is not None: + pipe.expire(key, self.ttl) + + pipe.execute() + + def clear(self, **kwargs: Any) -> None: + """Clear cache. If `asynchronous` is True, flush asynchronously.""" + asynchronous = kwargs.get("asynchronous", False) + self.redis.flushdb(asynchronous=asynchronous, **kwargs) + + +class RedisSemanticCache(BaseCache): + """Cache that uses Redis as a vector-store backend.""" + + # TODO - implement a TTL policy in Redis + + DEFAULT_SCHEMA = { + "content_key": "prompt", + "text": [ + {"name": "prompt"}, + ], + "extra": [{"name": "return_val"}, {"name": "llm_string"}], + } + + def __init__( + self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2 + ): + """Initialize by passing in the `init` GPTCache func + + Args: + redis_url (str): URL to connect to Redis. + embedding (Embedding): Embedding provider for semantic encoding and search. + score_threshold (float, 0.2): + + Example: + + .. code-block:: python + + from langchain_community.globals import set_llm_cache + + from langchain_community.cache import RedisSemanticCache + from langchain_community.embeddings import OpenAIEmbeddings + + set_llm_cache(RedisSemanticCache( + redis_url="redis://localhost:6379", + embedding=OpenAIEmbeddings() + )) + + """ + self._cache_dict: Dict[str, RedisVectorstore] = {} + self.redis_url = redis_url + self.embedding = embedding + self.score_threshold = score_threshold + + def _index_name(self, llm_string: str) -> str: + hashed_index = _hash(llm_string) + return f"cache:{hashed_index}" + + def _get_llm_cache(self, llm_string: str) -> RedisVectorstore: + index_name = self._index_name(llm_string) + + # return vectorstore client for the specific llm string + if index_name in self._cache_dict: + return self._cache_dict[index_name] + + # create new vectorstore client for the specific llm string + try: + self._cache_dict[index_name] = RedisVectorstore.from_existing_index( + embedding=self.embedding, + index_name=index_name, + redis_url=self.redis_url, + schema=cast(Dict, self.DEFAULT_SCHEMA), + ) + except ValueError: + redis = RedisVectorstore( + embedding=self.embedding, + index_name=index_name, + redis_url=self.redis_url, + index_schema=cast(Dict, self.DEFAULT_SCHEMA), + ) + _embedding = self.embedding.embed_query(text="test") + redis._create_index_if_not_exist(dim=len(_embedding)) + self._cache_dict[index_name] = redis + + return self._cache_dict[index_name] + + def clear(self, **kwargs: Any) -> None: + """Clear semantic cache for a given llm_string.""" + index_name = self._index_name(kwargs["llm_string"]) + if index_name in self._cache_dict: + self._cache_dict[index_name].drop_index( + index_name=index_name, delete_documents=True, redis_url=self.redis_url + ) + del self._cache_dict[index_name] + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + llm_cache = self._get_llm_cache(llm_string) + generations: List = [] + # Read from a Hash + results = llm_cache.similarity_search( + query=prompt, + k=1, + distance_threshold=self.score_threshold, + ) + if results: + for document in results: + try: + generations.extend(loads(document.metadata["return_val"])) + except Exception: + logger.warning( + "Retrieving a cache value that could not be deserialized " + "properly. This is likely due to the cache being in an " + "older format. Please recreate your cache to avoid this " + "error." + ) + # In a previous life we stored the raw text directly + # in the table, so assume it's in that format. + generations.extend( + _load_generations_from_json(document.metadata["return_val"]) + ) + return generations if generations else None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + for gen in return_val: + if not isinstance(gen, Generation): + raise ValueError( + "RedisSemanticCache only supports caching of " + f"normal LLM generations, got {type(gen)}" + ) + llm_cache = self._get_llm_cache(llm_string) + + metadata = { + "llm_string": llm_string, + "prompt": prompt, + "return_val": dumps([g for g in return_val]), + } + llm_cache.add_texts(texts=[prompt], metadatas=[metadata]) + + +class GPTCache(BaseCache): + """Cache that uses GPTCache as a backend.""" + + def __init__( + self, + init_func: Union[ + Callable[[Any, str], None], Callable[[Any], None], None + ] = None, + ): + """Initialize by passing in init function (default: `None`). + + Args: + init_func (Optional[Callable[[Any], None]]): init `GPTCache` function + (default: `None`) + + Example: + .. code-block:: python + + # Initialize GPTCache with a custom init function + import gptcache + from gptcache.processor.pre import get_prompt + from gptcache.manager.factory import get_data_manager + from langchain_community.globals import set_llm_cache + + # Avoid multiple caches using the same file, + causing different llm model caches to affect each other + + def init_gptcache(cache_obj: gptcache.Cache, llm str): + cache_obj.init( + pre_embedding_func=get_prompt, + data_manager=manager_factory( + manager="map", + data_dir=f"map_cache_{llm}" + ), + ) + + set_llm_cache(GPTCache(init_gptcache)) + + """ + try: + import gptcache # noqa: F401 + except ImportError: + raise ImportError( + "Could not import gptcache python package. " + "Please install it with `pip install gptcache`." + ) + + self.init_gptcache_func: Union[ + Callable[[Any, str], None], Callable[[Any], None], None + ] = init_func + self.gptcache_dict: Dict[str, Any] = {} + + def _new_gptcache(self, llm_string: str) -> Any: + """New gptcache object""" + from gptcache import Cache + from gptcache.manager.factory import get_data_manager + from gptcache.processor.pre import get_prompt + + _gptcache = Cache() + if self.init_gptcache_func is not None: + sig = inspect.signature(self.init_gptcache_func) + if len(sig.parameters) == 2: + self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg] + else: + self.init_gptcache_func(_gptcache) # type: ignore[call-arg] + else: + _gptcache.init( + pre_embedding_func=get_prompt, + data_manager=get_data_manager(data_path=llm_string), + ) + + self.gptcache_dict[llm_string] = _gptcache + return _gptcache + + def _get_gptcache(self, llm_string: str) -> Any: + """Get a cache object. + + When the corresponding llm model cache does not exist, it will be created.""" + _gptcache = self.gptcache_dict.get(llm_string, None) + if not _gptcache: + _gptcache = self._new_gptcache(llm_string) + return _gptcache + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up the cache data. + First, retrieve the corresponding cache object using the `llm_string` parameter, + and then retrieve the data from the cache based on the `prompt`. + """ + from gptcache.adapter.api import get + + _gptcache = self._get_gptcache(llm_string) + + res = get(prompt, cache_obj=_gptcache) + if res: + return [ + Generation(**generation_dict) for generation_dict in json.loads(res) + ] + return None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache. + First, retrieve the corresponding cache object using the `llm_string` parameter, + and then store the `prompt` and `return_val` in the cache object. + """ + for gen in return_val: + if not isinstance(gen, Generation): + raise ValueError( + "GPTCache only supports caching of normal LLM generations, " + f"got {type(gen)}" + ) + from gptcache.adapter.api import put + + _gptcache = self._get_gptcache(llm_string) + handled_data = json.dumps([generation.dict() for generation in return_val]) + put(prompt, handled_data, cache_obj=_gptcache) + return None + + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + from gptcache import Cache + + for gptcache_instance in self.gptcache_dict.values(): + gptcache_instance = cast(Cache, gptcache_instance) + gptcache_instance.flush() + + self.gptcache_dict.clear() + + +def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None: + """Create cache if it doesn't exist. + + Raises: + SdkException: Momento service or network error + Exception: Unexpected response + """ + from momento.responses import CreateCache + + create_cache_response = cache_client.create_cache(cache_name) + if isinstance(create_cache_response, CreateCache.Success) or isinstance( + create_cache_response, CreateCache.CacheAlreadyExists + ): + return None + elif isinstance(create_cache_response, CreateCache.Error): + raise create_cache_response.inner_exception + else: + raise Exception(f"Unexpected response cache creation: {create_cache_response}") + + +def _validate_ttl(ttl: Optional[timedelta]) -> None: + if ttl is not None and ttl <= timedelta(seconds=0): + raise ValueError(f"ttl must be positive but was {ttl}.") + + +class MomentoCache(BaseCache): + """Cache that uses Momento as a backend. See https://gomomento.com/""" + + def __init__( + self, + cache_client: momento.CacheClient, + cache_name: str, + *, + ttl: Optional[timedelta] = None, + ensure_cache_exists: bool = True, + ): + """Instantiate a prompt cache using Momento as a backend. + + Note: to instantiate the cache client passed to MomentoCache, + you must have a Momento account. See https://gomomento.com/. + + Args: + cache_client (CacheClient): The Momento cache client. + cache_name (str): The name of the cache to use to store the data. + ttl (Optional[timedelta], optional): The time to live for the cache items. + Defaults to None, ie use the client default TTL. + ensure_cache_exists (bool, optional): Create the cache if it doesn't + exist. Defaults to True. + + Raises: + ImportError: Momento python package is not installed. + TypeError: cache_client is not of type momento.CacheClientObject + ValueError: ttl is non-null and non-negative + """ + try: + from momento import CacheClient + except ImportError: + raise ImportError( + "Could not import momento python package. " + "Please install it with `pip install momento`." + ) + if not isinstance(cache_client, CacheClient): + raise TypeError("cache_client must be a momento.CacheClient object.") + _validate_ttl(ttl) + if ensure_cache_exists: + _ensure_cache_exists(cache_client, cache_name) + + self.cache_client = cache_client + self.cache_name = cache_name + self.ttl = ttl + + @classmethod + def from_client_params( + cls, + cache_name: str, + ttl: timedelta, + *, + configuration: Optional[momento.config.Configuration] = None, + api_key: Optional[str] = None, + auth_token: Optional[str] = None, # for backwards compatibility + **kwargs: Any, + ) -> MomentoCache: + """Construct cache from CacheClient parameters.""" + try: + from momento import CacheClient, Configurations, CredentialProvider + except ImportError: + raise ImportError( + "Could not import momento python package. " + "Please install it with `pip install momento`." + ) + if configuration is None: + configuration = Configurations.Laptop.v1() + + # Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility + try: + api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN") + except ValueError: + api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY") + credentials = CredentialProvider.from_string(api_key) + cache_client = CacheClient(configuration, credentials, default_ttl=ttl) + return cls(cache_client, cache_name, ttl=ttl, **kwargs) + + def __key(self, prompt: str, llm_string: str) -> str: + """Compute cache key from prompt and associated model and settings. + + Args: + prompt (str): The prompt run through the language model. + llm_string (str): The language model version and settings. + + Returns: + str: The cache key. + """ + return _hash(prompt + llm_string) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Lookup llm generations in cache by prompt and associated model and settings. + + Args: + prompt (str): The prompt run through the language model. + llm_string (str): The language model version and settings. + + Raises: + SdkException: Momento service or network error + + Returns: + Optional[RETURN_VAL_TYPE]: A list of language model generations. + """ + from momento.responses import CacheGet + + generations: RETURN_VAL_TYPE = [] + + get_response = self.cache_client.get( + self.cache_name, self.__key(prompt, llm_string) + ) + if isinstance(get_response, CacheGet.Hit): + value = get_response.value_string + generations = _load_generations_from_json(value) + elif isinstance(get_response, CacheGet.Miss): + pass + elif isinstance(get_response, CacheGet.Error): + raise get_response.inner_exception + return generations if generations else None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Store llm generations in cache. + + Args: + prompt (str): The prompt run through the language model. + llm_string (str): The language model string. + return_val (RETURN_VAL_TYPE): A list of language model generations. + + Raises: + SdkException: Momento service or network error + Exception: Unexpected response + """ + for gen in return_val: + if not isinstance(gen, Generation): + raise ValueError( + "Momento only supports caching of normal LLM generations, " + f"got {type(gen)}" + ) + key = self.__key(prompt, llm_string) + value = _dump_generations_to_json(return_val) + set_response = self.cache_client.set(self.cache_name, key, value, self.ttl) + from momento.responses import CacheSet + + if isinstance(set_response, CacheSet.Success): + pass + elif isinstance(set_response, CacheSet.Error): + raise set_response.inner_exception + else: + raise Exception(f"Unexpected response: {set_response}") + + def clear(self, **kwargs: Any) -> None: + """Clear the cache. + + Raises: + SdkException: Momento service or network error + """ + from momento.responses import CacheFlush + + flush_response = self.cache_client.flush_cache(self.cache_name) + if isinstance(flush_response, CacheFlush.Success): + pass + elif isinstance(flush_response, CacheFlush.Error): + raise flush_response.inner_exception + + +CASSANDRA_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_cache" +CASSANDRA_CACHE_DEFAULT_TTL_SECONDS = None + + +class CassandraCache(BaseCache): + """ + Cache that uses Cassandra / Astra DB as a backend. + + It uses a single Cassandra table. + The lookup keys (which get to form the primary key) are: + - prompt, a string + - llm_string, a deterministic str representation of the model parameters. + (needed to prevent collisions same-prompt-different-model collisions) + """ + + def __init__( + self, + session: Optional[CassandraSession] = None, + keyspace: Optional[str] = None, + table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME, + ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS, + skip_provisioning: bool = False, + ): + """ + Initialize with a ready session and a keyspace name. + Args: + session (cassandra.cluster.Session): an open Cassandra session + keyspace (str): the keyspace to use for storing the cache + table_name (str): name of the Cassandra table to use as cache + ttl_seconds (optional int): time-to-live for cache entries + (default: None, i.e. forever) + """ + try: + from cassio.table import ElasticCassandraTable + except (ImportError, ModuleNotFoundError): + raise ValueError( + "Could not import cassio python package. " + "Please install it with `pip install cassio`." + ) + + self.session = session + self.keyspace = keyspace + self.table_name = table_name + self.ttl_seconds = ttl_seconds + + self.kv_cache = ElasticCassandraTable( + session=self.session, + keyspace=self.keyspace, + table=self.table_name, + keys=["llm_string", "prompt"], + primary_key_type=["TEXT", "TEXT"], + ttl_seconds=self.ttl_seconds, + skip_provisioning=skip_provisioning, + ) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + item = self.kv_cache.get( + llm_string=_hash(llm_string), + prompt=_hash(prompt), + ) + if item is not None: + generations = _loads_generations(item["body_blob"]) + # this protects against malformed cached items: + if generations is not None: + return generations + else: + return None + else: + return None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + blob = _dumps_generations(return_val) + self.kv_cache.put( + llm_string=_hash(llm_string), + prompt=_hash(prompt), + body_blob=blob, + ) + + def delete_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> None: + """ + A wrapper around `delete` with the LLM being passed. + In case the llm(prompt) calls have a `stop` param, you should pass it here + """ + llm_string = get_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + )[1] + return self.delete(prompt, llm_string=llm_string) + + def delete(self, prompt: str, llm_string: str) -> None: + """Evict from cache if there's an entry.""" + return self.kv_cache.delete( + llm_string=_hash(llm_string), + prompt=_hash(prompt), + ) + + def clear(self, **kwargs: Any) -> None: + """Clear cache. This is for all LLMs at once.""" + self.kv_cache.clear() + + +CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot" +CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85 +CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_semantic_cache" +CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS = None +CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 + + +class CassandraSemanticCache(BaseCache): + """ + Cache that uses Cassandra as a vector-store backend for semantic + (i.e. similarity-based) lookup. + + It uses a single (vector) Cassandra table and stores, in principle, + cached values from several LLMs, so the LLM's llm_string is part + of the rows' primary keys. + + The similarity is based on one of several distance metrics (default: "dot"). + If choosing another metric, the default threshold is to be re-tuned accordingly. + """ + + def __init__( + self, + session: Optional[CassandraSession], + keyspace: Optional[str], + embedding: Embeddings, + table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME, + distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC, + score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD, + ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS, + skip_provisioning: bool = False, + ): + """ + Initialize the cache with all relevant parameters. + Args: + session (cassandra.cluster.Session): an open Cassandra session + keyspace (str): the keyspace to use for storing the cache + embedding (Embedding): Embedding provider for semantic + encoding and search. + table_name (str): name of the Cassandra (vector) table + to use as cache + distance_metric (str, 'dot'): which measure to adopt for + similarity searches + score_threshold (optional float): numeric value to use as + cutoff for the similarity searches + ttl_seconds (optional int): time-to-live for cache entries + (default: None, i.e. forever) + The default score threshold is tuned to the default metric. + Tune it carefully yourself if switching to another distance metric. + """ + try: + from cassio.table import MetadataVectorCassandraTable + except (ImportError, ModuleNotFoundError): + raise ValueError( + "Could not import cassio python package. " + "Please install it with `pip install cassio`." + ) + self.session = session + self.keyspace = keyspace + self.embedding = embedding + self.table_name = table_name + self.distance_metric = distance_metric + self.score_threshold = score_threshold + self.ttl_seconds = ttl_seconds + + # The contract for this class has separate lookup and update: + # in order to spare some embedding calculations we cache them between + # the two calls. + # Note: each instance of this class has its own `_get_embedding` with + # its own lru. + @lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) + def _cache_embedding(text: str) -> List[float]: + return self.embedding.embed_query(text=text) + + self._get_embedding = _cache_embedding + self.embedding_dimension = self._get_embedding_dimension() + + self.table = MetadataVectorCassandraTable( + session=self.session, + keyspace=self.keyspace, + table=self.table_name, + primary_key_type=["TEXT"], + vector_dimension=self.embedding_dimension, + ttl_seconds=self.ttl_seconds, + metadata_indexing=("allow", {"_llm_string_hash"}), + skip_provisioning=skip_provisioning, + ) + + def _get_embedding_dimension(self) -> int: + return len(self._get_embedding(text="This is a sample sentence.")) + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + embedding_vector = self._get_embedding(text=prompt) + llm_string_hash = _hash(llm_string) + body = _dumps_generations(return_val) + metadata = { + "_prompt": prompt, + "_llm_string_hash": llm_string_hash, + } + row_id = f"{_hash(prompt)}-{llm_string_hash}" + # + self.table.put( + body_blob=body, + vector=embedding_vector, + row_id=row_id, + metadata=metadata, + ) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + hit_with_id = self.lookup_with_id(prompt, llm_string) + if hit_with_id is not None: + return hit_with_id[1] + else: + return None + + def lookup_with_id( + self, prompt: str, llm_string: str + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + """ + Look up based on prompt and llm_string. + If there are hits, return (document_id, cached_entry) + """ + prompt_embedding: List[float] = self._get_embedding(text=prompt) + hits = list( + self.table.metric_ann_search( + vector=prompt_embedding, + metadata={"_llm_string_hash": _hash(llm_string)}, + n=1, + metric=self.distance_metric, + metric_threshold=self.score_threshold, + ) + ) + if hits: + hit = hits[0] + generations = _loads_generations(hit["body_blob"]) + if generations is not None: + # this protects against malformed cached items: + return ( + hit["row_id"], + generations, + ) + else: + return None + else: + return None + + def lookup_with_id_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + llm_string = get_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + )[1] + return self.lookup_with_id(prompt, llm_string=llm_string) + + def delete_by_document_id(self, document_id: str) -> None: + """ + Given this is a "similarity search" cache, an invalidation pattern + that makes sense is first a lookup to get an ID, and then deleting + with that ID. This is for the second step. + """ + self.table.delete(row_id=document_id) + + def clear(self, **kwargs: Any) -> None: + """Clear the *whole* semantic cache.""" + self.table.clear() + + +class FullMd5LLMCache(Base): # type: ignore + """SQLite table for full LLM Cache (all generations).""" + + __tablename__ = "full_md5_llm_cache" + id = Column(String, primary_key=True) + prompt_md5 = Column(String, index=True) + llm = Column(String, index=True) + idx = Column(Integer, index=True) + prompt = Column(String) + response = Column(String) + + +class SQLAlchemyMd5Cache(BaseCache): + """Cache that uses SQAlchemy as a backend.""" + + def __init__( + self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache + ): + """Initialize by creating all tables.""" + self.engine = engine + self.cache_schema = cache_schema + self.cache_schema.metadata.create_all(self.engine) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + rows = self._search_rows(prompt, llm_string) + if rows: + return [loads(row[0]) for row in rows] + return None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update based on prompt and llm_string.""" + self._delete_previous(prompt, llm_string) + prompt_md5 = self.get_md5(prompt) + items = [ + self.cache_schema( + id=str(uuid.uuid1()), + prompt=prompt, + prompt_md5=prompt_md5, + llm=llm_string, + response=dumps(gen), + idx=i, + ) + for i, gen in enumerate(return_val) + ] + with Session(self.engine) as session, session.begin(): + for item in items: + session.merge(item) + + def _delete_previous(self, prompt: str, llm_string: str) -> None: + stmt = ( + select(self.cache_schema.response) + .where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore + .where(self.cache_schema.llm == llm_string) + .where(self.cache_schema.prompt == prompt) + .order_by(self.cache_schema.idx) + ) + with Session(self.engine) as session, session.begin(): + rows = session.execute(stmt).fetchall() + for item in rows: + session.delete(item) + + def _search_rows(self, prompt: str, llm_string: str) -> List[Row]: + prompt_pd5 = self.get_md5(prompt) + stmt = ( + select(self.cache_schema.response) + .where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore + .where(self.cache_schema.llm == llm_string) + .where(self.cache_schema.prompt == prompt) + .order_by(self.cache_schema.idx) + ) + with Session(self.engine) as session: + return session.execute(stmt).fetchall() + + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + with Session(self.engine) as session: + session.execute(self.cache_schema.delete()) + + @staticmethod + def get_md5(input_string: str) -> str: + return hashlib.md5(input_string.encode()).hexdigest() + + +ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache" + + +class AstraDBCache(BaseCache): + """ + Cache that uses Astra DB as a backend. + + It uses a single collection as a kv store + The lookup keys, combined in the _id of the documents, are: + - prompt, a string + - llm_string, a deterministic str representation of the model parameters. + (needed to prevent same-prompt-different-model collisions) + """ + + def __init__( + self, + *, + collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + namespace: Optional[str] = None, + ): + """ + Create an AstraDB cache using a collection for storage. + + Args (only keyword-arguments accepted): + collection_name (str): name of the Astra DB collection to create/use. + token (Optional[str]): API token for Astra DB usage. + api_endpoint (Optional[str]): full URL to the API endpoint, + such as "https://-us-east1.apps.astra.datastax.com". + astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + namespace (Optional[str]): namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + """ + try: + from astrapy.db import ( + AstraDB as LibAstraDB, + ) + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent astrapy python package. " + "Please install it with `pip install --upgrade astrapy`." + ) + # Conflicting-arg checks: + if astra_db_client is not None: + if token is not None or api_endpoint is not None: + raise ValueError( + "You cannot pass 'astra_db_client' to AstraDB if passing " + "'token' and 'api_endpoint'." + ) + + self.collection_name = collection_name + self.token = token + self.api_endpoint = api_endpoint + self.namespace = namespace + + if astra_db_client is not None: + self.astra_db = astra_db_client + else: + self.astra_db = LibAstraDB( + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace, + ) + self.collection = self.astra_db.create_collection( + collection_name=self.collection_name, + ) + + @staticmethod + def _make_id(prompt: str, llm_string: str) -> str: + return f"{_hash(prompt)}#{_hash(llm_string)}" + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + doc_id = self._make_id(prompt, llm_string) + item = self.collection.find_one( + filter={ + "_id": doc_id, + }, + projection={ + "body_blob": 1, + }, + )["data"]["document"] + if item is not None: + generations = _loads_generations(item["body_blob"]) + # this protects against malformed cached items: + if generations is not None: + return generations + else: + return None + else: + return None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + doc_id = self._make_id(prompt, llm_string) + blob = _dumps_generations(return_val) + self.collection.upsert( + { + "_id": doc_id, + "body_blob": blob, + }, + ) + + def delete_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> None: + """ + A wrapper around `delete` with the LLM being passed. + In case the llm(prompt) calls have a `stop` param, you should pass it here + """ + llm_string = get_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + )[1] + return self.delete(prompt, llm_string=llm_string) + + def delete(self, prompt: str, llm_string: str) -> None: + """Evict from cache if there's an entry.""" + doc_id = self._make_id(prompt, llm_string) + return self.collection.delete_one(doc_id) + + def clear(self, **kwargs: Any) -> None: + """Clear cache. This is for all LLMs at once.""" + self.astra_db.truncate_collection(self.collection_name) + + +ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85 +ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache" +ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 + + +class AstraDBSemanticCache(BaseCache): + """ + Cache that uses Astra DB as a vector-store backend for semantic + (i.e. similarity-based) lookup. + + It uses a single (vector) collection and can store + cached values from several LLMs, so the LLM's 'llm_string' is stored + in the document metadata. + + You can choose the preferred similarity (or use the API default) -- + remember the threshold might require metric-dependend tuning. + """ + + def __init__( + self, + *, + collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + namespace: Optional[str] = None, + embedding: Embeddings, + metric: Optional[str] = None, + similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, + ): + """ + Initialize the cache with all relevant parameters. + Args: + + collection_name (str): name of the Astra DB collection to create/use. + token (Optional[str]): API token for Astra DB usage. + api_endpoint (Optional[str]): full URL to the API endpoint, + such as "https://-us-east1.apps.astra.datastax.com". + astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + namespace (Optional[str]): namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + embedding (Embedding): Embedding provider for semantic + encoding and search. + metric: the function to use for evaluating similarity of text embeddings. + Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product') + similarity_threshold (float, optional): the minimum similarity + for accepting a (semantic-search) match. + + The default score threshold is tuned to the default metric. + Tune it carefully yourself if switching to another distance metric. + """ + try: + from astrapy.db import ( + AstraDB as LibAstraDB, + ) + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent astrapy python package. " + "Please install it with `pip install --upgrade astrapy`." + ) + # Conflicting-arg checks: + if astra_db_client is not None: + if token is not None or api_endpoint is not None: + raise ValueError( + "You cannot pass 'astra_db_client' to AstraDB if passing " + "'token' and 'api_endpoint'." + ) + + self.embedding = embedding + self.metric = metric + self.similarity_threshold = similarity_threshold + + # The contract for this class has separate lookup and update: + # in order to spare some embedding calculations we cache them between + # the two calls. + # Note: each instance of this class has its own `_get_embedding` with + # its own lru. + @lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) + def _cache_embedding(text: str) -> List[float]: + return self.embedding.embed_query(text=text) + + self._get_embedding = _cache_embedding + self.embedding_dimension = self._get_embedding_dimension() + + self.collection_name = collection_name + self.token = token + self.api_endpoint = api_endpoint + self.namespace = namespace + + if astra_db_client is not None: + self.astra_db = astra_db_client + else: + self.astra_db = LibAstraDB( + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace, + ) + self.collection = self.astra_db.create_collection( + collection_name=self.collection_name, + dimension=self.embedding_dimension, + metric=self.metric, + ) + + def _get_embedding_dimension(self) -> int: + return len(self._get_embedding(text="This is a sample sentence.")) + + @staticmethod + def _make_id(prompt: str, llm_string: str) -> str: + return f"{_hash(prompt)}#{_hash(llm_string)}" + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + doc_id = self._make_id(prompt, llm_string) + llm_string_hash = _hash(llm_string) + embedding_vector = self._get_embedding(text=prompt) + body = _dumps_generations(return_val) + # + self.collection.upsert( + { + "_id": doc_id, + "body_blob": body, + "llm_string_hash": llm_string_hash, + "$vector": embedding_vector, + } + ) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + hit_with_id = self.lookup_with_id(prompt, llm_string) + if hit_with_id is not None: + return hit_with_id[1] + else: + return None + + def lookup_with_id( + self, prompt: str, llm_string: str + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + """ + Look up based on prompt and llm_string. + If there are hits, return (document_id, cached_entry) for the top hit + """ + prompt_embedding: List[float] = self._get_embedding(text=prompt) + llm_string_hash = _hash(llm_string) + + hit = self.collection.vector_find_one( + vector=prompt_embedding, + filter={ + "llm_string_hash": llm_string_hash, + }, + fields=["body_blob", "_id"], + include_similarity=True, + ) + + if hit is None or hit["$similarity"] < self.similarity_threshold: + return None + else: + generations = _loads_generations(hit["body_blob"]) + if generations is not None: + # this protects against malformed cached items: + return (hit["_id"], generations) + else: + return None + + def lookup_with_id_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + llm_string = get_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + )[1] + return self.lookup_with_id(prompt, llm_string=llm_string) + + def delete_by_document_id(self, document_id: str) -> None: + """ + Given this is a "similarity search" cache, an invalidation pattern + that makes sense is first a lookup to get an ID, and then deleting + with that ID. This is for the second step. + """ + self.collection.delete_one(document_id) + + def clear(self, **kwargs: Any) -> None: + """Clear the *whole* semantic cache.""" + self.astra_db.truncate_collection(self.collection_name) diff --git a/libs/community/langchain_community/callbacks/__init__.py b/libs/community/langchain_community/callbacks/__init__.py new file mode 100644 index 00000000000..6016e8304d7 --- /dev/null +++ b/libs/community/langchain_community/callbacks/__init__.py @@ -0,0 +1,66 @@ +"""**Callback handlers** allow listening to events in LangChain. + +**Class hierarchy:** + +.. code-block:: + + BaseCallbackHandler --> CallbackHandler # Example: AimCallbackHandler +""" + +from langchain_community.callbacks.aim_callback import AimCallbackHandler +from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler +from langchain_community.callbacks.arize_callback import ArizeCallbackHandler +from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler +from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler +from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler +from langchain_community.callbacks.context_callback import ContextCallbackHandler +from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler +from langchain_community.callbacks.human import HumanApprovalCallbackHandler +from langchain_community.callbacks.infino_callback import InfinoCallbackHandler +from langchain_community.callbacks.labelstudio_callback import ( + LabelStudioCallbackHandler, +) +from langchain_community.callbacks.llmonitor_callback import LLMonitorCallbackHandler +from langchain_community.callbacks.manager import ( + get_openai_callback, + wandb_tracing_enabled, +) +from langchain_community.callbacks.mlflow_callback import MlflowCallbackHandler +from langchain_community.callbacks.openai_info import OpenAICallbackHandler +from langchain_community.callbacks.promptlayer_callback import ( + PromptLayerCallbackHandler, +) +from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler +from langchain_community.callbacks.streamlit import ( + LLMThoughtLabeler, + StreamlitCallbackHandler, +) +from langchain_community.callbacks.trubrics_callback import TrubricsCallbackHandler +from langchain_community.callbacks.wandb_callback import WandbCallbackHandler +from langchain_community.callbacks.whylabs_callback import WhyLabsCallbackHandler + +__all__ = [ + "AimCallbackHandler", + "ArgillaCallbackHandler", + "ArizeCallbackHandler", + "PromptLayerCallbackHandler", + "ArthurCallbackHandler", + "ClearMLCallbackHandler", + "CometCallbackHandler", + "ContextCallbackHandler", + "HumanApprovalCallbackHandler", + "InfinoCallbackHandler", + "MlflowCallbackHandler", + "LLMonitorCallbackHandler", + "OpenAICallbackHandler", + "LLMThoughtLabeler", + "StreamlitCallbackHandler", + "WandbCallbackHandler", + "WhyLabsCallbackHandler", + "get_openai_callback", + "wandb_tracing_enabled", + "FlyteCallbackHandler", + "SageMakerCallbackHandler", + "LabelStudioCallbackHandler", + "TrubricsCallbackHandler", +] diff --git a/libs/community/langchain_community/callbacks/aim_callback.py b/libs/community/langchain_community/callbacks/aim_callback.py new file mode 100644 index 00000000000..da194b34b3d --- /dev/null +++ b/libs/community/langchain_community/callbacks/aim_callback.py @@ -0,0 +1,430 @@ +from copy import deepcopy +from typing import Any, Dict, List, Optional + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + + +def import_aim() -> Any: + """Import the aim python package and raise an error if it is not installed.""" + try: + import aim + except ImportError: + raise ImportError( + "To use the Aim callback manager you need to have the" + " `aim` python package installed." + "Please install it with `pip install aim`" + ) + return aim + + +class BaseMetadataCallbackHandler: + """This class handles the metadata and associated function states for callbacks. + + Attributes: + step (int): The current step. + starts (int): The number of times the start method has been called. + ends (int): The number of times the end method has been called. + errors (int): The number of times the error method has been called. + text_ctr (int): The number of times the text method has been called. + ignore_llm_ (bool): Whether to ignore llm callbacks. + ignore_chain_ (bool): Whether to ignore chain callbacks. + ignore_agent_ (bool): Whether to ignore agent callbacks. + ignore_retriever_ (bool): Whether to ignore retriever callbacks. + always_verbose_ (bool): Whether to always be verbose. + chain_starts (int): The number of times the chain start method has been called. + chain_ends (int): The number of times the chain end method has been called. + llm_starts (int): The number of times the llm start method has been called. + llm_ends (int): The number of times the llm end method has been called. + llm_streams (int): The number of times the text method has been called. + tool_starts (int): The number of times the tool start method has been called. + tool_ends (int): The number of times the tool end method has been called. + agent_ends (int): The number of times the agent end method has been called. + """ + + def __init__(self) -> None: + self.step = 0 + + self.starts = 0 + self.ends = 0 + self.errors = 0 + self.text_ctr = 0 + + self.ignore_llm_ = False + self.ignore_chain_ = False + self.ignore_agent_ = False + self.ignore_retriever_ = False + self.always_verbose_ = False + + self.chain_starts = 0 + self.chain_ends = 0 + + self.llm_starts = 0 + self.llm_ends = 0 + self.llm_streams = 0 + + self.tool_starts = 0 + self.tool_ends = 0 + + self.agent_ends = 0 + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return self.always_verbose_ + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ + + @property + def ignore_retriever(self) -> bool: + """Whether to ignore retriever callbacks.""" + return self.ignore_retriever_ + + def get_custom_callback_meta(self) -> Dict[str, Any]: + return { + "step": self.step, + "starts": self.starts, + "ends": self.ends, + "errors": self.errors, + "text_ctr": self.text_ctr, + "chain_starts": self.chain_starts, + "chain_ends": self.chain_ends, + "llm_starts": self.llm_starts, + "llm_ends": self.llm_ends, + "llm_streams": self.llm_streams, + "tool_starts": self.tool_starts, + "tool_ends": self.tool_ends, + "agent_ends": self.agent_ends, + } + + def reset_callback_meta(self) -> None: + """Reset the callback metadata.""" + self.step = 0 + + self.starts = 0 + self.ends = 0 + self.errors = 0 + self.text_ctr = 0 + + self.ignore_llm_ = False + self.ignore_chain_ = False + self.ignore_agent_ = False + self.always_verbose_ = False + + self.chain_starts = 0 + self.chain_ends = 0 + + self.llm_starts = 0 + self.llm_ends = 0 + self.llm_streams = 0 + + self.tool_starts = 0 + self.tool_ends = 0 + + self.agent_ends = 0 + + return None + + +class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): + """Callback Handler that logs to Aim. + + Parameters: + repo (:obj:`str`, optional): Aim repository path or Repo object to which + Run object is bound. If skipped, default Repo is used. + experiment_name (:obj:`str`, optional): Sets Run's `experiment` property. + 'default' if not specified. Can be used later to query runs/sequences. + system_tracking_interval (:obj:`int`, optional): Sets the tracking interval + in seconds for system usage metrics (CPU, Memory, etc.). Set to `None` + to disable system metrics tracking. + log_system_params (:obj:`bool`, optional): Enable/Disable logging of system + params such as installed packages, git info, environment variables, etc. + + This handler will utilize the associated callback method called and formats + the input of each callback function with metadata regarding the state of LLM run + and then logs the response to Aim. + """ + + def __init__( + self, + repo: Optional[str] = None, + experiment_name: Optional[str] = None, + system_tracking_interval: Optional[int] = 10, + log_system_params: bool = True, + ) -> None: + """Initialize callback handler.""" + + super().__init__() + + aim = import_aim() + self.repo = repo + self.experiment_name = experiment_name + self.system_tracking_interval = system_tracking_interval + self.log_system_params = log_system_params + self._run = aim.Run( + repo=self.repo, + experiment=self.experiment_name, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + self._run_hash = self._run.hash + self.action_records: list = [] + + def setup(self, **kwargs: Any) -> None: + aim = import_aim() + + if not self._run: + if self._run_hash: + self._run = aim.Run( + self._run_hash, + repo=self.repo, + system_tracking_interval=self.system_tracking_interval, + ) + else: + self._run = aim.Run( + repo=self.repo, + experiment=self.experiment_name, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + self._run_hash = self._run.hash + + if kwargs: + for key, value in kwargs.items(): + self._run.set(key, value, strict=False) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts.""" + aim = import_aim() + + self.step += 1 + self.llm_starts += 1 + self.starts += 1 + + resp = {"action": "on_llm_start"} + resp.update(self.get_custom_callback_meta()) + + prompts_res = deepcopy(prompts) + + self._run.track( + [aim.Text(prompt) for prompt in prompts_res], + name="on_llm_start", + context=resp, + ) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + aim = import_aim() + self.step += 1 + self.llm_ends += 1 + self.ends += 1 + + resp = {"action": "on_llm_end"} + resp.update(self.get_custom_callback_meta()) + + response_res = deepcopy(response) + + generated = [ + aim.Text(generation.text) + for generations in response_res.generations + for generation in generations + ] + self._run.track( + generated, + name="on_llm_end", + context=resp, + ) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + self.step += 1 + self.llm_streams += 1 + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when LLM errors.""" + self.step += 1 + self.errors += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + aim = import_aim() + self.step += 1 + self.chain_starts += 1 + self.starts += 1 + + resp = {"action": "on_chain_start"} + resp.update(self.get_custom_callback_meta()) + + inputs_res = deepcopy(inputs) + + self._run.track( + aim.Text(inputs_res["input"]), name="on_chain_start", context=resp + ) + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + aim = import_aim() + self.step += 1 + self.chain_ends += 1 + self.ends += 1 + + resp = {"action": "on_chain_end"} + resp.update(self.get_custom_callback_meta()) + + outputs_res = deepcopy(outputs) + + self._run.track( + aim.Text(outputs_res["output"]), name="on_chain_end", context=resp + ) + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when chain errors.""" + self.step += 1 + self.errors += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + aim = import_aim() + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + resp = {"action": "on_tool_start"} + resp.update(self.get_custom_callback_meta()) + + self._run.track(aim.Text(input_str), name="on_tool_start", context=resp) + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + aim = import_aim() + self.step += 1 + self.tool_ends += 1 + self.ends += 1 + + resp = {"action": "on_tool_end"} + resp.update(self.get_custom_callback_meta()) + + self._run.track(aim.Text(output), name="on_tool_end", context=resp) + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when tool errors.""" + self.step += 1 + self.errors += 1 + + def on_text(self, text: str, **kwargs: Any) -> None: + """ + Run when agent is ending. + """ + self.step += 1 + self.text_ctr += 1 + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + aim = import_aim() + self.step += 1 + self.agent_ends += 1 + self.ends += 1 + + resp = {"action": "on_agent_finish"} + resp.update(self.get_custom_callback_meta()) + + finish_res = deepcopy(finish) + + text = "OUTPUT:\n{}\n\nLOG:\n{}".format( + finish_res.return_values["output"], finish_res.log + ) + self._run.track(aim.Text(text), name="on_agent_finish", context=resp) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + aim = import_aim() + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + resp = { + "action": "on_agent_action", + "tool": action.tool, + } + resp.update(self.get_custom_callback_meta()) + + action_res = deepcopy(action) + + text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format( + action_res.tool_input, action_res.log + ) + self._run.track(aim.Text(text), name="on_agent_action", context=resp) + + def flush_tracker( + self, + repo: Optional[str] = None, + experiment_name: Optional[str] = None, + system_tracking_interval: Optional[int] = 10, + log_system_params: bool = True, + langchain_asset: Any = None, + reset: bool = True, + finish: bool = False, + ) -> None: + """Flush the tracker and reset the session. + + Args: + repo (:obj:`str`, optional): Aim repository path or Repo object to which + Run object is bound. If skipped, default Repo is used. + experiment_name (:obj:`str`, optional): Sets Run's `experiment` property. + 'default' if not specified. Can be used later to query runs/sequences. + system_tracking_interval (:obj:`int`, optional): Sets the tracking interval + in seconds for system usage metrics (CPU, Memory, etc.). Set to `None` + to disable system metrics tracking. + log_system_params (:obj:`bool`, optional): Enable/Disable logging of system + params such as installed packages, git info, environment variables, etc. + langchain_asset: The langchain asset to save. + reset: Whether to reset the session. + finish: Whether to finish the run. + + Returns: + None + """ + + if langchain_asset: + try: + for key, value in langchain_asset.dict().items(): + self._run.set(key, value, strict=False) + except Exception: + pass + + if finish or reset: + self._run.close() + self.reset_callback_meta() + if reset: + self.__init__( # type: ignore + repo=repo if repo else self.repo, + experiment_name=experiment_name + if experiment_name + else self.experiment_name, + system_tracking_interval=system_tracking_interval + if system_tracking_interval + else self.system_tracking_interval, + log_system_params=log_system_params + if log_system_params + else self.log_system_params, + ) diff --git a/libs/community/langchain_community/callbacks/argilla_callback.py b/libs/community/langchain_community/callbacks/argilla_callback.py new file mode 100644 index 00000000000..b280e0ce1a6 --- /dev/null +++ b/libs/community/langchain_community/callbacks/argilla_callback.py @@ -0,0 +1,352 @@ +import os +import warnings +from typing import Any, Dict, List, Optional + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult +from packaging.version import parse + + +class ArgillaCallbackHandler(BaseCallbackHandler): + """Callback Handler that logs into Argilla. + + Args: + dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must + exist in advance. If you need help on how to create a `FeedbackDataset` in + Argilla, please visit + https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html. + workspace_name: name of the workspace in Argilla where the specified + `FeedbackDataset` lives in. Defaults to `None`, which means that the + default workspace will be used. + api_url: URL of the Argilla Server that we want to use, and where the + `FeedbackDataset` lives in. Defaults to `None`, which means that either + `ARGILLA_API_URL` environment variable or the default will be used. + api_key: API Key to connect to the Argilla Server. Defaults to `None`, which + means that either `ARGILLA_API_KEY` environment variable or the default + will be used. + + Raises: + ImportError: if the `argilla` package is not installed. + ConnectionError: if the connection to Argilla fails. + FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails. + + Examples: + >>> from langchain_community.llms import OpenAI + >>> from langchain_community.callbacks import ArgillaCallbackHandler + >>> argilla_callback = ArgillaCallbackHandler( + ... dataset_name="my-dataset", + ... workspace_name="my-workspace", + ... api_url="http://localhost:6900", + ... api_key="argilla.apikey", + ... ) + >>> llm = OpenAI( + ... temperature=0, + ... callbacks=[argilla_callback], + ... verbose=True, + ... openai_api_key="API_KEY_HERE", + ... ) + >>> llm.generate([ + ... "What is the best NLP-annotation tool out there? (no bias at all)", + ... ]) + "Argilla, no doubt about it." + """ + + REPO_URL: str = "https://github.com/argilla-io/argilla" + ISSUES_URL: str = f"{REPO_URL}/issues" + BLOG_URL: str = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501 + + DEFAULT_API_URL: str = "http://localhost:6900" + + def __init__( + self, + dataset_name: str, + workspace_name: Optional[str] = None, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + ) -> None: + """Initializes the `ArgillaCallbackHandler`. + + Args: + dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must + exist in advance. If you need help on how to create a `FeedbackDataset` + in Argilla, please visit + https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html. + workspace_name: name of the workspace in Argilla where the specified + `FeedbackDataset` lives in. Defaults to `None`, which means that the + default workspace will be used. + api_url: URL of the Argilla Server that we want to use, and where the + `FeedbackDataset` lives in. Defaults to `None`, which means that either + `ARGILLA_API_URL` environment variable or the default will be used. + api_key: API Key to connect to the Argilla Server. Defaults to `None`, which + means that either `ARGILLA_API_KEY` environment variable or the default + will be used. + + Raises: + ImportError: if the `argilla` package is not installed. + ConnectionError: if the connection to Argilla fails. + FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails. + """ + + super().__init__() + + # Import Argilla (not via `import_argilla` to keep hints in IDEs) + try: + import argilla as rg # noqa: F401 + + self.ARGILLA_VERSION = rg.__version__ + except ImportError: + raise ImportError( + "To use the Argilla callback manager you need to have the `argilla` " + "Python package installed. Please install it with `pip install argilla`" + ) + + # Check whether the Argilla version is compatible + if parse(self.ARGILLA_VERSION) < parse("1.8.0"): + raise ImportError( + f"The installed `argilla` version is {self.ARGILLA_VERSION} but " + "`ArgillaCallbackHandler` requires at least version 1.8.0. Please " + "upgrade `argilla` with `pip install --upgrade argilla`." + ) + + # Show a warning message if Argilla will assume the default values will be used + if api_url is None and os.getenv("ARGILLA_API_URL") is None: + warnings.warn( + ( + "Since `api_url` is None, and the env var `ARGILLA_API_URL` is not" + f" set, it will default to `{self.DEFAULT_API_URL}`, which is the" + " default API URL in Argilla Quickstart." + ), + ) + api_url = self.DEFAULT_API_URL + + if api_key is None and os.getenv("ARGILLA_API_KEY") is None: + self.DEFAULT_API_KEY = ( + "admin.apikey" + if parse(self.ARGILLA_VERSION) < parse("1.11.0") + else "owner.apikey" + ) + + warnings.warn( + ( + "Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not" + f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the" + " default API key in Argilla Quickstart." + ), + ) + api_url = self.DEFAULT_API_URL + + # Connect to Argilla with the provided credentials, if applicable + try: + rg.init(api_key=api_key, api_url=api_url) + except Exception as e: + raise ConnectionError( + f"Could not connect to Argilla with exception: '{e}'.\n" + "Please check your `api_key` and `api_url`, and make sure that " + "the Argilla server is up and running. If the problem persists " + f"please report it to {self.ISSUES_URL} as an `integration` issue." + ) from e + + # Set the Argilla variables + self.dataset_name = dataset_name + self.workspace_name = workspace_name or rg.get_workspace() + + # Retrieve the `FeedbackDataset` from Argilla (without existing records) + try: + extra_args = {} + if parse(self.ARGILLA_VERSION) < parse("1.14.0"): + warnings.warn( + f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or" + " higher is recommended.", + UserWarning, + ) + extra_args = {"with_records": False} + self.dataset = rg.FeedbackDataset.from_argilla( + name=self.dataset_name, + workspace=self.workspace_name, + **extra_args, + ) + except Exception as e: + raise FileNotFoundError( + f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`." + f"\nPlease check that the dataset with name={self.dataset_name} in the" + f" workspace={self.workspace_name} exists in advance. If you need help" + " on how to create a `langchain`-compatible `FeedbackDataset` in" + f" Argilla, please visit {self.BLOG_URL}. If the problem persists" + f" please report it to {self.ISSUES_URL} as an `integration` issue." + ) from e + + supported_fields = ["prompt", "response"] + if supported_fields != [field.name for field in self.dataset.fields]: + raise ValueError( + f"`FeedbackDataset` with name={self.dataset_name} in the workspace=" + f"{self.workspace_name} had fields that are not supported yet for the" + f"`langchain` integration. Supported fields are: {supported_fields}," + f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501 + " For more information on how to create a `langchain`-compatible" + f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}." + ) + + self.prompts: Dict[str, List[str]] = {} + + warnings.warn( + ( + "The `ArgillaCallbackHandler` is currently in beta and is subject to" + " change based on updates to `langchain`. Please report any issues to" + f" {self.ISSUES_URL} as an `integration` issue." + ), + ) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Save the prompts in memory when an LLM starts.""" + self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts}) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing when a new token is generated.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Log records to Argilla when an LLM ends.""" + # Do nothing if there's a parent_run_id, since we will log the records when + # the chain ends + if kwargs["parent_run_id"]: + return + + # Creates the records and adds them to the `FeedbackDataset` + prompts = self.prompts[str(kwargs["run_id"])] + for prompt, generations in zip(prompts, response.generations): + self.dataset.add_records( + records=[ + { + "fields": { + "prompt": prompt, + "response": generation.text.strip(), + }, + } + for generation in generations + ] + ) + + # Pop current run from `self.runs` + self.prompts.pop(str(kwargs["run_id"])) + + if parse(self.ARGILLA_VERSION) < parse("1.14.0"): + # Push the records to Argilla + self.dataset.push_to_argilla() + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when LLM outputs an error.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """If the key `input` is in `inputs`, then save it in `self.prompts` using + either the `parent_run_id` or the `run_id` as the key. This is done so that + we don't log the same input prompt twice, once when the LLM starts and once + when the chain starts. + """ + if "input" in inputs: + self.prompts.update( + { + str(kwargs["parent_run_id"] or kwargs["run_id"]): ( + inputs["input"] + if isinstance(inputs["input"], list) + else [inputs["input"]] + ) + } + ) + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """If either the `parent_run_id` or the `run_id` is in `self.prompts`, then + log the outputs to Argilla, and pop the run from `self.prompts`. The behavior + differs if the output is a list or not. + """ + if not any( + key in self.prompts + for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])] + ): + return + prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get( + str(kwargs["run_id"]) + ) + for chain_output_key, chain_output_val in outputs.items(): + if isinstance(chain_output_val, list): + # Creates the records and adds them to the `FeedbackDataset` + self.dataset.add_records( + records=[ + { + "fields": { + "prompt": prompt, + "response": output["text"].strip(), + }, + } + for prompt, output in zip( + prompts, # type: ignore + chain_output_val, + ) + ] + ) + else: + # Creates the records and adds them to the `FeedbackDataset` + self.dataset.add_records( + records=[ + { + "fields": { + "prompt": " ".join(prompts), # type: ignore + "response": chain_output_val.strip(), + }, + } + ] + ) + + # Pop current run from `self.runs` + if str(kwargs["parent_run_id"]) in self.prompts: + self.prompts.pop(str(kwargs["parent_run_id"])) + if str(kwargs["run_id"]) in self.prompts: + self.prompts.pop(str(kwargs["run_id"])) + + if parse(self.ARGILLA_VERSION) < parse("1.14.0"): + # Push the records to Argilla + self.dataset.push_to_argilla() + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when LLM chain outputs an error.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Do nothing when tool starts.""" + pass + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Do nothing when agent takes a specific action.""" + pass + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Do nothing when tool ends.""" + pass + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when tool outputs an error.""" + pass + + def on_text(self, text: str, **kwargs: Any) -> None: + """Do nothing""" + pass + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Do nothing""" + pass diff --git a/libs/community/langchain_community/callbacks/arize_callback.py b/libs/community/langchain_community/callbacks/arize_callback.py new file mode 100644 index 00000000000..44212b61917 --- /dev/null +++ b/libs/community/langchain_community/callbacks/arize_callback.py @@ -0,0 +1,213 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +from langchain_community.callbacks.utils import import_pandas + + +class ArizeCallbackHandler(BaseCallbackHandler): + """Callback Handler that logs to Arize.""" + + def __init__( + self, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + SPACE_KEY: Optional[str] = None, + API_KEY: Optional[str] = None, + ) -> None: + """Initialize callback handler.""" + + super().__init__() + self.model_id = model_id + self.model_version = model_version + self.space_key = SPACE_KEY + self.api_key = API_KEY + self.prompt_records: List[str] = [] + self.response_records: List[str] = [] + self.prediction_ids: List[str] = [] + self.pred_timestamps: List[int] = [] + self.response_embeddings: List[float] = [] + self.prompt_embeddings: List[float] = [] + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.total_tokens = 0 + self.step = 0 + + from arize.pandas.embeddings import EmbeddingGenerator, UseCases + from arize.pandas.logger import Client + + self.generator = EmbeddingGenerator.from_use_case( + use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION, + model_name="distilbert-base-uncased", + tokenizer_max_length=512, + batch_size=256, + ) + self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY) + if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY": + raise ValueError("❌ CHANGE SPACE AND API KEYS") + else: + print("βœ… Arize client setup done! Now you can start using Arize!") + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + for prompt in prompts: + self.prompt_records.append(prompt.replace("\n", "")) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + pd = import_pandas() + from arize.utils.types import ( + EmbeddingColumnNames, + Environments, + ModelTypes, + Schema, + ) + + # Safe check if 'llm_output' and 'token_usage' exist + if response.llm_output and "token_usage" in response.llm_output: + self.prompt_tokens = response.llm_output["token_usage"].get( + "prompt_tokens", 0 + ) + self.total_tokens = response.llm_output["token_usage"].get( + "total_tokens", 0 + ) + self.completion_tokens = response.llm_output["token_usage"].get( + "completion_tokens", 0 + ) + else: + self.prompt_tokens = ( + self.total_tokens + ) = self.completion_tokens = 0 # assign default value + + for generations in response.generations: + for generation in generations: + prompt = self.prompt_records[self.step] + self.step = self.step + 1 + prompt_embedding = pd.Series( + self.generator.generate_embeddings( + text_col=pd.Series(prompt.replace("\n", " ")) + ).reset_index(drop=True) + ) + + # Assigning text to response_text instead of response + response_text = generation.text.replace("\n", " ") + response_embedding = pd.Series( + self.generator.generate_embeddings( + text_col=pd.Series(generation.text.replace("\n", " ")) + ).reset_index(drop=True) + ) + pred_timestamp = datetime.now().timestamp() + + # Define the columns and data + columns = [ + "prediction_ts", + "response", + "prompt", + "response_vector", + "prompt_vector", + "prompt_token", + "completion_token", + "total_token", + ] + data = [ + [ + pred_timestamp, + response_text, + prompt, + response_embedding[0], + prompt_embedding[0], + self.prompt_tokens, + self.total_tokens, + self.completion_tokens, + ] + ] + + # Create the DataFrame + df = pd.DataFrame(data, columns=columns) + + # Declare prompt and response columns + prompt_columns = EmbeddingColumnNames( + vector_column_name="prompt_vector", data_column_name="prompt" + ) + + response_columns = EmbeddingColumnNames( + vector_column_name="response_vector", data_column_name="response" + ) + + schema = Schema( + timestamp_column_name="prediction_ts", + tag_column_names=[ + "prompt_token", + "completion_token", + "total_token", + ], + prompt_column_names=prompt_columns, + response_column_names=response_columns, + ) + + response_from_arize = self.arize_client.log( + dataframe=df, + schema=schema, + model_id=self.model_id, + model_version=self.model_version, + model_type=ModelTypes.GENERATIVE_LLM, + environment=Environments.PRODUCTION, + ) + if response_from_arize.status_code == 200: + print("βœ… Successfully logged data to Arize!") + else: + print(f'❌ Logging failed "{response_from_arize.text}"') + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + pass + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + pass + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Do nothing.""" + pass + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + pass + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + pass + + def on_text(self, text: str, **kwargs: Any) -> None: + pass + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + pass diff --git a/libs/community/langchain_community/callbacks/arthur_callback.py b/libs/community/langchain_community/callbacks/arthur_callback.py new file mode 100644 index 00000000000..a5fce582ed1 --- /dev/null +++ b/libs/community/langchain_community/callbacks/arthur_callback.py @@ -0,0 +1,296 @@ +"""ArthurAI's Callback Handler.""" +from __future__ import annotations + +import os +import uuid +from collections import defaultdict +from datetime import datetime +from time import time +from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional + +import numpy as np +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +if TYPE_CHECKING: + import arthurai + from arthurai.core.models import ArthurModel + +PROMPT_TOKENS = "prompt_tokens" +COMPLETION_TOKENS = "completion_tokens" +TOKEN_USAGE = "token_usage" +FINISH_REASON = "finish_reason" +DURATION = "duration" + + +def _lazy_load_arthur() -> arthurai: + """Lazy load Arthur.""" + try: + import arthurai + except ImportError as e: + raise ImportError( + "To use the ArthurCallbackHandler you need the" + " `arthurai` package. Please install it with" + " `pip install arthurai`.", + e, + ) + + return arthurai + + +class ArthurCallbackHandler(BaseCallbackHandler): + """Callback Handler that logs to Arthur platform. + + Arthur helps enterprise teams optimize model operations + and performance at scale. The Arthur API tracks model + performance, explainability, and fairness across tabular, + NLP, and CV models. Our API is model- and platform-agnostic, + and continuously scales with complex and dynamic enterprise needs. + To learn more about Arthur, visit our website at + https://www.arthur.ai/ or read the Arthur docs at + https://docs.arthur.ai/ + """ + + def __init__( + self, + arthur_model: ArthurModel, + ) -> None: + """Initialize callback handler.""" + super().__init__() + arthurai = _lazy_load_arthur() + Stage = arthurai.common.constants.Stage + ValueType = arthurai.common.constants.ValueType + self.arthur_model = arthur_model + # save the attributes of this model to be used when preparing + # inferences to log to Arthur in on_llm_end() + self.attr_names = set([a.name for a in self.arthur_model.get_attributes()]) + self.input_attr = [ + x + for x in self.arthur_model.get_attributes() + if x.stage == Stage.ModelPipelineInput + and x.value_type == ValueType.Unstructured_Text + ][0].name + self.output_attr = [ + x + for x in self.arthur_model.get_attributes() + if x.stage == Stage.PredictedValue + and x.value_type == ValueType.Unstructured_Text + ][0].name + self.token_likelihood_attr = None + if ( + len( + [ + x + for x in self.arthur_model.get_attributes() + if x.value_type == ValueType.TokenLikelihoods + ] + ) + > 0 + ): + self.token_likelihood_attr = [ + x + for x in self.arthur_model.get_attributes() + if x.value_type == ValueType.TokenLikelihoods + ][0].name + + self.run_map: DefaultDict[str, Any] = defaultdict(dict) + + @classmethod + def from_credentials( + cls, + model_id: str, + arthur_url: Optional[str] = "https://app.arthur.ai", + arthur_login: Optional[str] = None, + arthur_password: Optional[str] = None, + ) -> ArthurCallbackHandler: + """Initialize callback handler from Arthur credentials. + + Args: + model_id (str): The ID of the arthur model to log to. + arthur_url (str, optional): The URL of the Arthur instance to log to. + Defaults to "https://app.arthur.ai". + arthur_login (str, optional): The login to use to connect to Arthur. + Defaults to None. + arthur_password (str, optional): The password to use to connect to + Arthur. Defaults to None. + + Returns: + ArthurCallbackHandler: The initialized callback handler. + """ + arthurai = _lazy_load_arthur() + ArthurAI = arthurai.ArthurAI + ResponseClientError = arthurai.common.exceptions.ResponseClientError + + # connect to Arthur + if arthur_login is None: + try: + arthur_api_key = os.environ["ARTHUR_API_KEY"] + except KeyError: + raise ValueError( + "No Arthur authentication provided. Either give" + " a login to the ArthurCallbackHandler" + " or set an ARTHUR_API_KEY as an environment variable." + ) + arthur = ArthurAI(url=arthur_url, access_key=arthur_api_key) + else: + if arthur_password is None: + arthur = ArthurAI(url=arthur_url, login=arthur_login) + else: + arthur = ArthurAI( + url=arthur_url, login=arthur_login, password=arthur_password + ) + # get model from Arthur by the provided model ID + try: + arthur_model = arthur.get_model(model_id) + except ResponseClientError: + raise ValueError( + f"Was unable to retrieve model with id {model_id} from Arthur." + " Make sure the ID corresponds to a model that is currently" + " registered with your Arthur account." + ) + return cls(arthur_model) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """On LLM start, save the input prompts""" + run_id = kwargs["run_id"] + self.run_map[run_id]["input_texts"] = prompts + self.run_map[run_id]["start_time"] = time() + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """On LLM end, send data to Arthur.""" + try: + import pytz # type: ignore[import] + except ImportError as e: + raise ImportError( + "Could not import pytz. Please install it with 'pip install pytz'." + ) from e + + run_id = kwargs["run_id"] + + # get the run params from this run ID, + # or raise an error if this run ID has no corresponding metadata in self.run_map + try: + run_map_data = self.run_map[run_id] + except KeyError as e: + raise KeyError( + "This function has been called with a run_id" + " that was never registered in on_llm_start()." + " Restart and try running the LLM again" + ) from e + + # mark the duration time between on_llm_start() and on_llm_end() + time_from_start_to_end = time() - run_map_data["start_time"] + + # create inferences to log to Arthur + inferences = [] + for i, generations in enumerate(response.generations): + for generation in generations: + inference = { + "partner_inference_id": str(uuid.uuid4()), + "inference_timestamp": datetime.now(tz=pytz.UTC), + self.input_attr: run_map_data["input_texts"][i], + self.output_attr: generation.text, + } + + if generation.generation_info is not None: + # add finish reason to the inference + # if generation info contains a finish reason and + # if the ArthurModel was registered to monitor finish_reason + if ( + FINISH_REASON in generation.generation_info + and FINISH_REASON in self.attr_names + ): + inference[FINISH_REASON] = generation.generation_info[ + FINISH_REASON + ] + + # add token likelihoods data to the inference if the ArthurModel + # was registered to monitor token likelihoods + logprobs_data = generation.generation_info["logprobs"] + if ( + logprobs_data is not None + and self.token_likelihood_attr is not None + ): + logprobs = logprobs_data["top_logprobs"] + likelihoods = [ + {k: np.exp(v) for k, v in logprobs[i].items()} + for i in range(len(logprobs)) + ] + inference[self.token_likelihood_attr] = likelihoods + + # add token usage counts to the inference if the + # ArthurModel was registered to monitor token usage + if ( + isinstance(response.llm_output, dict) + and TOKEN_USAGE in response.llm_output + ): + token_usage = response.llm_output[TOKEN_USAGE] + if ( + PROMPT_TOKENS in token_usage + and PROMPT_TOKENS in self.attr_names + ): + inference[PROMPT_TOKENS] = token_usage[PROMPT_TOKENS] + if ( + COMPLETION_TOKENS in token_usage + and COMPLETION_TOKENS in self.attr_names + ): + inference[COMPLETION_TOKENS] = token_usage[COMPLETION_TOKENS] + + # add inference duration to the inference if the ArthurModel + # was registered to monitor inference duration + if DURATION in self.attr_names: + inference[DURATION] = time_from_start_to_end + + inferences.append(inference) + + # send inferences to arthur + self.arthur_model.send_inferences(inferences) + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """On chain start, do nothing.""" + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """On chain end, do nothing.""" + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when LLM outputs an error.""" + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """On new token, pass.""" + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when LLM chain outputs an error.""" + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Do nothing when tool starts.""" + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Do nothing when agent takes a specific action.""" + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Do nothing when tool ends.""" + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when tool outputs an error.""" + + def on_text(self, text: str, **kwargs: Any) -> None: + """Do nothing""" + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Do nothing""" diff --git a/libs/community/langchain_community/callbacks/clearml_callback.py b/libs/community/langchain_community/callbacks/clearml_callback.py new file mode 100644 index 00000000000..71f30fccdd8 --- /dev/null +++ b/libs/community/langchain_community/callbacks/clearml_callback.py @@ -0,0 +1,527 @@ +from __future__ import annotations + +import tempfile +from copy import deepcopy +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +from langchain_community.callbacks.utils import ( + BaseMetadataCallbackHandler, + flatten_dict, + hash_string, + import_pandas, + import_spacy, + import_textstat, + load_json, +) + +if TYPE_CHECKING: + import pandas as pd + + +def import_clearml() -> Any: + """Import the clearml python package and raise an error if it is not installed.""" + try: + import clearml # noqa: F401 + except ImportError: + raise ImportError( + "To use the clearml callback manager you need to have the `clearml` python " + "package installed. Please install it with `pip install clearml`" + ) + return clearml + + +class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): + """Callback Handler that logs to ClearML. + + Parameters: + job_type (str): The type of clearml task such as "inference", "testing" or "qc" + project_name (str): The clearml project name + tags (list): Tags to add to the task + task_name (str): Name of the clearml task + visualize (bool): Whether to visualize the run. + complexity_metrics (bool): Whether to log complexity metrics + stream_logs (bool): Whether to stream callback actions to ClearML + + This handler will utilize the associated callback method and formats + the input of each callback function with metadata regarding the state of LLM run, + and adds the response to the list of records for both the {method}_records and + action. It then logs the response to the ClearML console. + """ + + def __init__( + self, + task_type: Optional[str] = "inference", + project_name: Optional[str] = "langchain_callback_demo", + tags: Optional[Sequence] = None, + task_name: Optional[str] = None, + visualize: bool = False, + complexity_metrics: bool = False, + stream_logs: bool = False, + ) -> None: + """Initialize callback handler.""" + + clearml = import_clearml() + spacy = import_spacy() + super().__init__() + + self.task_type = task_type + self.project_name = project_name + self.tags = tags + self.task_name = task_name + self.visualize = visualize + self.complexity_metrics = complexity_metrics + self.stream_logs = stream_logs + + self.temp_dir = tempfile.TemporaryDirectory() + + # Check if ClearML task already exists (e.g. in pipeline) + if clearml.Task.current_task(): + self.task = clearml.Task.current_task() + else: + self.task = clearml.Task.init( # type: ignore + task_type=self.task_type, + project_name=self.project_name, + tags=self.tags, + task_name=self.task_name, + output_uri=True, + ) + self.logger = self.task.get_logger() + warning = ( + "The clearml callback is currently in beta and is subject to change " + "based on updates to `langchain`. Please report any issues to " + "https://github.com/allegroai/clearml/issues with the tag `langchain`." + ) + self.logger.report_text(warning, level=30, print_console=True) + self.callback_columns: list = [] + self.action_records: list = [] + self.complexity_metrics = complexity_metrics + self.visualize = visualize + self.nlp = spacy.load("en_core_web_sm") + + def _init_resp(self) -> Dict: + return {k: None for k in self.callback_columns} + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts.""" + self.step += 1 + self.llm_starts += 1 + self.starts += 1 + + resp = self._init_resp() + resp.update({"action": "on_llm_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + + for prompt in prompts: + prompt_resp = deepcopy(resp) + prompt_resp["prompts"] = prompt + self.on_llm_start_records.append(prompt_resp) + self.action_records.append(prompt_resp) + if self.stream_logs: + self.logger.report_text(prompt_resp) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + self.step += 1 + self.llm_streams += 1 + + resp = self._init_resp() + resp.update({"action": "on_llm_new_token", "token": token}) + resp.update(self.get_custom_callback_meta()) + + self.on_llm_token_records.append(resp) + self.action_records.append(resp) + if self.stream_logs: + self.logger.report_text(resp) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + self.step += 1 + self.llm_ends += 1 + self.ends += 1 + + resp = self._init_resp() + resp.update({"action": "on_llm_end"}) + resp.update(flatten_dict(response.llm_output or {})) + resp.update(self.get_custom_callback_meta()) + + for generations in response.generations: + for generation in generations: + generation_resp = deepcopy(resp) + generation_resp.update(flatten_dict(generation.dict())) + generation_resp.update(self.analyze_text(generation.text)) + self.on_llm_end_records.append(generation_resp) + self.action_records.append(generation_resp) + if self.stream_logs: + self.logger.report_text(generation_resp) + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when LLM errors.""" + self.step += 1 + self.errors += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + self.step += 1 + self.chain_starts += 1 + self.starts += 1 + + resp = self._init_resp() + resp.update({"action": "on_chain_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + + chain_input = inputs.get("input", inputs.get("human_input")) + + if isinstance(chain_input, str): + input_resp = deepcopy(resp) + input_resp["input"] = chain_input + self.on_chain_start_records.append(input_resp) + self.action_records.append(input_resp) + if self.stream_logs: + self.logger.report_text(input_resp) + elif isinstance(chain_input, list): + for inp in chain_input: + input_resp = deepcopy(resp) + input_resp.update(inp) + self.on_chain_start_records.append(input_resp) + self.action_records.append(input_resp) + if self.stream_logs: + self.logger.report_text(input_resp) + else: + raise ValueError("Unexpected data format provided!") + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + self.step += 1 + self.chain_ends += 1 + self.ends += 1 + + resp = self._init_resp() + resp.update( + { + "action": "on_chain_end", + "outputs": outputs.get("output", outputs.get("text")), + } + ) + resp.update(self.get_custom_callback_meta()) + + self.on_chain_end_records.append(resp) + self.action_records.append(resp) + if self.stream_logs: + self.logger.report_text(resp) + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when chain errors.""" + self.step += 1 + self.errors += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + resp = self._init_resp() + resp.update({"action": "on_tool_start", "input_str": input_str}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + + self.on_tool_start_records.append(resp) + self.action_records.append(resp) + if self.stream_logs: + self.logger.report_text(resp) + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + self.step += 1 + self.tool_ends += 1 + self.ends += 1 + + resp = self._init_resp() + resp.update({"action": "on_tool_end", "output": output}) + resp.update(self.get_custom_callback_meta()) + + self.on_tool_end_records.append(resp) + self.action_records.append(resp) + if self.stream_logs: + self.logger.report_text(resp) + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when tool errors.""" + self.step += 1 + self.errors += 1 + + def on_text(self, text: str, **kwargs: Any) -> None: + """ + Run when agent is ending. + """ + self.step += 1 + self.text_ctr += 1 + + resp = self._init_resp() + resp.update({"action": "on_text", "text": text}) + resp.update(self.get_custom_callback_meta()) + + self.on_text_records.append(resp) + self.action_records.append(resp) + if self.stream_logs: + self.logger.report_text(resp) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + self.step += 1 + self.agent_ends += 1 + self.ends += 1 + + resp = self._init_resp() + resp.update( + { + "action": "on_agent_finish", + "output": finish.return_values["output"], + "log": finish.log, + } + ) + resp.update(self.get_custom_callback_meta()) + + self.on_agent_finish_records.append(resp) + self.action_records.append(resp) + if self.stream_logs: + self.logger.report_text(resp) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + resp = self._init_resp() + resp.update( + { + "action": "on_agent_action", + "tool": action.tool, + "tool_input": action.tool_input, + "log": action.log, + } + ) + resp.update(self.get_custom_callback_meta()) + self.on_agent_action_records.append(resp) + self.action_records.append(resp) + if self.stream_logs: + self.logger.report_text(resp) + + def analyze_text(self, text: str) -> dict: + """Analyze text using textstat and spacy. + + Parameters: + text (str): The text to analyze. + + Returns: + (dict): A dictionary containing the complexity metrics. + """ + resp = {} + textstat = import_textstat() + spacy = import_spacy() + if self.complexity_metrics: + text_complexity_metrics = { + "flesch_reading_ease": textstat.flesch_reading_ease(text), + "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), + "smog_index": textstat.smog_index(text), + "coleman_liau_index": textstat.coleman_liau_index(text), + "automated_readability_index": textstat.automated_readability_index( + text + ), + "dale_chall_readability_score": textstat.dale_chall_readability_score( + text + ), + "difficult_words": textstat.difficult_words(text), + "linsear_write_formula": textstat.linsear_write_formula(text), + "gunning_fog": textstat.gunning_fog(text), + "text_standard": textstat.text_standard(text), + "fernandez_huerta": textstat.fernandez_huerta(text), + "szigriszt_pazos": textstat.szigriszt_pazos(text), + "gutierrez_polini": textstat.gutierrez_polini(text), + "crawford": textstat.crawford(text), + "gulpease_index": textstat.gulpease_index(text), + "osman": textstat.osman(text), + } + resp.update(text_complexity_metrics) + + if self.visualize and self.nlp and self.temp_dir.name is not None: + doc = self.nlp(text) + + dep_out = spacy.displacy.render( # type: ignore + doc, style="dep", jupyter=False, page=True + ) + dep_output_path = Path( + self.temp_dir.name, hash_string(f"dep-{text}") + ".html" + ) + dep_output_path.open("w", encoding="utf-8").write(dep_out) + + ent_out = spacy.displacy.render( # type: ignore + doc, style="ent", jupyter=False, page=True + ) + ent_output_path = Path( + self.temp_dir.name, hash_string(f"ent-{text}") + ".html" + ) + ent_output_path.open("w", encoding="utf-8").write(ent_out) + + self.logger.report_media( + "Dependencies Plot", text, local_path=dep_output_path + ) + self.logger.report_media("Entities Plot", text, local_path=ent_output_path) + + return resp + + @staticmethod + def _build_llm_df( + base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping + ) -> pd.DataFrame: + base_df_fields = [field for field in base_df_fields if field in base_df] + rename_map = { + map_entry_k: map_entry_v + for map_entry_k, map_entry_v in rename_map.items() + if map_entry_k in base_df_fields + } + llm_df = base_df[base_df_fields].dropna(axis=1) + if rename_map: + llm_df = llm_df.rename(rename_map, axis=1) + return llm_df + + def _create_session_analysis_df(self) -> Any: + """Create a dataframe with all the information from the session.""" + pd = import_pandas() + on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records) + + llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df( + base_df=on_llm_end_records_df, + base_df_fields=["step", "prompts"] + + (["name"] if "name" in on_llm_end_records_df else ["id"]), + rename_map={"step": "prompt_step"}, + ) + complexity_metrics_columns = [] + visualizations_columns: List = [] + + if self.complexity_metrics: + complexity_metrics_columns = [ + "flesch_reading_ease", + "flesch_kincaid_grade", + "smog_index", + "coleman_liau_index", + "automated_readability_index", + "dale_chall_readability_score", + "difficult_words", + "linsear_write_formula", + "gunning_fog", + "text_standard", + "fernandez_huerta", + "szigriszt_pazos", + "gutierrez_polini", + "crawford", + "gulpease_index", + "osman", + ] + + llm_outputs_df = ClearMLCallbackHandler._build_llm_df( + on_llm_end_records_df, + [ + "step", + "text", + "token_usage_total_tokens", + "token_usage_prompt_tokens", + "token_usage_completion_tokens", + ] + + complexity_metrics_columns + + visualizations_columns, + {"step": "output_step", "text": "output"}, + ) + session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1) + return session_analysis_df + + def flush_tracker( + self, + name: Optional[str] = None, + langchain_asset: Any = None, + finish: bool = False, + ) -> None: + """Flush the tracker and setup the session. + + Everything after this will be a new table. + + Args: + name: Name of the performed session so far so it is identifiable + langchain_asset: The langchain asset to save. + finish: Whether to finish the run. + + Returns: + None + """ + pd = import_pandas() + clearml = import_clearml() + + # Log the action records + self.logger.report_table( + "Action Records", name, table_plot=pd.DataFrame(self.action_records) + ) + + # Session analysis + session_analysis_df = self._create_session_analysis_df() + self.logger.report_table( + "Session Analysis", name, table_plot=session_analysis_df + ) + + if self.stream_logs: + self.logger.report_text( + { + "action_records": pd.DataFrame(self.action_records), + "session_analysis": session_analysis_df, + } + ) + + if langchain_asset: + langchain_asset_path = Path(self.temp_dir.name, "model.json") + try: + langchain_asset.save(langchain_asset_path) + # Create output model and connect it to the task + output_model = clearml.OutputModel( + task=self.task, config_text=load_json(langchain_asset_path) + ) + output_model.update_weights( + weights_filename=str(langchain_asset_path), + auto_delete_file=False, + target_filename=name, + ) + except ValueError: + langchain_asset.save_agent(langchain_asset_path) + output_model = clearml.OutputModel( + task=self.task, config_text=load_json(langchain_asset_path) + ) + output_model.update_weights( + weights_filename=str(langchain_asset_path), + auto_delete_file=False, + target_filename=name, + ) + except NotImplementedError as e: + print("Could not save model.") + print(repr(e)) + pass + + # Cleanup after adding everything to ClearML + self.task.flush(wait_for_uploads=True) + self.temp_dir.cleanup() + self.temp_dir = tempfile.TemporaryDirectory() + self.reset_callback_meta() + + if finish: + self.task.close() diff --git a/libs/community/langchain_community/callbacks/comet_ml_callback.py b/libs/community/langchain_community/callbacks/comet_ml_callback.py new file mode 100644 index 00000000000..5493c947ae8 --- /dev/null +++ b/libs/community/langchain_community/callbacks/comet_ml_callback.py @@ -0,0 +1,645 @@ +import tempfile +from copy import deepcopy +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import Generation, LLMResult + +import langchain_community +from langchain_community.callbacks.utils import ( + BaseMetadataCallbackHandler, + flatten_dict, + import_pandas, + import_spacy, + import_textstat, +) + +LANGCHAIN_MODEL_NAME = "langchain-model" + + +def import_comet_ml() -> Any: + """Import comet_ml and raise an error if it is not installed.""" + try: + import comet_ml # noqa: F401 + except ImportError: + raise ImportError( + "To use the comet_ml callback manager you need to have the " + "`comet_ml` python package installed. Please install it with" + " `pip install comet_ml`" + ) + return comet_ml + + +def _get_experiment( + workspace: Optional[str] = None, project_name: Optional[str] = None +) -> Any: + comet_ml = import_comet_ml() + + experiment = comet_ml.Experiment( # type: ignore + workspace=workspace, + project_name=project_name, + ) + + return experiment + + +def _fetch_text_complexity_metrics(text: str) -> dict: + textstat = import_textstat() + text_complexity_metrics = { + "flesch_reading_ease": textstat.flesch_reading_ease(text), + "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), + "smog_index": textstat.smog_index(text), + "coleman_liau_index": textstat.coleman_liau_index(text), + "automated_readability_index": textstat.automated_readability_index(text), + "dale_chall_readability_score": textstat.dale_chall_readability_score(text), + "difficult_words": textstat.difficult_words(text), + "linsear_write_formula": textstat.linsear_write_formula(text), + "gunning_fog": textstat.gunning_fog(text), + "text_standard": textstat.text_standard(text), + "fernandez_huerta": textstat.fernandez_huerta(text), + "szigriszt_pazos": textstat.szigriszt_pazos(text), + "gutierrez_polini": textstat.gutierrez_polini(text), + "crawford": textstat.crawford(text), + "gulpease_index": textstat.gulpease_index(text), + "osman": textstat.osman(text), + } + return text_complexity_metrics + + +def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict: + pd = import_pandas() + metrics_df = pd.DataFrame(metrics) + metrics_summary = metrics_df.describe() + + return metrics_summary.to_dict() + + +class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): + """Callback Handler that logs to Comet. + + Parameters: + job_type (str): The type of comet_ml task such as "inference", + "testing" or "qc" + project_name (str): The comet_ml project name + tags (list): Tags to add to the task + task_name (str): Name of the comet_ml task + visualize (bool): Whether to visualize the run. + complexity_metrics (bool): Whether to log complexity metrics + stream_logs (bool): Whether to stream callback actions to Comet + + This handler will utilize the associated callback method and formats + the input of each callback function with metadata regarding the state of LLM run, + and adds the response to the list of records for both the {method}_records and + action. It then logs the response to Comet. + """ + + def __init__( + self, + task_type: Optional[str] = "inference", + workspace: Optional[str] = None, + project_name: Optional[str] = None, + tags: Optional[Sequence] = None, + name: Optional[str] = None, + visualizations: Optional[List[str]] = None, + complexity_metrics: bool = False, + custom_metrics: Optional[Callable] = None, + stream_logs: bool = True, + ) -> None: + """Initialize callback handler.""" + + self.comet_ml = import_comet_ml() + super().__init__() + + self.task_type = task_type + self.workspace = workspace + self.project_name = project_name + self.tags = tags + self.visualizations = visualizations + self.complexity_metrics = complexity_metrics + self.custom_metrics = custom_metrics + self.stream_logs = stream_logs + self.temp_dir = tempfile.TemporaryDirectory() + + self.experiment = _get_experiment(workspace, project_name) + self.experiment.log_other("Created from", "langchain") + if tags: + self.experiment.add_tags(tags) + self.name = name + if self.name: + self.experiment.set_name(self.name) + + warning = ( + "The comet_ml callback is currently in beta and is subject to change " + "based on updates to `langchain`. Please report any issues to " + "https://github.com/comet-ml/issue-tracking/issues with the tag " + "`langchain`." + ) + self.comet_ml.LOGGER.warning(warning) + + self.callback_columns: list = [] + self.action_records: list = [] + self.complexity_metrics = complexity_metrics + if self.visualizations: + spacy = import_spacy() + self.nlp = spacy.load("en_core_web_sm") + else: + self.nlp = None + + def _init_resp(self) -> Dict: + return {k: None for k in self.callback_columns} + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts.""" + self.step += 1 + self.llm_starts += 1 + self.starts += 1 + + metadata = self._init_resp() + metadata.update({"action": "on_llm_start"}) + metadata.update(flatten_dict(serialized)) + metadata.update(self.get_custom_callback_meta()) + + for prompt in prompts: + prompt_resp = deepcopy(metadata) + prompt_resp["prompts"] = prompt + self.on_llm_start_records.append(prompt_resp) + self.action_records.append(prompt_resp) + + if self.stream_logs: + self._log_stream(prompt, metadata, self.step) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + self.step += 1 + self.llm_streams += 1 + + resp = self._init_resp() + resp.update({"action": "on_llm_new_token", "token": token}) + resp.update(self.get_custom_callback_meta()) + + self.action_records.append(resp) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + self.step += 1 + self.llm_ends += 1 + self.ends += 1 + + metadata = self._init_resp() + metadata.update({"action": "on_llm_end"}) + metadata.update(flatten_dict(response.llm_output or {})) + metadata.update(self.get_custom_callback_meta()) + + output_complexity_metrics = [] + output_custom_metrics = [] + + for prompt_idx, generations in enumerate(response.generations): + for gen_idx, generation in enumerate(generations): + text = generation.text + + generation_resp = deepcopy(metadata) + generation_resp.update(flatten_dict(generation.dict())) + + complexity_metrics = self._get_complexity_metrics(text) + if complexity_metrics: + output_complexity_metrics.append(complexity_metrics) + generation_resp.update(complexity_metrics) + + custom_metrics = self._get_custom_metrics( + generation, prompt_idx, gen_idx + ) + if custom_metrics: + output_custom_metrics.append(custom_metrics) + generation_resp.update(custom_metrics) + + if self.stream_logs: + self._log_stream(text, metadata, self.step) + + self.action_records.append(generation_resp) + self.on_llm_end_records.append(generation_resp) + + self._log_text_metrics(output_complexity_metrics, step=self.step) + self._log_text_metrics(output_custom_metrics, step=self.step) + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when LLM errors.""" + self.step += 1 + self.errors += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + self.step += 1 + self.chain_starts += 1 + self.starts += 1 + + resp = self._init_resp() + resp.update({"action": "on_chain_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + + for chain_input_key, chain_input_val in inputs.items(): + if isinstance(chain_input_val, str): + input_resp = deepcopy(resp) + if self.stream_logs: + self._log_stream(chain_input_val, resp, self.step) + input_resp.update({chain_input_key: chain_input_val}) + self.action_records.append(input_resp) + + else: + self.comet_ml.LOGGER.warning( + f"Unexpected data format provided! " + f"Input Value for {chain_input_key} will not be logged" + ) + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + self.step += 1 + self.chain_ends += 1 + self.ends += 1 + + resp = self._init_resp() + resp.update({"action": "on_chain_end"}) + resp.update(self.get_custom_callback_meta()) + + for chain_output_key, chain_output_val in outputs.items(): + if isinstance(chain_output_val, str): + output_resp = deepcopy(resp) + if self.stream_logs: + self._log_stream(chain_output_val, resp, self.step) + output_resp.update({chain_output_key: chain_output_val}) + self.action_records.append(output_resp) + else: + self.comet_ml.LOGGER.warning( + f"Unexpected data format provided! " + f"Output Value for {chain_output_key} will not be logged" + ) + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when chain errors.""" + self.step += 1 + self.errors += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + resp = self._init_resp() + resp.update({"action": "on_tool_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(input_str, resp, self.step) + + resp.update({"input_str": input_str}) + self.action_records.append(resp) + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + self.step += 1 + self.tool_ends += 1 + self.ends += 1 + + resp = self._init_resp() + resp.update({"action": "on_tool_end"}) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(output, resp, self.step) + + resp.update({"output": output}) + self.action_records.append(resp) + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when tool errors.""" + self.step += 1 + self.errors += 1 + + def on_text(self, text: str, **kwargs: Any) -> None: + """ + Run when agent is ending. + """ + self.step += 1 + self.text_ctr += 1 + + resp = self._init_resp() + resp.update({"action": "on_text"}) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(text, resp, self.step) + + resp.update({"text": text}) + self.action_records.append(resp) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + self.step += 1 + self.agent_ends += 1 + self.ends += 1 + + resp = self._init_resp() + output = finish.return_values["output"] + log = finish.log + + resp.update({"action": "on_agent_finish", "log": log}) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(output, resp, self.step) + + resp.update({"output": output}) + self.action_records.append(resp) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + tool = action.tool + tool_input = str(action.tool_input) + log = action.log + + resp = self._init_resp() + resp.update({"action": "on_agent_action", "log": log, "tool": tool}) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(tool_input, resp, self.step) + + resp.update({"tool_input": tool_input}) + self.action_records.append(resp) + + def _get_complexity_metrics(self, text: str) -> dict: + """Compute text complexity metrics using textstat. + + Parameters: + text (str): The text to analyze. + + Returns: + (dict): A dictionary containing the complexity metrics. + """ + resp = {} + if self.complexity_metrics: + text_complexity_metrics = _fetch_text_complexity_metrics(text) + resp.update(text_complexity_metrics) + + return resp + + def _get_custom_metrics( + self, generation: Generation, prompt_idx: int, gen_idx: int + ) -> dict: + """Compute Custom Metrics for an LLM Generated Output + + Args: + generation (LLMResult): Output generation from an LLM + prompt_idx (int): List index of the input prompt + gen_idx (int): List index of the generated output + + Returns: + dict: A dictionary containing the custom metrics. + """ + + resp = {} + if self.custom_metrics: + custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx) + resp.update(custom_metrics) + + return resp + + def flush_tracker( + self, + langchain_asset: Any = None, + task_type: Optional[str] = "inference", + workspace: Optional[str] = None, + project_name: Optional[str] = "comet-langchain-demo", + tags: Optional[Sequence] = None, + name: Optional[str] = None, + visualizations: Optional[List[str]] = None, + complexity_metrics: bool = False, + custom_metrics: Optional[Callable] = None, + finish: bool = False, + reset: bool = False, + ) -> None: + """Flush the tracker and setup the session. + + Everything after this will be a new table. + + Args: + name: Name of the performed session so far so it is identifiable + langchain_asset: The langchain asset to save. + finish: Whether to finish the run. + + Returns: + None + """ + self._log_session(langchain_asset) + + if langchain_asset: + try: + self._log_model(langchain_asset) + except Exception: + self.comet_ml.LOGGER.error( + "Failed to export agent or LLM to Comet", + exc_info=True, + extra={"show_traceback": True}, + ) + + if finish: + self.experiment.end() + + if reset: + self._reset( + task_type, + workspace, + project_name, + tags, + name, + visualizations, + complexity_metrics, + custom_metrics, + ) + + def _log_stream(self, prompt: str, metadata: dict, step: int) -> None: + self.experiment.log_text(prompt, metadata=metadata, step=step) + + def _log_model(self, langchain_asset: Any) -> None: + model_parameters = self._get_llm_parameters(langchain_asset) + self.experiment.log_parameters(model_parameters, prefix="model") + + langchain_asset_path = Path(self.temp_dir.name, "model.json") + model_name = self.name if self.name else LANGCHAIN_MODEL_NAME + + try: + if hasattr(langchain_asset, "save"): + langchain_asset.save(langchain_asset_path) + self.experiment.log_model(model_name, str(langchain_asset_path)) + except (ValueError, AttributeError, NotImplementedError) as e: + if hasattr(langchain_asset, "save_agent"): + langchain_asset.save_agent(langchain_asset_path) + self.experiment.log_model(model_name, str(langchain_asset_path)) + else: + self.comet_ml.LOGGER.error( + f"{e}" + " Could not save Langchain Asset " + f"for {langchain_asset.__class__.__name__}" + ) + + def _log_session(self, langchain_asset: Optional[Any] = None) -> None: + try: + llm_session_df = self._create_session_analysis_dataframe(langchain_asset) + # Log the cleaned dataframe as a table + self.experiment.log_table("langchain-llm-session.csv", llm_session_df) + except Exception: + self.comet_ml.LOGGER.warning( + "Failed to log session data to Comet", + exc_info=True, + extra={"show_traceback": True}, + ) + + try: + metadata = {"langchain_version": str(langchain_community.__version__)} + # Log the langchain low-level records as a JSON file directly + self.experiment.log_asset_data( + self.action_records, "langchain-action_records.json", metadata=metadata + ) + except Exception: + self.comet_ml.LOGGER.warning( + "Failed to log session data to Comet", + exc_info=True, + extra={"show_traceback": True}, + ) + + try: + self._log_visualizations(llm_session_df) + except Exception: + self.comet_ml.LOGGER.warning( + "Failed to log visualizations to Comet", + exc_info=True, + extra={"show_traceback": True}, + ) + + def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None: + if not metrics: + return + + metrics_summary = _summarize_metrics_for_generated_outputs(metrics) + for key, value in metrics_summary.items(): + self.experiment.log_metrics(value, prefix=key, step=step) + + def _log_visualizations(self, session_df: Any) -> None: + if not (self.visualizations and self.nlp): + return + + spacy = import_spacy() + + prompts = session_df["prompts"].tolist() + outputs = session_df["text"].tolist() + + for idx, (prompt, output) in enumerate(zip(prompts, outputs)): + doc = self.nlp(output) + sentence_spans = list(doc.sents) + + for visualization in self.visualizations: + try: + html = spacy.displacy.render( + sentence_spans, + style=visualization, + options={"compact": True}, + jupyter=False, + page=True, + ) + self.experiment.log_asset_data( + html, + name=f"langchain-viz-{visualization}-{idx}.html", + metadata={"prompt": prompt}, + step=idx, + ) + except Exception as e: + self.comet_ml.LOGGER.warning( + e, exc_info=True, extra={"show_traceback": True} + ) + + return + + def _reset( + self, + task_type: Optional[str] = None, + workspace: Optional[str] = None, + project_name: Optional[str] = None, + tags: Optional[Sequence] = None, + name: Optional[str] = None, + visualizations: Optional[List[str]] = None, + complexity_metrics: bool = False, + custom_metrics: Optional[Callable] = None, + ) -> None: + _task_type = task_type if task_type else self.task_type + _workspace = workspace if workspace else self.workspace + _project_name = project_name if project_name else self.project_name + _tags = tags if tags else self.tags + _name = name if name else self.name + _visualizations = visualizations if visualizations else self.visualizations + _complexity_metrics = ( + complexity_metrics if complexity_metrics else self.complexity_metrics + ) + _custom_metrics = custom_metrics if custom_metrics else self.custom_metrics + + self.__init__( # type: ignore + task_type=_task_type, + workspace=_workspace, + project_name=_project_name, + tags=_tags, + name=_name, + visualizations=_visualizations, + complexity_metrics=_complexity_metrics, + custom_metrics=_custom_metrics, + ) + + self.reset_callback_meta() + self.temp_dir = tempfile.TemporaryDirectory() + + def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict: + pd = import_pandas() + + llm_parameters = self._get_llm_parameters(langchain_asset) + num_generations_per_prompt = llm_parameters.get("n", 1) + + llm_start_records_df = pd.DataFrame(self.on_llm_start_records) + # Repeat each input row based on the number of outputs generated per prompt + llm_start_records_df = llm_start_records_df.loc[ + llm_start_records_df.index.repeat(num_generations_per_prompt) + ].reset_index(drop=True) + llm_end_records_df = pd.DataFrame(self.on_llm_end_records) + + llm_session_df = pd.merge( + llm_start_records_df, + llm_end_records_df, + left_index=True, + right_index=True, + suffixes=["_llm_start", "_llm_end"], + ) + + return llm_session_df + + def _get_llm_parameters(self, langchain_asset: Any = None) -> dict: + if not langchain_asset: + return {} + try: + if hasattr(langchain_asset, "agent"): + llm_parameters = langchain_asset.agent.llm_chain.llm.dict() + elif hasattr(langchain_asset, "llm_chain"): + llm_parameters = langchain_asset.llm_chain.llm.dict() + elif hasattr(langchain_asset, "llm"): + llm_parameters = langchain_asset.llm.dict() + else: + llm_parameters = langchain_asset.dict() + except Exception: + return {} + + return llm_parameters diff --git a/libs/community/langchain_community/callbacks/confident_callback.py b/libs/community/langchain_community/callbacks/confident_callback.py new file mode 100644 index 00000000000..d9432e5d567 --- /dev/null +++ b/libs/community/langchain_community/callbacks/confident_callback.py @@ -0,0 +1,183 @@ +# flake8: noqa +import os +import warnings +from typing import Any, Dict, List, Optional, Union + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.outputs import LLMResult + + +class DeepEvalCallbackHandler(BaseCallbackHandler): + """Callback Handler that logs into deepeval. + + Args: + implementation_name: name of the `implementation` in deepeval + metrics: A list of metrics + + Raises: + ImportError: if the `deepeval` package is not installed. + + Examples: + >>> from langchain_community.llms import OpenAI + >>> from langchain_community.callbacks import DeepEvalCallbackHandler + >>> from deepeval.metrics import AnswerRelevancy + >>> metric = AnswerRelevancy(minimum_score=0.3) + >>> deepeval_callback = DeepEvalCallbackHandler( + ... implementation_name="exampleImplementation", + ... metrics=[metric], + ... ) + >>> llm = OpenAI( + ... temperature=0, + ... callbacks=[deepeval_callback], + ... verbose=True, + ... openai_api_key="API_KEY_HERE", + ... ) + >>> llm.generate([ + ... "What is the best evaluation tool out there? (no bias at all)", + ... ]) + "Deepeval, no doubt about it." + """ + + REPO_URL: str = "https://github.com/confident-ai/deepeval" + ISSUES_URL: str = f"{REPO_URL}/issues" + BLOG_URL: str = "https://docs.confident-ai.com" # noqa: E501 + + def __init__( + self, + metrics: List[Any], + implementation_name: Optional[str] = None, + ) -> None: + """Initializes the `deepevalCallbackHandler`. + + Args: + implementation_name: Name of the implementation you want. + metrics: What metrics do you want to track? + + Raises: + ImportError: if the `deepeval` package is not installed. + ConnectionError: if the connection to deepeval fails. + """ + + super().__init__() + + # Import deepeval (not via `import_deepeval` to keep hints in IDEs) + try: + import deepeval # ignore: F401,I001 + except ImportError: + raise ImportError( + """To use the deepeval callback manager you need to have the + `deepeval` Python package installed. Please install it with + `pip install deepeval`""" + ) + + if os.path.exists(".deepeval"): + warnings.warn( + """You are currently not logging anything to the dashboard, we + recommend using `deepeval login`.""" + ) + + # Set the deepeval variables + self.implementation_name = implementation_name + self.metrics = metrics + + warnings.warn( + ( + "The `DeepEvalCallbackHandler` is currently in beta and is subject to" + " change based on updates to `langchain`. Please report any issues to" + f" {self.ISSUES_URL} as an `integration` issue." + ), + ) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Store the prompts""" + self.prompts = prompts + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing when a new token is generated.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Log records to deepeval when an LLM ends.""" + from deepeval.metrics.answer_relevancy import AnswerRelevancy + from deepeval.metrics.bias_classifier import UnBiasedMetric + from deepeval.metrics.metric import Metric + from deepeval.metrics.toxic_classifier import NonToxicMetric + + for metric in self.metrics: + for i, generation in enumerate(response.generations): + # Here, we only measure the first generation's output + output = generation[0].text + query = self.prompts[i] + if isinstance(metric, AnswerRelevancy): + result = metric.measure( + output=output, + query=query, + ) + print(f"Answer Relevancy: {result}") + elif isinstance(metric, UnBiasedMetric): + score = metric.measure(output) + print(f"Bias Score: {score}") + elif isinstance(metric, NonToxicMetric): + score = metric.measure(output) + print(f"Toxic Score: {score}") + else: + raise ValueError( + f"""Metric {metric.__name__} is not supported by deepeval + callbacks.""" + ) + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when LLM outputs an error.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Do nothing when chain starts""" + pass + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Do nothing when chain ends.""" + pass + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when LLM chain outputs an error.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Do nothing when tool starts.""" + pass + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Do nothing when agent takes a specific action.""" + pass + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Do nothing when tool ends.""" + pass + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when tool outputs an error.""" + pass + + def on_text(self, text: str, **kwargs: Any) -> None: + """Do nothing""" + pass + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Do nothing""" + pass diff --git a/libs/community/langchain_community/callbacks/context_callback.py b/libs/community/langchain_community/callbacks/context_callback.py new file mode 100644 index 00000000000..8514976687d --- /dev/null +++ b/libs/community/langchain_community/callbacks/context_callback.py @@ -0,0 +1,194 @@ +"""Callback handler for Context AI""" +import os +from typing import Any, Dict, List +from uuid import UUID + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import BaseMessage +from langchain_core.outputs import LLMResult + + +def import_context() -> Any: + """Import the `getcontext` package.""" + try: + import getcontext # noqa: F401 + from getcontext.generated.models import ( + Conversation, + Message, + MessageRole, + Rating, + ) + from getcontext.token import Credential # noqa: F401 + except ImportError: + raise ImportError( + "To use the context callback manager you need to have the " + "`getcontext` python package installed (version >=0.3.0). " + "Please install it with `pip install --upgrade python-context`" + ) + return getcontext, Credential, Conversation, Message, MessageRole, Rating + + +class ContextCallbackHandler(BaseCallbackHandler): + """Callback Handler that records transcripts to the Context service. + + (https://context.ai). + + Keyword Args: + token (optional): The token with which to authenticate requests to Context. + Visit https://with.context.ai/settings to generate a token. + If not provided, the value of the `CONTEXT_TOKEN` environment + variable will be used. + + Raises: + ImportError: if the `context-python` package is not installed. + + Chat Example: + >>> from langchain_community.llms import ChatOpenAI + >>> from langchain_community.callbacks import ContextCallbackHandler + >>> context_callback = ContextCallbackHandler( + ... token="", + ... ) + >>> chat = ChatOpenAI( + ... temperature=0, + ... headers={"user_id": "123"}, + ... callbacks=[context_callback], + ... openai_api_key="API_KEY_HERE", + ... ) + >>> messages = [ + ... SystemMessage(content="You translate English to French."), + ... HumanMessage(content="I love programming with LangChain."), + ... ] + >>> chat(messages) + + Chain Example: + >>> from langchain.chains import LLMChain + >>> from langchain_community.chat_models import ChatOpenAI + >>> from langchain_community.callbacks import ContextCallbackHandler + >>> context_callback = ContextCallbackHandler( + ... token="", + ... ) + >>> human_message_prompt = HumanMessagePromptTemplate( + ... prompt=PromptTemplate( + ... template="What is a good name for a company that makes {product}?", + ... input_variables=["product"], + ... ), + ... ) + >>> chat_prompt_template = ChatPromptTemplate.from_messages( + ... [human_message_prompt] + ... ) + >>> callback = ContextCallbackHandler(token) + >>> # Note: the same callback object must be shared between the + ... LLM and the chain. + >>> chat = ChatOpenAI(temperature=0.9, callbacks=[callback]) + >>> chain = LLMChain( + ... llm=chat, + ... prompt=chat_prompt_template, + ... callbacks=[callback] + ... ) + >>> chain.run("colorful socks") + """ + + def __init__(self, token: str = "", verbose: bool = False, **kwargs: Any) -> None: + ( + self.context, + self.credential, + self.conversation_model, + self.message_model, + self.message_role_model, + self.rating_model, + ) = import_context() + + token = token or os.environ.get("CONTEXT_TOKEN") or "" + + self.client = self.context.ContextAPI(credential=self.credential(token)) + + self.chain_run_id = None + + self.llm_model = None + + self.messages: List[Any] = [] + self.metadata: Dict[str, str] = {} + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: + """Run when the chat model is started.""" + llm_model = kwargs.get("invocation_params", {}).get("model", None) + if llm_model is not None: + self.metadata["model"] = llm_model + + if len(messages) == 0: + return + + for message in messages[0]: + role = self.message_role_model.SYSTEM + if message.type == "human": + role = self.message_role_model.USER + elif message.type == "system": + role = self.message_role_model.SYSTEM + elif message.type == "ai": + role = self.message_role_model.ASSISTANT + + self.messages.append( + self.message_model( + message=message.content, + role=role, + ) + ) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends.""" + if len(response.generations) == 0 or len(response.generations[0]) == 0: + return + + if not self.chain_run_id: + generation = response.generations[0][0] + self.messages.append( + self.message_model( + message=generation.text, + role=self.message_role_model.ASSISTANT, + ) + ) + + self._log_conversation() + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts.""" + self.chain_run_id = kwargs.get("run_id", None) + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends.""" + self.messages.append( + self.message_model( + message=outputs["text"], + role=self.message_role_model.ASSISTANT, + ) + ) + + self._log_conversation() + + self.chain_run_id = None + + def _log_conversation(self) -> None: + """Log the conversation to the context API.""" + if len(self.messages) == 0: + return + + self.client.log.conversation_upsert( + body={ + "conversation": self.conversation_model( + messages=self.messages, + metadata=self.metadata, + ) + } + ) + + self.messages = [] + self.metadata = {} diff --git a/libs/community/langchain_community/callbacks/flyte_callback.py b/libs/community/langchain_community/callbacks/flyte_callback.py new file mode 100644 index 00000000000..23a8f473430 --- /dev/null +++ b/libs/community/langchain_community/callbacks/flyte_callback.py @@ -0,0 +1,371 @@ +"""FlyteKit callback handler.""" +from __future__ import annotations + +import logging +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +from langchain_community.callbacks.utils import ( + BaseMetadataCallbackHandler, + flatten_dict, + import_pandas, + import_spacy, + import_textstat, +) + +if TYPE_CHECKING: + import flytekit + from flytekitplugins.deck import renderer + +logger = logging.getLogger(__name__) + + +def import_flytekit() -> Tuple[flytekit, renderer]: + """Import flytekit and flytekitplugins-deck-standard.""" + try: + import flytekit # noqa: F401 + from flytekitplugins.deck import renderer # noqa: F401 + except ImportError: + raise ImportError( + "To use the flyte callback manager you need" + "to have the `flytekit` and `flytekitplugins-deck-standard`" + "packages installed. Please install them with `pip install flytekit`" + "and `pip install flytekitplugins-deck-standard`." + ) + return flytekit, renderer + + +def analyze_text( + text: str, + nlp: Any = None, + textstat: Any = None, +) -> dict: + """Analyze text using textstat and spacy. + + Parameters: + text (str): The text to analyze. + nlp (spacy.lang): The spacy language model to use for visualization. + + Returns: + (dict): A dictionary containing the complexity metrics and visualization + files serialized to HTML string. + """ + resp: Dict[str, Any] = {} + if textstat is not None: + text_complexity_metrics = { + "flesch_reading_ease": textstat.flesch_reading_ease(text), + "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), + "smog_index": textstat.smog_index(text), + "coleman_liau_index": textstat.coleman_liau_index(text), + "automated_readability_index": textstat.automated_readability_index(text), + "dale_chall_readability_score": textstat.dale_chall_readability_score(text), + "difficult_words": textstat.difficult_words(text), + "linsear_write_formula": textstat.linsear_write_formula(text), + "gunning_fog": textstat.gunning_fog(text), + "fernandez_huerta": textstat.fernandez_huerta(text), + "szigriszt_pazos": textstat.szigriszt_pazos(text), + "gutierrez_polini": textstat.gutierrez_polini(text), + "crawford": textstat.crawford(text), + "gulpease_index": textstat.gulpease_index(text), + "osman": textstat.osman(text), + } + resp.update({"text_complexity_metrics": text_complexity_metrics}) + resp.update(text_complexity_metrics) + + if nlp is not None: + spacy = import_spacy() + doc = nlp(text) + dep_out = spacy.displacy.render( # type: ignore + doc, style="dep", jupyter=False, page=True + ) + ent_out = spacy.displacy.render( # type: ignore + doc, style="ent", jupyter=False, page=True + ) + text_visualizations = { + "dependency_tree": dep_out, + "entities": ent_out, + } + resp.update(text_visualizations) + + return resp + + +class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): + """This callback handler that is used within a Flyte task.""" + + def __init__(self) -> None: + """Initialize callback handler.""" + flytekit, renderer = import_flytekit() + self.pandas = import_pandas() + + self.textstat = None + try: + self.textstat = import_textstat() + except ImportError: + logger.warning( + "Textstat library is not installed. \ + It may result in the inability to log \ + certain metrics that can be captured with Textstat." + ) + + spacy = None + try: + spacy = import_spacy() + except ImportError: + logger.warning( + "Spacy library is not installed. \ + It may result in the inability to log \ + certain metrics that can be captured with Spacy." + ) + + super().__init__() + + self.nlp = None + if spacy: + try: + self.nlp = spacy.load("en_core_web_sm") + except OSError: + logger.warning( + "FlyteCallbackHandler uses spacy's en_core_web_sm model" + " for certain metrics. To download," + " run the following command in your terminal:" + " `python -m spacy download en_core_web_sm`" + ) + + self.table_renderer = renderer.TableRenderer + self.markdown_renderer = renderer.MarkdownRenderer + + self.deck = flytekit.Deck( + "LangChain Metrics", + self.markdown_renderer().to_html("## LangChain Metrics"), + ) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts.""" + + self.step += 1 + self.llm_starts += 1 + self.starts += 1 + + resp: Dict[str, Any] = {} + resp.update({"action": "on_llm_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + + prompt_responses = [] + for prompt in prompts: + prompt_responses.append(prompt) + + resp.update({"prompts": prompt_responses}) + + self.deck.append(self.markdown_renderer().to_html("### LLM Start")) + self.deck.append( + self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" + ) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + self.step += 1 + self.llm_ends += 1 + self.ends += 1 + + resp: Dict[str, Any] = {} + resp.update({"action": "on_llm_end"}) + resp.update(flatten_dict(response.llm_output or {})) + resp.update(self.get_custom_callback_meta()) + + self.deck.append(self.markdown_renderer().to_html("### LLM End")) + self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp]))) + + for generations in response.generations: + for generation in generations: + generation_resp = deepcopy(resp) + generation_resp.update(flatten_dict(generation.dict())) + if self.nlp or self.textstat: + generation_resp.update( + analyze_text( + generation.text, nlp=self.nlp, textstat=self.textstat + ) + ) + + complexity_metrics: Dict[str, float] = generation_resp.pop( + "text_complexity_metrics" + ) # type: ignore # noqa: E501 + self.deck.append( + self.markdown_renderer().to_html("#### Text Complexity Metrics") + ) + self.deck.append( + self.table_renderer().to_html( + self.pandas.DataFrame([complexity_metrics]) + ) + + "\n" + ) + + dependency_tree = generation_resp["dependency_tree"] + self.deck.append( + self.markdown_renderer().to_html("#### Dependency Tree") + ) + self.deck.append(dependency_tree) + + entities = generation_resp["entities"] + self.deck.append(self.markdown_renderer().to_html("#### Entities")) + self.deck.append(entities) + else: + self.deck.append( + self.markdown_renderer().to_html("#### Generated Response") + ) + self.deck.append(self.markdown_renderer().to_html(generation.text)) + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when LLM errors.""" + self.step += 1 + self.errors += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + self.step += 1 + self.chain_starts += 1 + self.starts += 1 + + resp: Dict[str, Any] = {} + resp.update({"action": "on_chain_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + + chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) + input_resp = deepcopy(resp) + input_resp["inputs"] = chain_input + + self.deck.append(self.markdown_renderer().to_html("### Chain Start")) + self.deck.append( + self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n" + ) + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + self.step += 1 + self.chain_ends += 1 + self.ends += 1 + + resp: Dict[str, Any] = {} + chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) + resp.update({"action": "on_chain_end", "outputs": chain_output}) + resp.update(self.get_custom_callback_meta()) + + self.deck.append(self.markdown_renderer().to_html("### Chain End")) + self.deck.append( + self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" + ) + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when chain errors.""" + self.step += 1 + self.errors += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + resp: Dict[str, Any] = {} + resp.update({"action": "on_tool_start", "input_str": input_str}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + + self.deck.append(self.markdown_renderer().to_html("### Tool Start")) + self.deck.append( + self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" + ) + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + self.step += 1 + self.tool_ends += 1 + self.ends += 1 + + resp: Dict[str, Any] = {} + resp.update({"action": "on_tool_end", "output": output}) + resp.update(self.get_custom_callback_meta()) + + self.deck.append(self.markdown_renderer().to_html("### Tool End")) + self.deck.append( + self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" + ) + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when tool errors.""" + self.step += 1 + self.errors += 1 + + def on_text(self, text: str, **kwargs: Any) -> None: + """ + Run when agent is ending. + """ + self.step += 1 + self.text_ctr += 1 + + resp: Dict[str, Any] = {} + resp.update({"action": "on_text", "text": text}) + resp.update(self.get_custom_callback_meta()) + + self.deck.append(self.markdown_renderer().to_html("### On Text")) + self.deck.append( + self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" + ) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + self.step += 1 + self.agent_ends += 1 + self.ends += 1 + + resp: Dict[str, Any] = {} + resp.update( + { + "action": "on_agent_finish", + "output": finish.return_values["output"], + "log": finish.log, + } + ) + resp.update(self.get_custom_callback_meta()) + + self.deck.append(self.markdown_renderer().to_html("### Agent Finish")) + self.deck.append( + self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" + ) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + resp: Dict[str, Any] = {} + resp.update( + { + "action": "on_agent_action", + "tool": action.tool, + "tool_input": action.tool_input, + "log": action.log, + } + ) + resp.update(self.get_custom_callback_meta()) + + self.deck.append(self.markdown_renderer().to_html("### Agent Action")) + self.deck.append( + self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" + ) diff --git a/libs/community/langchain_community/callbacks/human.py b/libs/community/langchain_community/callbacks/human.py new file mode 100644 index 00000000000..64ea01f99f5 --- /dev/null +++ b/libs/community/langchain_community/callbacks/human.py @@ -0,0 +1,88 @@ +from typing import Any, Awaitable, Callable, Dict, Optional +from uuid import UUID + +from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler + + +def _default_approve(_input: str) -> bool: + msg = ( + "Do you approve of the following input? " + "Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no." + ) + msg += "\n\n" + _input + "\n" + resp = input(msg) + return resp.lower() in ("yes", "y") + + +async def _adefault_approve(_input: str) -> bool: + msg = ( + "Do you approve of the following input? " + "Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no." + ) + msg += "\n\n" + _input + "\n" + resp = input(msg) + return resp.lower() in ("yes", "y") + + +def _default_true(_: Dict[str, Any]) -> bool: + return True + + +class HumanRejectedException(Exception): + """Exception to raise when a person manually review and rejects a value.""" + + +class HumanApprovalCallbackHandler(BaseCallbackHandler): + """Callback for manually validating values.""" + + raise_error: bool = True + + def __init__( + self, + approve: Callable[[Any], bool] = _default_approve, + should_check: Callable[[Dict[str, Any]], bool] = _default_true, + ): + self._approve = approve + self._should_check = should_check + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if self._should_check(serialized) and not self._approve(input_str): + raise HumanRejectedException( + f"Inputs {input_str} to tool {serialized} were rejected." + ) + + +class AsyncHumanApprovalCallbackHandler(AsyncCallbackHandler): + """Asynchronous callback for manually validating values.""" + + raise_error: bool = True + + def __init__( + self, + approve: Callable[[Any], Awaitable[bool]] = _adefault_approve, + should_check: Callable[[Dict[str, Any]], bool] = _default_true, + ): + self._approve = approve + self._should_check = should_check + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if self._should_check(serialized) and not await self._approve(input_str): + raise HumanRejectedException( + f"Inputs {input_str} to tool {serialized} were rejected." + ) diff --git a/libs/community/langchain_community/callbacks/infino_callback.py b/libs/community/langchain_community/callbacks/infino_callback.py new file mode 100644 index 00000000000..57d756948ef --- /dev/null +++ b/libs/community/langchain_community/callbacks/infino_callback.py @@ -0,0 +1,266 @@ +import time +from typing import Any, Dict, List, Optional, cast + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import BaseMessage +from langchain_core.outputs import LLMResult + + +def import_infino() -> Any: + """Import the infino client.""" + try: + from infinopy import InfinoClient + except ImportError: + raise ImportError( + "To use the Infino callbacks manager you need to have the" + " `infinopy` python package installed." + "Please install it with `pip install infinopy`" + ) + return InfinoClient() + + +def import_tiktoken() -> Any: + """Import tiktoken for counting tokens for OpenAI models.""" + try: + import tiktoken + except ImportError: + raise ImportError( + "To use the ChatOpenAI model with Infino callback manager, you need to " + "have the `tiktoken` python package installed." + "Please install it with `pip install tiktoken`" + ) + return tiktoken + + +def get_num_tokens(string: str, openai_model_name: str) -> int: + """Calculate num tokens for OpenAI with tiktoken package. + + Official documentation: https://github.com/openai/openai-cookbook/blob/main + /examples/How_to_count_tokens_with_tiktoken.ipynb + """ + tiktoken = import_tiktoken() + + encoding = tiktoken.encoding_for_model(openai_model_name) + num_tokens = len(encoding.encode(string)) + return num_tokens + + +class InfinoCallbackHandler(BaseCallbackHandler): + """Callback Handler that logs to Infino.""" + + def __init__( + self, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + verbose: bool = False, + ) -> None: + # Set Infino client + self.client = import_infino() + self.model_id = model_id + self.model_version = model_version + self.verbose = verbose + self.is_chat_openai_model = False + self.chat_openai_model_name = "gpt-3.5-turbo" + + def _send_to_infino( + self, + key: str, + value: Any, + is_ts: bool = True, + ) -> None: + """Send the key-value to Infino. + + Parameters: + key (str): the key to send to Infino. + value (Any): the value to send to Infino. + is_ts (bool): if True, the value is part of a time series, else it + is sent as a log message. + """ + payload = { + "date": int(time.time()), + key: value, + "labels": { + "model_id": self.model_id, + "model_version": self.model_version, + }, + } + if self.verbose: + print(f"Tracking {key} with Infino: {payload}") + + # Append to Infino time series only if is_ts is True, otherwise + # append to Infino log. + if is_ts: + self.client.append_ts(payload) + else: + self.client.append_log(payload) + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + **kwargs: Any, + ) -> None: + """Log the prompts to Infino, and set start time and error flag.""" + for prompt in prompts: + self._send_to_infino("prompt", prompt, is_ts=False) + + # Set the error flag to indicate no error (this will get overridden + # in on_llm_error if an error occurs). + self.error = 0 + + # Set the start time (so that we can calculate the request + # duration in on_llm_end). + self.start_time = time.time() + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing when a new token is generated.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Log the latency, error, token usage, and response to Infino.""" + # Calculate and track the request latency. + self.end_time = time.time() + duration = self.end_time - self.start_time + self._send_to_infino("latency", duration) + + # Track success or error flag. + self._send_to_infino("error", self.error) + + # Track prompt response. + for generations in response.generations: + for generation in generations: + self._send_to_infino("prompt_response", generation.text, is_ts=False) + + # Track token usage (for non-chat models). + if (response.llm_output is not None) and isinstance(response.llm_output, Dict): + token_usage = response.llm_output["token_usage"] + if token_usage is not None: + prompt_tokens = token_usage["prompt_tokens"] + total_tokens = token_usage["total_tokens"] + completion_tokens = token_usage["completion_tokens"] + self._send_to_infino("prompt_tokens", prompt_tokens) + self._send_to_infino("total_tokens", total_tokens) + self._send_to_infino("completion_tokens", completion_tokens) + + # Track completion token usage (for openai chat models). + if self.is_chat_openai_model: + messages = " ".join( + generation.message.content # type: ignore[attr-defined] + for generation in generations + ) + completion_tokens = get_num_tokens( + messages, openai_model_name=self.chat_openai_model_name + ) + self._send_to_infino("completion_tokens", completion_tokens) + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Set the error flag.""" + self.error = 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Do nothing when LLM chain starts.""" + pass + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Do nothing when LLM chain ends.""" + pass + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Need to log the error.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Do nothing when tool starts.""" + pass + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Do nothing when agent takes a specific action.""" + pass + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Do nothing when tool ends.""" + pass + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing when tool outputs an error.""" + pass + + def on_text(self, text: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any, + ) -> None: + """Run when LLM starts running.""" + + # Currently, for chat models, we only support input prompts for ChatOpenAI. + # Check if this model is a ChatOpenAI model. + values = serialized.get("id") + if values: + for value in values: + if value == "ChatOpenAI": + self.is_chat_openai_model = True + break + + # Track prompt tokens for ChatOpenAI model. + if self.is_chat_openai_model: + invocation_params = kwargs.get("invocation_params") + if invocation_params: + model_name = invocation_params.get("model_name") + if model_name: + self.chat_openai_model_name = model_name + prompt_tokens = 0 + for message_list in messages: + message_string = " ".join( + cast(str, msg.content) for msg in message_list + ) + num_tokens = get_num_tokens( + message_string, + openai_model_name=self.chat_openai_model_name, + ) + prompt_tokens += num_tokens + + self._send_to_infino("prompt_tokens", prompt_tokens) + + if self.verbose: + print( + f"on_chat_model_start: is_chat_openai_model= \ + {self.is_chat_openai_model}, \ + chat_openai_model_name={self.chat_openai_model_name}" + ) + + # Send the prompt to infino + prompt = " ".join( + cast(str, msg.content) for sublist in messages for msg in sublist + ) + self._send_to_infino("prompt", prompt, is_ts=False) + + # Set the error flag to indicate no error (this will get overridden + # in on_llm_error if an error occurs). + self.error = 0 + + # Set the start time (so that we can calculate the request + # duration in on_llm_end). + self.start_time = time.time() diff --git a/libs/community/langchain_community/callbacks/labelstudio_callback.py b/libs/community/langchain_community/callbacks/labelstudio_callback.py new file mode 100644 index 00000000000..73954820b23 --- /dev/null +++ b/libs/community/langchain_community/callbacks/labelstudio_callback.py @@ -0,0 +1,390 @@ +import os +import warnings +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union +from uuid import UUID + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import BaseMessage, ChatMessage +from langchain_core.outputs import Generation, LLMResult + + +class LabelStudioMode(Enum): + """Label Studio mode enumerator.""" + + PROMPT = "prompt" + CHAT = "chat" + + +def get_default_label_configs( + mode: Union[str, LabelStudioMode], +) -> Tuple[str, LabelStudioMode]: + """Get default Label Studio configs for the given mode. + + Parameters: + mode: Label Studio mode ("prompt" or "chat") + + Returns: Tuple of Label Studio config and mode + """ + _default_label_configs = { + LabelStudioMode.PROMPT.value: """ + + + + + + +