mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-11 03:30:09 +00:00
Compare commits
53 Commits
v0.0.350
...
harrison/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d05246462c | ||
|
|
858f4cbce4 | ||
|
|
231891706b | ||
|
|
2bef45074d | ||
|
|
ea2616ae23 | ||
|
|
7e6ca3c2b9 | ||
|
|
db04580dfa | ||
|
|
eb179eb4f3 | ||
|
|
6038e03c44 | ||
|
|
fc174c1e1a | ||
|
|
7178a565f4 | ||
|
|
74782694f4 | ||
|
|
fa6ae6410f | ||
|
|
673ce6aa60 | ||
|
|
90f3424a65 | ||
|
|
b9ef92f2f4 | ||
|
|
df95abb7e7 | ||
|
|
06abff41da | ||
|
|
414bddd5f0 | ||
|
|
0be7e1e397 | ||
|
|
e780433f6b | ||
|
|
6080c98108 | ||
|
|
5da79e150b | ||
|
|
b4e3e47c92 | ||
|
|
d31ff30df6 | ||
|
|
158dda440b | ||
|
|
0dc432aa95 | ||
|
|
b092bfbb3c | ||
|
|
e84a350791 | ||
|
|
1bf84c3056 | ||
|
|
a4992ffada | ||
|
|
a019183a01 | ||
|
|
e5bd88383f | ||
|
|
b885880344 | ||
|
|
1830d5e138 | ||
|
|
aa3d534db9 | ||
|
|
945f6eb5d6 | ||
|
|
159b5cab16 | ||
|
|
bf9853418f | ||
|
|
a8532c176d | ||
|
|
282362382c | ||
|
|
a47f210b38 | ||
|
|
f337284bce | ||
|
|
d6e8cd1641 | ||
|
|
fb73fdf47a | ||
|
|
43b1c3c384 | ||
|
|
12fbd5f670 | ||
|
|
262579ffc3 | ||
|
|
0d71b98f49 | ||
|
|
ca7da8f7ef | ||
|
|
2a10cabf66 | ||
|
|
b72b19b593 | ||
|
|
c32554a3e0 |
8
.github/scripts/check_diff.py
vendored
8
.github/scripts/check_diff.py
vendored
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
ALL_DIRS = {
|
||||
LANGCHAIN_DIRS = {
|
||||
"libs/core",
|
||||
"libs/langchain",
|
||||
"libs/experimental",
|
||||
@@ -23,8 +23,7 @@ if __name__ == "__main__":
|
||||
".github/scripts/check_diff.py",
|
||||
)
|
||||
):
|
||||
dirs_to_run = ALL_DIRS
|
||||
break
|
||||
dirs_to_run.update(LANGCHAIN_DIRS)
|
||||
elif "libs/community" in file:
|
||||
dirs_to_run.update(
|
||||
("libs/community", "libs/langchain", "libs/experimental")
|
||||
@@ -39,8 +38,7 @@ if __name__ == "__main__":
|
||||
elif "libs/experimental" in file:
|
||||
dirs_to_run.add("libs/experimental")
|
||||
elif file.startswith("libs/"):
|
||||
dirs_to_run = ALL_DIRS
|
||||
break
|
||||
dirs_to_run.update(LANGCHAIN_DIRS)
|
||||
else:
|
||||
pass
|
||||
print(json.dumps(list(dirs_to_run)))
|
||||
|
||||
1
.github/workflows/_all_ci.yml
vendored
1
.github/workflows/_all_ci.yml
vendored
@@ -72,6 +72,7 @@ jobs:
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ ! startsWith(inputs.working-directory, 'libs/partners/') }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Main entrypoint into package."""
|
||||
from importlib import metadata
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
@@ -1,79 +0,0 @@
|
||||
"""Agent toolkits contain integrations with various resources and services.
|
||||
|
||||
LangChain has a large ecosystem of integrations with various external resources
|
||||
like local and remote file systems, APIs and databases.
|
||||
|
||||
These integrations allow developers to create versatile applications that combine the
|
||||
power of LLMs with the ability to access, interact with and manipulate external
|
||||
resources.
|
||||
|
||||
When developing an application, developers should inspect the capabilities and
|
||||
permissions of the tools that underlie the given agent toolkit, and determine
|
||||
whether permissions of the given toolkit are appropriate for the application.
|
||||
|
||||
See [Security](https://python.langchain.com/docs/security) for more information.
|
||||
"""
|
||||
from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
|
||||
from langchain_community.agent_toolkits.amadeus.toolkit import AmadeusToolkit
|
||||
from langchain_community.agent_toolkits.azure_cognitive_services import (
|
||||
AzureCognitiveServicesToolkit,
|
||||
)
|
||||
from langchain_community.agent_toolkits.conversational_retrieval.openai_functions import ( # noqa: E501
|
||||
create_conversational_retrieval_agent,
|
||||
)
|
||||
from langchain_community.agent_toolkits.file_management.toolkit import (
|
||||
FileManagementToolkit,
|
||||
)
|
||||
from langchain_community.agent_toolkits.gmail.toolkit import GmailToolkit
|
||||
from langchain_community.agent_toolkits.jira.toolkit import JiraToolkit
|
||||
from langchain_community.agent_toolkits.json.base import create_json_agent
|
||||
from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain_community.agent_toolkits.multion.toolkit import MultionToolkit
|
||||
from langchain_community.agent_toolkits.nasa.toolkit import NasaToolkit
|
||||
from langchain_community.agent_toolkits.nla.toolkit import NLAToolkit
|
||||
from langchain_community.agent_toolkits.office365.toolkit import O365Toolkit
|
||||
from langchain_community.agent_toolkits.openapi.base import create_openapi_agent
|
||||
from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit
|
||||
from langchain_community.agent_toolkits.playwright.toolkit import (
|
||||
PlayWrightBrowserToolkit,
|
||||
)
|
||||
from langchain_community.agent_toolkits.powerbi.base import create_pbi_agent
|
||||
from langchain_community.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent
|
||||
from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
from langchain_community.agent_toolkits.slack.toolkit import SlackToolkit
|
||||
from langchain_community.agent_toolkits.spark_sql.base import create_spark_sql_agent
|
||||
from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
|
||||
from langchain_community.agent_toolkits.sql.base import create_sql_agent
|
||||
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_community.agent_toolkits.steam.toolkit import SteamToolkit
|
||||
from langchain_community.agent_toolkits.zapier.toolkit import ZapierToolkit
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AINetworkToolkit",
|
||||
"AmadeusToolkit",
|
||||
"AzureCognitiveServicesToolkit",
|
||||
"FileManagementToolkit",
|
||||
"GmailToolkit",
|
||||
"JiraToolkit",
|
||||
"JsonToolkit",
|
||||
"MultionToolkit",
|
||||
"NasaToolkit",
|
||||
"NLAToolkit",
|
||||
"O365Toolkit",
|
||||
"OpenAPIToolkit",
|
||||
"PlayWrightBrowserToolkit",
|
||||
"PowerBIToolkit",
|
||||
"SlackToolkit",
|
||||
"SteamToolkit",
|
||||
"SQLDatabaseToolkit",
|
||||
"SparkSQLToolkit",
|
||||
"ZapierToolkit",
|
||||
"create_json_agent",
|
||||
"create_openapi_agent",
|
||||
"create_pbi_agent",
|
||||
"create_pbi_chat_agent",
|
||||
"create_spark_sql_agent",
|
||||
"create_sql_agent",
|
||||
"create_conversational_retrieval_agent",
|
||||
]
|
||||
@@ -1,53 +0,0 @@
|
||||
"""Json agent."""
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_community.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
|
||||
from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
||||
|
||||
def create_json_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: JsonToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = JSON_PREFIX,
|
||||
suffix: str = JSON_SUFFIX,
|
||||
format_instructions: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a json agent from an LLM and tools."""
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.chains.llm import LLMChain
|
||||
tools = toolkit.get_tools()
|
||||
prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {}
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
**prompt_params,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Tool for interacting with a single API with natural language definition."""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.tools import Tool
|
||||
|
||||
from langchain_community.tools.openapi.utils.api_models import APIOperation
|
||||
from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
|
||||
|
||||
|
||||
class NLATool(Tool):
|
||||
"""Natural Language API Tool."""
|
||||
|
||||
@classmethod
|
||||
def from_open_api_endpoint_chain(
|
||||
cls, chain: OpenAPIEndpointChain, api_title: str
|
||||
) -> "NLATool":
|
||||
"""Convert an endpoint chain to an API endpoint tool."""
|
||||
expanded_name = (
|
||||
f'{api_title.replace(" ", "_")}.{chain.api_operation.operation_id}'
|
||||
)
|
||||
description = (
|
||||
f"I'm an AI from {api_title}. Instruct what you want,"
|
||||
" and I'll assist via an API with description:"
|
||||
f" {chain.api_operation.description}"
|
||||
)
|
||||
return cls(name=expanded_name, func=chain.run, description=description)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_method(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
path: str,
|
||||
method: str,
|
||||
spec: OpenAPISpec,
|
||||
requests: Optional[Requests] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "NLATool":
|
||||
"""Instantiate the tool from the specified path and method."""
|
||||
api_operation = APIOperation.from_openapi_spec(spec, path, method)
|
||||
chain = OpenAPIEndpointChain.from_api_operation(
|
||||
api_operation,
|
||||
llm,
|
||||
requests=requests,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
**kwargs,
|
||||
)
|
||||
return cls.from_open_api_endpoint_chain(chain, spec.info.title)
|
||||
@@ -1,77 +0,0 @@
|
||||
"""OpenAPI spec agent."""
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_community.agent_toolkits.openapi.prompt import (
|
||||
OPENAPI_PREFIX,
|
||||
OPENAPI_SUFFIX,
|
||||
)
|
||||
from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
||||
|
||||
def create_openapi_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: OpenAPIToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = OPENAPI_PREFIX,
|
||||
suffix: str = OPENAPI_SUFFIX,
|
||||
format_instructions: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct an OpenAPI agent from an LLM and tools.
|
||||
|
||||
*Security Note*: When creating an OpenAPI agent, check the permissions
|
||||
and capabilities of the underlying toolkit.
|
||||
|
||||
For example, if the default implementation of OpenAPIToolkit
|
||||
uses the RequestsToolkit which contains tools to make arbitrary
|
||||
network requests against any URL (e.g., GET, POST, PATCH, PUT, DELETE),
|
||||
|
||||
Control access to who can submit issue requests using this toolkit and
|
||||
what network access it has.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.chains.llm import LLMChain
|
||||
tools = toolkit.get_tools()
|
||||
prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {}
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
**prompt_params
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -1,370 +0,0 @@
|
||||
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
|
||||
import json
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from langchain_community.llms import OpenAI
|
||||
|
||||
from langchain_community.agent_toolkits.openapi.planner_prompt import (
|
||||
API_CONTROLLER_PROMPT,
|
||||
API_CONTROLLER_TOOL_DESCRIPTION,
|
||||
API_CONTROLLER_TOOL_NAME,
|
||||
API_ORCHESTRATOR_PROMPT,
|
||||
API_PLANNER_PROMPT,
|
||||
API_PLANNER_TOOL_DESCRIPTION,
|
||||
API_PLANNER_TOOL_NAME,
|
||||
PARSING_DELETE_PROMPT,
|
||||
PARSING_GET_PROMPT,
|
||||
PARSING_PATCH_PROMPT,
|
||||
PARSING_POST_PROMPT,
|
||||
PARSING_PUT_PROMPT,
|
||||
REQUESTS_DELETE_TOOL_DESCRIPTION,
|
||||
REQUESTS_GET_TOOL_DESCRIPTION,
|
||||
REQUESTS_PATCH_TOOL_DESCRIPTION,
|
||||
REQUESTS_POST_TOOL_DESCRIPTION,
|
||||
REQUESTS_PUT_TOOL_DESCRIPTION,
|
||||
)
|
||||
from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec
|
||||
from langchain_community.tools.requests.tool import BaseRequestsTool
|
||||
from langchain_community.utilities.requests import RequestsWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory import ReadOnlySharedMemory
|
||||
|
||||
#
|
||||
# Requests tools with LLM-instructed extraction of truncated responses.
|
||||
#
|
||||
# Of course, truncating so bluntly may lose a lot of valuable
|
||||
# information in the response.
|
||||
# However, the goal for now is to have only a single inference step.
|
||||
MAX_RESPONSE_LENGTH = 5000
|
||||
"""Maximum length of the response to be returned."""
|
||||
|
||||
|
||||
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
|
||||
from langchain.chains.llm import LLMChain
|
||||
return LLMChain(
|
||||
llm=OpenAI(),
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def _get_default_llm_chain_factory(
|
||||
prompt: BasePromptTemplate,
|
||||
) -> Callable[[], LLMChain]:
|
||||
"""Returns a default LLMChain factory."""
|
||||
return partial(_get_default_llm_chain, prompt)
|
||||
|
||||
|
||||
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name: str = "requests_get"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
try:
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
data_params = data.get("params")
|
||||
response = self.requests_wrapper.get(data["url"], params=data_params)
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name: str = "requests_post"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
try:
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.post(data["url"], data["data"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name: str = "requests_patch"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
try:
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.patch(data["url"], data["data"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests PUT tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name: str = "requests_put"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PUT_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
try:
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.put(data["url"], data["data"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""A tool that sends a DELETE request and parses the response."""
|
||||
|
||||
name: str = "requests_delete"
|
||||
"""The name of the tool."""
|
||||
description = REQUESTS_DELETE_TOOL_DESCRIPTION
|
||||
"""The description of the tool."""
|
||||
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""The maximum length of the response."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
|
||||
)
|
||||
"""The LLM chain used to parse the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
try:
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.delete(data["url"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
#
|
||||
# Orchestrator, planner, controller.
|
||||
#
|
||||
def _create_api_planner_tool(
|
||||
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
|
||||
) -> Tool:
|
||||
from langchain.chains.llm import LLMChain
|
||||
endpoint_descriptions = [
|
||||
f"{name} {description}" for name, description, _ in api_spec.endpoints
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_PLANNER_PROMPT,
|
||||
input_variables=["query"],
|
||||
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
tool = Tool(
|
||||
name=API_PLANNER_TOOL_NAME,
|
||||
description=API_PLANNER_TOOL_DESCRIPTION,
|
||||
func=chain.run,
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
def _create_api_controller_agent(
|
||||
api_url: str,
|
||||
api_docs: str,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
) -> AgentExecutor:
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.chains.llm import LLMChain
|
||||
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
||||
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||
tools: List[BaseTool] = [
|
||||
RequestsGetToolWithParsing(
|
||||
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
|
||||
),
|
||||
RequestsPostToolWithParsing(
|
||||
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
|
||||
),
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_CONTROLLER_PROMPT,
|
||||
input_variables=["input", "agent_scratchpad"],
|
||||
partial_variables={
|
||||
"api_url": api_url,
|
||||
"api_docs": api_docs,
|
||||
"tool_names": ", ".join([tool.name for tool in tools]),
|
||||
"tool_descriptions": "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in tools]
|
||||
),
|
||||
},
|
||||
)
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=LLMChain(llm=llm, prompt=prompt),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
|
||||
def _create_api_controller_tool(
|
||||
api_spec: ReducedOpenAPISpec,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
) -> Tool:
|
||||
"""Expose controller as a tool.
|
||||
|
||||
The tool is invoked with a plan from the planner, and dynamically
|
||||
creates a controller agent with relevant documentation only to
|
||||
constrain the context.
|
||||
"""
|
||||
|
||||
base_url = api_spec.servers[0]["url"] # TODO: do better.
|
||||
|
||||
def _create_and_run_api_controller_agent(plan_str: str) -> str:
|
||||
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
|
||||
matches = re.findall(pattern, plan_str)
|
||||
endpoint_names = [
|
||||
"{method} {route}".format(method=method, route=route.split("?")[0])
|
||||
for method, route in matches
|
||||
]
|
||||
docs_str = ""
|
||||
for endpoint_name in endpoint_names:
|
||||
found_match = False
|
||||
for name, _, docs in api_spec.endpoints:
|
||||
regex_name = re.compile(re.sub("\{.*?\}", ".*", name))
|
||||
if regex_name.match(endpoint_name):
|
||||
found_match = True
|
||||
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
|
||||
if not found_match:
|
||||
raise ValueError(f"{endpoint_name} endpoint does not exist.")
|
||||
|
||||
agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
|
||||
return agent.run(plan_str)
|
||||
|
||||
return Tool(
|
||||
name=API_CONTROLLER_TOOL_NAME,
|
||||
func=_create_and_run_api_controller_agent,
|
||||
description=API_CONTROLLER_TOOL_DESCRIPTION,
|
||||
)
|
||||
|
||||
|
||||
def create_openapi_agent(
|
||||
api_spec: ReducedOpenAPISpec,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
shared_memory: Optional[ReadOnlySharedMemory] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
verbose: bool = True,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Instantiate OpenAI API planner and controller for a given spec.
|
||||
|
||||
Inject credentials via requests_wrapper.
|
||||
|
||||
We use a top-level "orchestrator" agent to invoke the planner and controller,
|
||||
rather than a top-level planner
|
||||
that invokes a controller with its plan. This is to keep the planner simple.
|
||||
"""
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.chains.llm import LLMChain
|
||||
tools = [
|
||||
_create_api_planner_tool(api_spec, llm),
|
||||
_create_api_controller_tool(api_spec, requests_wrapper, llm),
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_ORCHESTRATOR_PROMPT,
|
||||
input_variables=["input", "agent_scratchpad"],
|
||||
partial_variables={
|
||||
"tool_names": ", ".join([tool.name for tool in tools]),
|
||||
"tool_descriptions": "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in tools]
|
||||
),
|
||||
},
|
||||
)
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Requests toolkit."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.tools import Tool
|
||||
|
||||
from langchain_community.agent_toolkits.base import BaseToolkit
|
||||
from langchain_community.agent_toolkits.json.base import create_json_agent
|
||||
from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain_community.agent_toolkits.openapi.prompt import DESCRIPTION
|
||||
from langchain_community.tools import BaseTool
|
||||
from langchain_community.tools.json.tool import JsonSpec
|
||||
from langchain_community.tools.requests.tool import (
|
||||
RequestsDeleteTool,
|
||||
RequestsGetTool,
|
||||
RequestsPatchTool,
|
||||
RequestsPostTool,
|
||||
RequestsPutTool,
|
||||
)
|
||||
from langchain_community.utilities.requests import TextRequestsWrapper
|
||||
|
||||
|
||||
class RequestsToolkit(BaseToolkit):
|
||||
"""Toolkit for making REST requests.
|
||||
|
||||
*Security Note*: This toolkit contains tools to make GET, POST, PATCH, PUT,
|
||||
and DELETE requests to an API.
|
||||
|
||||
Exercise care in who is allowed to use this toolkit. If exposing
|
||||
to end users, consider that users will be able to make arbitrary
|
||||
requests on behalf of the server hosting the code. For example,
|
||||
users could ask the server to make a request to a private API
|
||||
that is only accessible from the server.
|
||||
|
||||
Control access to who can submit issue requests using this toolkit and
|
||||
what network access it has.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
requests_wrapper: TextRequestsWrapper
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Return a list of tools."""
|
||||
return [
|
||||
RequestsGetTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPostTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPatchTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPutTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsDeleteTool(requests_wrapper=self.requests_wrapper),
|
||||
]
|
||||
|
||||
|
||||
class OpenAPIToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with an OpenAPI API.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by creating, deleting, or updating,
|
||||
reading underlying data.
|
||||
|
||||
For example, this toolkit can be used to delete data exposed via
|
||||
an OpenAPI compliant API.
|
||||
"""
|
||||
|
||||
json_agent: Any
|
||||
requests_wrapper: TextRequestsWrapper
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
json_agent_tool = Tool(
|
||||
name="json_explorer",
|
||||
func=self.json_agent.run,
|
||||
description=DESCRIPTION,
|
||||
)
|
||||
request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper)
|
||||
return [*request_toolkit.get_tools(), json_agent_tool]
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
json_spec: JsonSpec,
|
||||
requests_wrapper: TextRequestsWrapper,
|
||||
**kwargs: Any,
|
||||
) -> OpenAPIToolkit:
|
||||
"""Create json agent from llm, then initialize."""
|
||||
json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs)
|
||||
return cls(json_agent=json_agent, requests_wrapper=requests_wrapper)
|
||||
@@ -1,68 +0,0 @@
|
||||
"""Power BI agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_community.agent_toolkits.powerbi.prompt import (
|
||||
POWERBI_PREFIX,
|
||||
POWERBI_SUFFIX,
|
||||
)
|
||||
from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
from langchain_community.utilities.powerbi import PowerBIDataset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents import AgentExecutor
|
||||
|
||||
|
||||
def create_pbi_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: Optional[PowerBIToolkit] = None,
|
||||
powerbi: Optional[PowerBIDataset] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = POWERBI_PREFIX,
|
||||
suffix: str = POWERBI_SUFFIX,
|
||||
format_instructions: Optional[str] = None,
|
||||
examples: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
top_k: int = 10,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a Power BI agent from an LLM and tools."""
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chains.llm import LLMChain
|
||||
if toolkit is None:
|
||||
if powerbi is None:
|
||||
raise ValueError("Must provide either a toolkit or powerbi dataset")
|
||||
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
|
||||
tools = toolkit.get_tools()
|
||||
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
|
||||
prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {}
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=LLMChain(
|
||||
llm=llm,
|
||||
prompt=ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix.format(top_k=top_k).format(tables=tables),
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
**prompt_params,
|
||||
),
|
||||
callback_manager=callback_manager, # type: ignore
|
||||
verbose=verbose,
|
||||
),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -1,69 +0,0 @@
|
||||
"""Power BI agent."""
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
from langchain_community.agent_toolkits.powerbi.prompt import (
|
||||
POWERBI_CHAT_PREFIX,
|
||||
POWERBI_CHAT_SUFFIX,
|
||||
)
|
||||
from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
from langchain_community.utilities.powerbi import PowerBIDataset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
|
||||
def create_pbi_chat_agent(
|
||||
llm: BaseChatModel,
|
||||
toolkit: Optional[PowerBIToolkit] = None,
|
||||
powerbi: Optional[PowerBIDataset] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = POWERBI_CHAT_PREFIX,
|
||||
suffix: str = POWERBI_CHAT_SUFFIX,
|
||||
examples: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory: Optional[BaseChatMemory] = None,
|
||||
top_k: int = 10,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a Power BI agent from a Chat LLM and tools.
|
||||
|
||||
If you supply only a toolkit and no Power BI dataset, the same LLM is used for both.
|
||||
"""
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.conversational_chat.base import ConversationalChatAgent
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
if toolkit is None:
|
||||
if powerbi is None:
|
||||
raise ValueError("Must provide either a toolkit or powerbi dataset")
|
||||
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
|
||||
tools = toolkit.get_tools()
|
||||
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
|
||||
agent = ConversationalChatAgent.from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
system_message=prefix.format(top_k=top_k).format(tables=tables),
|
||||
human_message=suffix,
|
||||
input_variables=input_variables,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
memory=memory
|
||||
or ConversationBufferMemory(memory_key="chat_history", return_messages=True),
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Toolkit for interacting with a Power BI dataset."""
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_community.agent_toolkits.base import BaseToolkit
|
||||
from langchain_community.tools import BaseTool
|
||||
from langchain_community.tools.powerbi.prompt import (
|
||||
QUESTION_TO_QUERY_BASE,
|
||||
SINGLE_QUESTION_TO_QUERY,
|
||||
USER_INPUT,
|
||||
)
|
||||
from langchain_community.tools.powerbi.tool import (
|
||||
InfoPowerBITool,
|
||||
ListPowerBITool,
|
||||
QueryPowerBITool,
|
||||
)
|
||||
from langchain_community.utilities.powerbi import PowerBIDataset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
class PowerBIToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with Power BI dataset.
|
||||
|
||||
*Security Note*: This toolkit interacts with an external service.
|
||||
|
||||
Control access to who can use this toolkit.
|
||||
|
||||
Make sure that the capabilities given by this toolkit to the calling
|
||||
code are appropriately scoped to the application.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
powerbi: PowerBIDataset = Field(exclude=True)
|
||||
llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True)
|
||||
examples: Optional[str] = None
|
||||
max_iterations: int = 5
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
output_token_limit: Optional[int] = None
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
QueryPowerBITool(
|
||||
llm_chain=self._get_chain(),
|
||||
powerbi=self.powerbi,
|
||||
examples=self.examples,
|
||||
max_iterations=self.max_iterations,
|
||||
output_token_limit=self.output_token_limit,
|
||||
tiktoken_model_name=self.tiktoken_model_name,
|
||||
),
|
||||
InfoPowerBITool(powerbi=self.powerbi),
|
||||
ListPowerBITool(powerbi=self.powerbi),
|
||||
]
|
||||
|
||||
def _get_chain(self) -> LLMChain:
|
||||
"""Construct the chain based on the callback manager and model type."""
|
||||
from langchain.chains.llm import LLMChain
|
||||
if isinstance(self.llm, BaseLanguageModel):
|
||||
return LLMChain(
|
||||
llm=self.llm,
|
||||
callback_manager=self.callback_manager
|
||||
if self.callback_manager
|
||||
else None,
|
||||
prompt=PromptTemplate(
|
||||
template=SINGLE_QUESTION_TO_QUERY,
|
||||
input_variables=["tool_input", "tables", "schemas", "examples"],
|
||||
),
|
||||
)
|
||||
|
||||
system_prompt = SystemMessagePromptTemplate(
|
||||
prompt=PromptTemplate(
|
||||
template=QUESTION_TO_QUERY_BASE,
|
||||
input_variables=["tables", "schemas", "examples"],
|
||||
)
|
||||
)
|
||||
human_prompt = HumanMessagePromptTemplate(
|
||||
prompt=PromptTemplate(
|
||||
template=USER_INPUT,
|
||||
input_variables=["tool_input"],
|
||||
)
|
||||
)
|
||||
return LLMChain(
|
||||
llm=self.llm,
|
||||
callback_manager=self.callback_manager if self.callback_manager else None,
|
||||
prompt=ChatPromptTemplate.from_messages([system_prompt, human_prompt]),
|
||||
)
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Spark SQL agent."""
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_community.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
|
||||
from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
||||
|
||||
def create_spark_sql_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: SparkSQLToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
prefix: str = SQL_PREFIX,
|
||||
suffix: str = SQL_SUFFIX,
|
||||
format_instructions: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
top_k: int = 10,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a Spark SQL agent from an LLM and tools."""
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.chains.llm import LLMChain
|
||||
tools = toolkit.get_tools()
|
||||
prefix = prefix.format(top_k=top_k)
|
||||
prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {}
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
**prompt_params,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
verbose=verbose,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -1,102 +0,0 @@
|
||||
"""SQL agent."""
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
|
||||
from langchain_community.agent_toolkits.sql.prompt import (
|
||||
SQL_FUNCTIONS_SUFFIX,
|
||||
SQL_PREFIX,
|
||||
SQL_SUFFIX,
|
||||
)
|
||||
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_community.tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_types import AgentType
|
||||
|
||||
|
||||
def create_sql_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: SQLDatabaseToolkit,
|
||||
agent_type: Optional[AgentType] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = SQL_PREFIX,
|
||||
suffix: Optional[str] = None,
|
||||
format_instructions: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
top_k: int = 10,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct an SQL agent from an LLM and tools."""
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.chains.llm import LLMChain
|
||||
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
tools = toolkit.get_tools() + list(extra_tools)
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
agent: BaseSingleActionAgent
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt_params = {"format_instructions": format_instructions} if format_instructions is not None else {}
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix or SQL_SUFFIX,
|
||||
input_variables=input_variables,
|
||||
**prompt_params,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
messages = [
|
||||
SystemMessage(content=prefix),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -1,66 +0,0 @@
|
||||
"""**Callback handlers** allow listening to events in LangChain.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
|
||||
"""
|
||||
|
||||
from langchain_community.callbacks.aim_callback import AimCallbackHandler
|
||||
from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler
|
||||
from langchain_community.callbacks.arize_callback import ArizeCallbackHandler
|
||||
from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler
|
||||
from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler
|
||||
from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler
|
||||
from langchain_community.callbacks.context_callback import ContextCallbackHandler
|
||||
from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler
|
||||
from langchain_community.callbacks.human import HumanApprovalCallbackHandler
|
||||
from langchain_community.callbacks.infino_callback import InfinoCallbackHandler
|
||||
from langchain_community.callbacks.labelstudio_callback import (
|
||||
LabelStudioCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.llmonitor_callback import LLMonitorCallbackHandler
|
||||
from langchain_community.callbacks.manager import (
|
||||
get_openai_callback,
|
||||
wandb_tracing_enabled,
|
||||
)
|
||||
from langchain_community.callbacks.mlflow_callback import MlflowCallbackHandler
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain_community.callbacks.promptlayer_callback import (
|
||||
PromptLayerCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
|
||||
from langchain_community.callbacks.streamlit import (
|
||||
LLMThoughtLabeler,
|
||||
StreamlitCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.trubrics_callback import TrubricsCallbackHandler
|
||||
from langchain_community.callbacks.wandb_callback import WandbCallbackHandler
|
||||
from langchain_community.callbacks.whylabs_callback import WhyLabsCallbackHandler
|
||||
|
||||
__all__ = [
|
||||
"AimCallbackHandler",
|
||||
"ArgillaCallbackHandler",
|
||||
"ArizeCallbackHandler",
|
||||
"PromptLayerCallbackHandler",
|
||||
"ArthurCallbackHandler",
|
||||
"ClearMLCallbackHandler",
|
||||
"CometCallbackHandler",
|
||||
"ContextCallbackHandler",
|
||||
"HumanApprovalCallbackHandler",
|
||||
"InfinoCallbackHandler",
|
||||
"MlflowCallbackHandler",
|
||||
"LLMonitorCallbackHandler",
|
||||
"OpenAICallbackHandler",
|
||||
"LLMThoughtLabeler",
|
||||
"StreamlitCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
"WhyLabsCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
"LabelStudioCallbackHandler",
|
||||
"TrubricsCallbackHandler",
|
||||
]
|
||||
@@ -1,69 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
Generator,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain_core.tracers.context import register_configure_hook
|
||||
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain_community.callbacks.tracers.wandb import WandbTracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
)
|
||||
wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( # noqa: E501
|
||||
"tracing_wandb_callback", default=None
|
||||
)
|
||||
|
||||
register_configure_hook(openai_callback_var, True)
|
||||
register_configure_hook(
|
||||
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get the OpenAI callback handler in a context manager.
|
||||
which conveniently exposes token and cost information.
|
||||
|
||||
Returns:
|
||||
OpenAICallbackHandler: The OpenAI callback handler.
|
||||
|
||||
Example:
|
||||
>>> with get_openai_callback() as cb:
|
||||
... # Use the OpenAI callback handler
|
||||
"""
|
||||
cb = OpenAICallbackHandler()
|
||||
openai_callback_var.set(cb)
|
||||
yield cb
|
||||
openai_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def wandb_tracing_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[None, None, None]:
|
||||
"""Get the WandbTracer in a context manager.
|
||||
|
||||
Args:
|
||||
session_name (str, optional): The name of the session.
|
||||
Defaults to "default".
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Example:
|
||||
>>> with wandb_tracing_enabled() as session:
|
||||
... # Use the WandbTracer session
|
||||
"""
|
||||
cb = WandbTracer()
|
||||
wandb_tracing_callback_var.set(cb)
|
||||
yield None
|
||||
wandb_tracing_callback_var.set(None)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Tracers that record execution of LangChain runs."""
|
||||
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
||||
from langchain_core.tracers.stdout import (
|
||||
ConsoleCallbackHandler,
|
||||
FunctionCallbackHandler,
|
||||
)
|
||||
|
||||
from langchain_community.callbacks.tracers.wandb import WandbTracer
|
||||
|
||||
__all__ = [
|
||||
"ConsoleCallbackHandler",
|
||||
"FunctionCallbackHandler",
|
||||
"LangChainTracer",
|
||||
"LangChainTracerV1",
|
||||
"WandbTracer",
|
||||
]
|
||||
@@ -1,101 +0,0 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.blob_loaders import Blob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
|
||||
class BaseLoader(ABC):
|
||||
"""Interface for Document Loader.
|
||||
|
||||
Implementations should implement the lazy-loading method using generators
|
||||
to avoid loading all Documents into memory at once.
|
||||
|
||||
The `load` method will remain as is for backwards compatibility, but its
|
||||
implementation should be just `list(self.lazy_load())`.
|
||||
"""
|
||||
|
||||
# Sub-classes should implement this method
|
||||
# as return list(self.lazy_load()).
|
||||
# This method returns a List which is materialized in memory.
|
||||
@abstractmethod
|
||||
def load(self) -> List[Document]:
|
||||
"""Load data into Document objects."""
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
||||
|
||||
# Attention: This method will be upgraded into an abstractmethod once it's
|
||||
# implemented in all the existing subclasses.
|
||||
def lazy_load(
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
"""A lazy loader for Documents."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement lazy_load()"
|
||||
)
|
||||
|
||||
|
||||
class BaseBlobParser(ABC):
|
||||
"""Abstract interface for blob parsers.
|
||||
|
||||
A blob parser provides a way to parse raw data stored in a blob into one
|
||||
or more documents.
|
||||
|
||||
The parser can be composed with blob loaders, making it easy to reuse
|
||||
a parser independent of how the blob was originally loaded.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazy parsing interface.
|
||||
|
||||
Subclasses are required to implement this method.
|
||||
|
||||
Args:
|
||||
blob: Blob instance
|
||||
|
||||
Returns:
|
||||
Generator of documents
|
||||
"""
|
||||
|
||||
def parse(self, blob: Blob) -> List[Document]:
|
||||
"""Eagerly parse the blob into a document or documents.
|
||||
|
||||
This is a convenience method for interactive development environment.
|
||||
|
||||
Production applications should favor the lazy_parse method instead.
|
||||
|
||||
Subclasses should generally not over-ride this parse method.
|
||||
|
||||
Args:
|
||||
blob: Blob instance
|
||||
|
||||
Returns:
|
||||
List of documents
|
||||
"""
|
||||
return list(self.lazy_parse(blob))
|
||||
@@ -1,147 +0,0 @@
|
||||
"""Use to load blobs from the local file system."""
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union
|
||||
|
||||
from langchain_community.document_loaders.blob_loaders.schema import Blob, BlobLoader
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _make_iterator(
|
||||
length_func: Callable[[], int], show_progress: bool = False
|
||||
) -> Callable[[Iterable[T]], Iterator[T]]:
|
||||
"""Create a function that optionally wraps an iterable in tqdm."""
|
||||
if show_progress:
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You must install tqdm to use show_progress=True."
|
||||
"You can install tqdm with `pip install tqdm`."
|
||||
)
|
||||
|
||||
# Make sure to provide `total` here so that tqdm can show
|
||||
# a progress bar that takes into account the total number of files.
|
||||
def _with_tqdm(iterable: Iterable[T]) -> Iterator[T]:
|
||||
"""Wrap an iterable in a tqdm progress bar."""
|
||||
return tqdm(iterable, total=length_func())
|
||||
|
||||
iterator = _with_tqdm
|
||||
else:
|
||||
iterator = iter # type: ignore
|
||||
|
||||
return iterator
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
class FileSystemBlobLoader(BlobLoader):
|
||||
"""Load blobs in the local file system.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader
|
||||
loader = FileSystemBlobLoader("/path/to/directory")
|
||||
for blob in loader.yield_blobs():
|
||||
print(blob)
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
glob: str = "**/[!.]*",
|
||||
exclude: Sequence[str] = (),
|
||||
suffixes: Optional[Sequence[str]] = None,
|
||||
show_progress: bool = False,
|
||||
) -> None:
|
||||
"""Initialize with a path to directory and how to glob over it.
|
||||
|
||||
Args:
|
||||
path: Path to directory to load from or path to file to load.
|
||||
If a path to a file is provided, glob/exclude/suffixes are ignored.
|
||||
glob: Glob pattern relative to the specified path
|
||||
by default set to pick up all non-hidden files
|
||||
exclude: patterns to exclude from results, use glob syntax
|
||||
suffixes: Provide to keep only files with these suffixes
|
||||
Useful when wanting to keep files with different suffixes
|
||||
Suffixes must include the dot, e.g. ".txt"
|
||||
show_progress: If true, will show a progress bar as the files are loaded.
|
||||
This forces an iteration through all matching files
|
||||
to count them prior to loading them.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader
|
||||
|
||||
# Load a single file.
|
||||
loader = FileSystemBlobLoader("/path/to/file.txt")
|
||||
|
||||
# Recursively load all text files in a directory.
|
||||
loader = FileSystemBlobLoader("/path/to/directory", glob="**/*.txt")
|
||||
|
||||
# Recursively load all non-hidden files in a directory.
|
||||
loader = FileSystemBlobLoader("/path/to/directory", glob="**/[!.]*")
|
||||
|
||||
# Load all files in a directory without recursion.
|
||||
loader = FileSystemBlobLoader("/path/to/directory", glob="*")
|
||||
|
||||
# Recursively load all files in a directory, except for py or pyc files.
|
||||
loader = FileSystemBlobLoader(
|
||||
"/path/to/directory",
|
||||
glob="**/*.txt",
|
||||
exclude=["**/*.py", "**/*.pyc"]
|
||||
)
|
||||
""" # noqa: E501
|
||||
if isinstance(path, Path):
|
||||
_path = path
|
||||
elif isinstance(path, str):
|
||||
_path = Path(path)
|
||||
else:
|
||||
raise TypeError(f"Expected str or Path, got {type(path)}")
|
||||
|
||||
self.path = _path.expanduser() # Expand user to handle ~
|
||||
self.glob = glob
|
||||
self.suffixes = set(suffixes or [])
|
||||
self.show_progress = show_progress
|
||||
self.exclude = exclude
|
||||
|
||||
def yield_blobs(
|
||||
self,
|
||||
) -> Iterable[Blob]:
|
||||
"""Yield blobs that match the requested pattern."""
|
||||
iterator = _make_iterator(
|
||||
length_func=self.count_matching_files, show_progress=self.show_progress
|
||||
)
|
||||
|
||||
for path in iterator(self._yield_paths()):
|
||||
yield Blob.from_path(path)
|
||||
|
||||
def _yield_paths(self) -> Iterable[Path]:
|
||||
"""Yield paths that match the requested pattern."""
|
||||
if self.path.is_file():
|
||||
yield self.path
|
||||
return
|
||||
|
||||
paths = self.path.glob(self.glob)
|
||||
for path in paths:
|
||||
if self.exclude:
|
||||
if any(path.match(glob) for glob in self.exclude):
|
||||
continue
|
||||
if path.is_file():
|
||||
if self.suffixes and path.suffix not in self.suffixes:
|
||||
continue
|
||||
yield path
|
||||
|
||||
def count_matching_files(self) -> int:
|
||||
"""Count files that match the pattern without loading them."""
|
||||
# Carry out a full iteration to count the files without
|
||||
# materializing anything expensive in memory.
|
||||
num = 0
|
||||
for _ in self._yield_paths():
|
||||
num += 1
|
||||
return num
|
||||
@@ -1,190 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader
|
||||
from langchain_community.document_loaders.blob_loaders import (
|
||||
BlobLoader,
|
||||
FileSystemBlobLoader,
|
||||
)
|
||||
from langchain_community.document_loaders.parsers.registry import get_parser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
_PathLike = Union[str, Path]
|
||||
|
||||
DEFAULT = Literal["default"]
|
||||
|
||||
|
||||
class GenericLoader(BaseLoader):
|
||||
"""Generic Document Loader.
|
||||
|
||||
A generic document loader that allows combining an arbitrary blob loader with
|
||||
a blob parser.
|
||||
|
||||
Examples:
|
||||
|
||||
Parse a specific PDF file:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.document_loaders import GenericLoader
|
||||
from langchain_community.document_loaders.parsers.pdf import PyPDFParser
|
||||
|
||||
# Recursively load all text files in a directory.
|
||||
loader = GenericLoader.from_filesystem(
|
||||
"my_lovely_pdf.pdf",
|
||||
parser=PyPDFParser()
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.document_loaders import GenericLoader
|
||||
from langchain_community.document_loaders.blob_loaders import FileSystemBlobLoader
|
||||
|
||||
|
||||
loader = GenericLoader.from_filesystem(
|
||||
path="path/to/directory",
|
||||
glob="**/[!.]*",
|
||||
suffixes=[".pdf"],
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
docs = loader.lazy_load()
|
||||
next(docs)
|
||||
|
||||
Example instantiations to change which files are loaded:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Recursively load all text files in a directory.
|
||||
loader = GenericLoader.from_filesystem("/path/to/dir", glob="**/*.txt")
|
||||
|
||||
# Recursively load all non-hidden files in a directory.
|
||||
loader = GenericLoader.from_filesystem("/path/to/dir", glob="**/[!.]*")
|
||||
|
||||
# Load all files in a directory without recursion.
|
||||
loader = GenericLoader.from_filesystem("/path/to/dir", glob="*")
|
||||
|
||||
Example instantiations to change which parser is used:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.document_loaders.parsers.pdf import PyPDFParser
|
||||
|
||||
# Recursively load all text files in a directory.
|
||||
loader = GenericLoader.from_filesystem(
|
||||
"/path/to/dir",
|
||||
glob="**/*.pdf",
|
||||
parser=PyPDFParser()
|
||||
)
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blob_loader: BlobLoader,
|
||||
blob_parser: BaseBlobParser,
|
||||
) -> None:
|
||||
"""A generic document loader.
|
||||
|
||||
Args:
|
||||
blob_loader: A blob loader which knows how to yield blobs
|
||||
blob_parser: A blob parser which knows how to parse blobs into documents
|
||||
"""
|
||||
self.blob_loader = blob_loader
|
||||
self.blob_parser = blob_parser
|
||||
|
||||
def lazy_load(
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
"""Load documents lazily. Use this when working at a large scale."""
|
||||
for blob in self.blob_loader.yield_blobs():
|
||||
yield from self.blob_parser.lazy_parse(blob)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load all documents."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load all documents and split them into sentences."""
|
||||
raise NotImplementedError(
|
||||
"Loading and splitting is not yet implemented for generic loaders. "
|
||||
"When they will be implemented they will be added via the initializer. "
|
||||
"This method should not be used going forward."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_filesystem(
|
||||
cls,
|
||||
path: _PathLike,
|
||||
*,
|
||||
glob: str = "**/[!.]*",
|
||||
exclude: Sequence[str] = (),
|
||||
suffixes: Optional[Sequence[str]] = None,
|
||||
show_progress: bool = False,
|
||||
parser: Union[DEFAULT, BaseBlobParser] = "default",
|
||||
parser_kwargs: Optional[dict] = None,
|
||||
) -> GenericLoader:
|
||||
"""Create a generic document loader using a filesystem blob loader.
|
||||
|
||||
Args:
|
||||
path: The path to the directory to load documents from OR the path to a
|
||||
single file to load. If this is a file, glob, exclude, suffixes
|
||||
will be ignored.
|
||||
glob: The glob pattern to use to find documents.
|
||||
suffixes: The suffixes to use to filter documents. If None, all files
|
||||
matching the glob will be loaded.
|
||||
exclude: A list of patterns to exclude from the loader.
|
||||
show_progress: Whether to show a progress bar or not (requires tqdm).
|
||||
Proxies to the file system loader.
|
||||
parser: A blob parser which knows how to parse blobs into documents,
|
||||
will instantiate a default parser if not provided.
|
||||
The default can be overridden by either passing a parser or
|
||||
setting the class attribute `blob_parser` (the latter
|
||||
should be used with inheritance).
|
||||
parser_kwargs: Keyword arguments to pass to the parser.
|
||||
|
||||
Returns:
|
||||
A generic document loader.
|
||||
"""
|
||||
blob_loader = FileSystemBlobLoader(
|
||||
path,
|
||||
glob=glob,
|
||||
exclude=exclude,
|
||||
suffixes=suffixes,
|
||||
show_progress=show_progress,
|
||||
)
|
||||
if isinstance(parser, str):
|
||||
if parser == "default":
|
||||
try:
|
||||
# If there is an implementation of get_parser on the class, use it.
|
||||
blob_parser = cls.get_parser(**(parser_kwargs or {}))
|
||||
except NotImplementedError:
|
||||
# if not then use the global registry.
|
||||
blob_parser = get_parser(parser)
|
||||
else:
|
||||
blob_parser = get_parser(parser)
|
||||
else:
|
||||
blob_parser = parser
|
||||
return cls(blob_loader, blob_parser)
|
||||
|
||||
@staticmethod
|
||||
def get_parser(**kwargs: Any) -> BaseBlobParser:
|
||||
"""Override this method to associate a default parser with the class."""
|
||||
raise NotImplementedError()
|
||||
@@ -1,70 +0,0 @@
|
||||
"""Code for generic / auxiliary parsers.
|
||||
|
||||
This module contains some logic to help assemble more sophisticated parsers.
|
||||
"""
|
||||
from typing import Iterator, Mapping, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseBlobParser
|
||||
from langchain_community.document_loaders.blob_loaders.schema import Blob
|
||||
|
||||
|
||||
class MimeTypeBasedParser(BaseBlobParser):
|
||||
"""Parser that uses `mime`-types to parse a blob.
|
||||
|
||||
This parser is useful for simple pipelines where the mime-type is sufficient
|
||||
to determine how to parse a blob.
|
||||
|
||||
To use, configure handlers based on mime-types and pass them to the initializer.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser
|
||||
|
||||
parser = MimeTypeBasedParser(
|
||||
handlers={
|
||||
"application/pdf": ...,
|
||||
},
|
||||
fallback_parser=...,
|
||||
)
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: Mapping[str, BaseBlobParser],
|
||||
*,
|
||||
fallback_parser: Optional[BaseBlobParser] = None,
|
||||
) -> None:
|
||||
"""Define a parser that uses mime-types to determine how to parse a blob.
|
||||
|
||||
Args:
|
||||
handlers: A mapping from mime-types to functions that take a blob, parse it
|
||||
and return a document.
|
||||
fallback_parser: A fallback_parser parser to use if the mime-type is not
|
||||
found in the handlers. If provided, this parser will be
|
||||
used to parse blobs with all mime-types not found in
|
||||
the handlers.
|
||||
If not provided, a ValueError will be raised if the
|
||||
mime-type is not found in the handlers.
|
||||
"""
|
||||
self.handlers = handlers
|
||||
self.fallback_parser = fallback_parser
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Load documents from a blob."""
|
||||
mimetype = blob.mimetype
|
||||
|
||||
if mimetype is None:
|
||||
raise ValueError(f"{blob} does not have a mimetype.")
|
||||
|
||||
if mimetype in self.handlers:
|
||||
handler = self.handlers[mimetype]
|
||||
yield from handler.lazy_parse(blob)
|
||||
else:
|
||||
if self.fallback_parser is not None:
|
||||
yield from self.fallback_parser.lazy_parse(blob)
|
||||
else:
|
||||
raise ValueError(f"Unsupported mime type: {mimetype}")
|
||||
@@ -1,157 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseBlobParser
|
||||
from langchain_community.document_loaders.blob_loaders import Blob
|
||||
from langchain_community.document_loaders.parsers.language.cobol import CobolSegmenter
|
||||
from langchain_community.document_loaders.parsers.language.javascript import (
|
||||
JavaScriptSegmenter,
|
||||
)
|
||||
from langchain_community.document_loaders.parsers.language.python import PythonSegmenter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.text_splitter import Language
|
||||
|
||||
try:
|
||||
from langchain.text_splitter import Language
|
||||
LANGUAGE_EXTENSIONS: Dict[str, str] = {
|
||||
"py": Language.PYTHON,
|
||||
"js": Language.JS,
|
||||
"cobol": Language.COBOL,
|
||||
}
|
||||
|
||||
LANGUAGE_SEGMENTERS: Dict[str, Any] = {
|
||||
Language.PYTHON: PythonSegmenter,
|
||||
Language.JS: JavaScriptSegmenter,
|
||||
Language.COBOL: CobolSegmenter,
|
||||
}
|
||||
except ImportError:
|
||||
LANGUAGE_EXTENSIONS = {}
|
||||
LANGUAGE_SEGMENTERS = {}
|
||||
|
||||
|
||||
class LanguageParser(BaseBlobParser):
|
||||
"""Parse using the respective programming language syntax.
|
||||
|
||||
Each top-level function and class in the code is loaded into separate documents.
|
||||
Furthermore, an extra document is generated, containing the remaining top-level code
|
||||
that excludes the already segmented functions and classes.
|
||||
|
||||
This approach can potentially improve the accuracy of QA models over source code.
|
||||
|
||||
Currently, the supported languages for code parsing are Python and JavaScript.
|
||||
|
||||
The language used for parsing can be configured, along with the minimum number of
|
||||
lines required to activate the splitting based on syntax.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.text_splitter.Language
|
||||
from langchain_community.document_loaders.generic import GenericLoader
|
||||
from langchain_community.document_loaders.parsers import LanguageParser
|
||||
|
||||
loader = GenericLoader.from_filesystem(
|
||||
"./code",
|
||||
glob="**/*",
|
||||
suffixes=[".py", ".js"],
|
||||
parser=LanguageParser()
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
Example instantiations to manually select the language:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.text_splitter import Language
|
||||
|
||||
loader = GenericLoader.from_filesystem(
|
||||
"./code",
|
||||
glob="**/*",
|
||||
suffixes=[".py"],
|
||||
parser=LanguageParser(language=Language.PYTHON)
|
||||
)
|
||||
|
||||
Example instantiations to set number of lines threshold:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
loader = GenericLoader.from_filesystem(
|
||||
"./code",
|
||||
glob="**/*",
|
||||
suffixes=[".py"],
|
||||
parser=LanguageParser(parser_threshold=200)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, language: Optional[Language] = None, parser_threshold: int = 0):
|
||||
"""
|
||||
Language parser that split code using the respective language syntax.
|
||||
|
||||
Args:
|
||||
language: If None (default), it will try to infer language from source.
|
||||
parser_threshold: Minimum lines needed to activate parsing (0 by default).
|
||||
"""
|
||||
self.language = language
|
||||
self.parser_threshold = parser_threshold
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
code = blob.as_string()
|
||||
|
||||
language = self.language or (
|
||||
LANGUAGE_EXTENSIONS.get(blob.source.rsplit(".", 1)[-1])
|
||||
if isinstance(blob.source, str)
|
||||
else None
|
||||
)
|
||||
|
||||
if language is None:
|
||||
yield Document(
|
||||
page_content=code,
|
||||
metadata={
|
||||
"source": blob.source,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if self.parser_threshold >= len(code.splitlines()):
|
||||
yield Document(
|
||||
page_content=code,
|
||||
metadata={
|
||||
"source": blob.source,
|
||||
"language": language,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
self.Segmenter = LANGUAGE_SEGMENTERS[language]
|
||||
segmenter = self.Segmenter(blob.as_string())
|
||||
if not segmenter.is_valid():
|
||||
yield Document(
|
||||
page_content=code,
|
||||
metadata={
|
||||
"source": blob.source,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
for functions_classes in segmenter.extract_functions_classes():
|
||||
yield Document(
|
||||
page_content=functions_classes,
|
||||
metadata={
|
||||
"source": blob.source,
|
||||
"content_type": "functions_classes",
|
||||
"language": language,
|
||||
},
|
||||
)
|
||||
yield Document(
|
||||
page_content=segmenter.simplify_code(),
|
||||
metadata={
|
||||
"source": blob.source,
|
||||
"content_type": "simplified_code",
|
||||
"language": language,
|
||||
},
|
||||
)
|
||||
@@ -1,262 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
from telethon.hints import EntityLike
|
||||
|
||||
|
||||
def concatenate_rows(row: dict) -> str:
|
||||
"""Combine message information in a readable format ready to be used."""
|
||||
date = row["date"]
|
||||
sender = row["from"]
|
||||
text = row["text"]
|
||||
return f"{sender} on {date}: {text}\n\n"
|
||||
|
||||
|
||||
class TelegramChatFileLoader(BaseLoader):
|
||||
"""Load from `Telegram chat` dump."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
"""Initialize with a path."""
|
||||
self.file_path = path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load documents."""
|
||||
p = Path(self.file_path)
|
||||
|
||||
with open(p, encoding="utf8") as f:
|
||||
d = json.load(f)
|
||||
|
||||
text = "".join(
|
||||
concatenate_rows(message)
|
||||
for message in d["messages"]
|
||||
if message["type"] == "message" and isinstance(message["text"], str)
|
||||
)
|
||||
metadata = {"source": str(p)}
|
||||
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
|
||||
|
||||
def text_to_docs(text: Union[str, List[str]]) -> List[Document]:
|
||||
"""Convert a string or list of strings to a list of Documents with metadata."""
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
if isinstance(text, str):
|
||||
# Take a single string as one page
|
||||
text = [text]
|
||||
page_docs = [Document(page_content=page) for page in text]
|
||||
|
||||
# Add page numbers as metadata
|
||||
for i, doc in enumerate(page_docs):
|
||||
doc.metadata["page"] = i + 1
|
||||
|
||||
# Split pages into chunks
|
||||
doc_chunks = []
|
||||
|
||||
for doc in page_docs:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=800,
|
||||
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
|
||||
chunk_overlap=20,
|
||||
)
|
||||
chunks = text_splitter.split_text(doc.page_content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc = Document(
|
||||
page_content=chunk, metadata={"page": doc.metadata["page"], "chunk": i}
|
||||
)
|
||||
# Add sources a metadata
|
||||
doc.metadata["source"] = f"{doc.metadata['page']}-{doc.metadata['chunk']}"
|
||||
doc_chunks.append(doc)
|
||||
return doc_chunks
|
||||
|
||||
|
||||
class TelegramChatApiLoader(BaseLoader):
|
||||
"""Load `Telegram` chat json directory dump."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_entity: Optional[EntityLike] = None,
|
||||
api_id: Optional[int] = None,
|
||||
api_hash: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
file_path: str = "telegram_data.json",
|
||||
):
|
||||
"""Initialize with API parameters.
|
||||
|
||||
Args:
|
||||
chat_entity: The chat entity to fetch data from.
|
||||
api_id: The API ID.
|
||||
api_hash: The API hash.
|
||||
username: The username.
|
||||
file_path: The file path to save the data to. Defaults to
|
||||
"telegram_data.json".
|
||||
"""
|
||||
self.chat_entity = chat_entity
|
||||
self.api_id = api_id
|
||||
self.api_hash = api_hash
|
||||
self.username = username
|
||||
self.file_path = file_path
|
||||
|
||||
async def fetch_data_from_telegram(self) -> None:
|
||||
"""Fetch data from Telegram API and save it as a JSON file."""
|
||||
from telethon.sync import TelegramClient
|
||||
|
||||
data = []
|
||||
async with TelegramClient(self.username, self.api_id, self.api_hash) as client:
|
||||
async for message in client.iter_messages(self.chat_entity):
|
||||
is_reply = message.reply_to is not None
|
||||
reply_to_id = message.reply_to.reply_to_msg_id if is_reply else None
|
||||
data.append(
|
||||
{
|
||||
"sender_id": message.sender_id,
|
||||
"text": message.text,
|
||||
"date": message.date.isoformat(),
|
||||
"message.id": message.id,
|
||||
"is_reply": is_reply,
|
||||
"reply_to_id": reply_to_id,
|
||||
}
|
||||
)
|
||||
|
||||
with open(self.file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def _get_message_threads(self, data: pd.DataFrame) -> dict:
|
||||
"""Create a dictionary of message threads from the given data.
|
||||
|
||||
Args:
|
||||
data (pd.DataFrame): A DataFrame containing the conversation \
|
||||
data with columns:
|
||||
- message.sender_id
|
||||
- text
|
||||
- date
|
||||
- message.id
|
||||
- is_reply
|
||||
- reply_to_id
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where the key is the parent message ID and \
|
||||
the value is a list of message IDs in ascending order.
|
||||
"""
|
||||
|
||||
def find_replies(parent_id: int, reply_data: pd.DataFrame) -> List[int]:
|
||||
"""
|
||||
Recursively find all replies to a given parent message ID.
|
||||
|
||||
Args:
|
||||
parent_id (int): The parent message ID.
|
||||
reply_data (pd.DataFrame): A DataFrame containing reply messages.
|
||||
|
||||
Returns:
|
||||
list: A list of message IDs that are replies to the parent message ID.
|
||||
"""
|
||||
# Find direct replies to the parent message ID
|
||||
direct_replies = reply_data[reply_data["reply_to_id"] == parent_id][
|
||||
"message.id"
|
||||
].tolist()
|
||||
|
||||
# Recursively find replies to the direct replies
|
||||
all_replies = []
|
||||
for reply_id in direct_replies:
|
||||
all_replies += [reply_id] + find_replies(reply_id, reply_data)
|
||||
|
||||
return all_replies
|
||||
|
||||
# Filter out parent messages
|
||||
parent_messages = data[~data["is_reply"]]
|
||||
|
||||
# Filter out reply messages and drop rows with NaN in 'reply_to_id'
|
||||
reply_messages = data[data["is_reply"]].dropna(subset=["reply_to_id"])
|
||||
|
||||
# Convert 'reply_to_id' to integer
|
||||
reply_messages["reply_to_id"] = reply_messages["reply_to_id"].astype(int)
|
||||
|
||||
# Create a dictionary of message threads with parent message IDs as keys and \
|
||||
# lists of reply message IDs as values
|
||||
message_threads = {
|
||||
parent_id: [parent_id] + find_replies(parent_id, reply_messages)
|
||||
for parent_id in parent_messages["message.id"]
|
||||
}
|
||||
|
||||
return message_threads
|
||||
|
||||
def _combine_message_texts(
|
||||
self, message_threads: Dict[int, List[int]], data: pd.DataFrame
|
||||
) -> str:
|
||||
"""
|
||||
Combine the message texts for each parent message ID based \
|
||||
on the list of message threads.
|
||||
|
||||
Args:
|
||||
message_threads (dict): A dictionary where the key is the parent message \
|
||||
ID and the value is a list of message IDs in ascending order.
|
||||
data (pd.DataFrame): A DataFrame containing the conversation data:
|
||||
- message.sender_id
|
||||
- text
|
||||
- date
|
||||
- message.id
|
||||
- is_reply
|
||||
- reply_to_id
|
||||
|
||||
Returns:
|
||||
str: A combined string of message texts sorted by date.
|
||||
"""
|
||||
combined_text = ""
|
||||
|
||||
# Iterate through sorted parent message IDs
|
||||
for parent_id, message_ids in message_threads.items():
|
||||
# Get the message texts for the message IDs and sort them by date
|
||||
message_texts = (
|
||||
data[data["message.id"].isin(message_ids)]
|
||||
.sort_values(by="date")["text"]
|
||||
.tolist()
|
||||
)
|
||||
message_texts = [str(elem) for elem in message_texts]
|
||||
|
||||
# Combine the message texts
|
||||
combined_text += " ".join(message_texts) + ".\n"
|
||||
|
||||
return combined_text.strip()
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load documents."""
|
||||
|
||||
if self.chat_entity is not None:
|
||||
try:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
asyncio.run(self.fetch_data_from_telegram())
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""`nest_asyncio` package not found.
|
||||
please install with `pip install nest_asyncio`
|
||||
"""
|
||||
)
|
||||
|
||||
p = Path(self.file_path)
|
||||
|
||||
with open(p, encoding="utf8") as f:
|
||||
d = json.load(f)
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""`pandas` package not found.
|
||||
please install with `pip install pandas`
|
||||
"""
|
||||
)
|
||||
normalized_messages = pd.json_normalize(d)
|
||||
df = pd.DataFrame(normalized_messages)
|
||||
|
||||
message_threads = self._get_message_threads(df)
|
||||
combined_texts = self._combine_message_texts(message_threads, df)
|
||||
|
||||
return text_to_docs(combined_texts)
|
||||
@@ -1,149 +0,0 @@
|
||||
from typing import Any, Iterator, List, Sequence, cast
|
||||
|
||||
from langchain_core.documents import BaseDocumentTransformer, Document
|
||||
|
||||
|
||||
class BeautifulSoupTransformer(BaseDocumentTransformer):
|
||||
"""Transform HTML content by extracting specific tags and removing unwanted ones.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.document_transformers import BeautifulSoupTransformer
|
||||
|
||||
bs4_transformer = BeautifulSoupTransformer()
|
||||
docs_transformed = bs4_transformer.transform_documents(docs)
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the transformer.
|
||||
|
||||
This checks if the BeautifulSoup4 package is installed.
|
||||
If not, it raises an ImportError.
|
||||
"""
|
||||
try:
|
||||
import bs4 # noqa:F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"BeautifulSoup4 is required for BeautifulSoupTransformer. "
|
||||
"Please install it with `pip install beautifulsoup4`."
|
||||
)
|
||||
|
||||
def transform_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
unwanted_tags: List[str] = ["script", "style"],
|
||||
tags_to_extract: List[str] = ["p", "li", "div", "a"],
|
||||
remove_lines: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
Transform a list of Document objects by cleaning their HTML content.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Document objects containing HTML content.
|
||||
unwanted_tags: A list of tags to be removed from the HTML.
|
||||
tags_to_extract: A list of tags whose content will be extracted.
|
||||
remove_lines: If set to True, unnecessary lines will be
|
||||
removed from the HTML content.
|
||||
|
||||
Returns:
|
||||
A sequence of Document objects with transformed content.
|
||||
"""
|
||||
for doc in documents:
|
||||
cleaned_content = doc.page_content
|
||||
|
||||
cleaned_content = self.remove_unwanted_tags(cleaned_content, unwanted_tags)
|
||||
|
||||
cleaned_content = self.extract_tags(cleaned_content, tags_to_extract)
|
||||
|
||||
if remove_lines:
|
||||
cleaned_content = self.remove_unnecessary_lines(cleaned_content)
|
||||
|
||||
doc.page_content = cleaned_content
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def remove_unwanted_tags(html_content: str, unwanted_tags: List[str]) -> str:
|
||||
"""
|
||||
Remove unwanted tags from a given HTML content.
|
||||
|
||||
Args:
|
||||
html_content: The original HTML content string.
|
||||
unwanted_tags: A list of tags to be removed from the HTML.
|
||||
|
||||
Returns:
|
||||
A cleaned HTML string with unwanted tags removed.
|
||||
"""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
for tag in unwanted_tags:
|
||||
for element in soup.find_all(tag):
|
||||
element.decompose()
|
||||
return str(soup)
|
||||
|
||||
@staticmethod
|
||||
def extract_tags(html_content: str, tags: List[str]) -> str:
|
||||
"""
|
||||
Extract specific tags from a given HTML content.
|
||||
|
||||
Args:
|
||||
html_content: The original HTML content string.
|
||||
tags: A list of tags to be extracted from the HTML.
|
||||
|
||||
Returns:
|
||||
A string combining the content of the extracted tags.
|
||||
"""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
text_parts: List[str] = []
|
||||
for element in soup.find_all():
|
||||
if element.name in tags:
|
||||
# Extract all navigable strings recursively from this element.
|
||||
text_parts += get_navigable_strings(element)
|
||||
|
||||
# To avoid duplicate text, remove all descendants from the soup.
|
||||
element.decompose()
|
||||
|
||||
return " ".join(text_parts)
|
||||
|
||||
@staticmethod
|
||||
def remove_unnecessary_lines(content: str) -> str:
|
||||
"""
|
||||
Clean up the content by removing unnecessary lines.
|
||||
|
||||
Args:
|
||||
content: A string, which may contain unnecessary lines or spaces.
|
||||
|
||||
Returns:
|
||||
A cleaned string with unnecessary lines removed.
|
||||
"""
|
||||
lines = content.split("\n")
|
||||
stripped_lines = [line.strip() for line in lines]
|
||||
non_empty_lines = [line for line in stripped_lines if line]
|
||||
cleaned_content = " ".join(non_empty_lines)
|
||||
return cleaned_content
|
||||
|
||||
async def atransform_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
**kwargs: Any,
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_navigable_strings(element: Any) -> Iterator[str]:
|
||||
from bs4 import NavigableString, Tag
|
||||
|
||||
for child in cast(Tag, element).children:
|
||||
if isinstance(child, Tag):
|
||||
yield from get_navigable_strings(child)
|
||||
elif isinstance(child, NavigableString):
|
||||
if (element.name == "a") and (href := element.get("href")):
|
||||
yield f"{child.strip()} ({href})"
|
||||
else:
|
||||
yield child.strip()
|
||||
@@ -1,140 +0,0 @@
|
||||
"""Document transformers that use OpenAI Functions models"""
|
||||
from typing import Any, Dict, Optional, Sequence, Type, Union
|
||||
|
||||
from langchain_core.documents import BaseDocumentTransformer, Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
|
||||
"""Extract metadata tags from document contents using OpenAI functions.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.document_transformers import OpenAIMetadataTagger
|
||||
from langchain_core.documents import Document
|
||||
|
||||
schema = {
|
||||
"properties": {
|
||||
"movie_title": { "type": "string" },
|
||||
"critic": { "type": "string" },
|
||||
"tone": {
|
||||
"type": "string",
|
||||
"enum": ["positive", "negative"]
|
||||
},
|
||||
"rating": {
|
||||
"type": "integer",
|
||||
"description": "The number of stars the critic rated the movie"
|
||||
}
|
||||
},
|
||||
"required": ["movie_title", "critic", "tone"]
|
||||
}
|
||||
|
||||
# Must be an OpenAI model that supports functions
|
||||
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
|
||||
tagging_chain = create_tagging_chain(schema, llm)
|
||||
document_transformer = OpenAIMetadataTagger(tagging_chain=tagging_chain)
|
||||
original_documents = [
|
||||
Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\nThis is the greatest movie ever made. 4 out of 5 stars."),
|
||||
Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
|
||||
]
|
||||
|
||||
enhanced_documents = document_transformer.transform_documents(original_documents)
|
||||
""" # noqa: E501
|
||||
|
||||
tagging_chain: Any
|
||||
"""The chain used to extract metadata from each document."""
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Automatically extract and populate metadata
|
||||
for each document according to the provided schema."""
|
||||
|
||||
new_documents = []
|
||||
|
||||
for document in documents:
|
||||
extracted_metadata: Dict = self.tagging_chain.run(document.page_content) # type: ignore[assignment] # noqa: E501
|
||||
new_document = Document(
|
||||
page_content=document.page_content,
|
||||
metadata={**extracted_metadata, **document.metadata},
|
||||
)
|
||||
new_documents.append(new_document)
|
||||
return new_documents
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_metadata_tagger(
|
||||
metadata_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[ChatPromptTemplate] = None,
|
||||
*,
|
||||
tagging_chain_kwargs: Optional[Dict] = None,
|
||||
) -> OpenAIMetadataTagger:
|
||||
"""Create a DocumentTransformer that uses an OpenAI function chain to automatically
|
||||
tag documents with metadata based on their content and an input schema.
|
||||
|
||||
Args:
|
||||
metadata_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
|
||||
is passed in, it's assumed to already be a valid JsonSchema.
|
||||
For best results, pydantic.BaseModels should have docstrings describing what
|
||||
the schema represents and descriptions for the parameters.
|
||||
llm: Language model to use, assumed to support the OpenAI function-calling API.
|
||||
Defaults to use "gpt-3.5-turbo-0613"
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
|
||||
Returns:
|
||||
An LLMChain that will pass the given function to the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.document_transformers import create_metadata_tagger
|
||||
from langchain_core.documents import Document
|
||||
|
||||
schema = {
|
||||
"properties": {
|
||||
"movie_title": { "type": "string" },
|
||||
"critic": { "type": "string" },
|
||||
"tone": {
|
||||
"type": "string",
|
||||
"enum": ["positive", "negative"]
|
||||
},
|
||||
"rating": {
|
||||
"type": "integer",
|
||||
"description": "The number of stars the critic rated the movie"
|
||||
}
|
||||
},
|
||||
"required": ["movie_title", "critic", "tone"]
|
||||
}
|
||||
|
||||
# Must be an OpenAI model that supports functions
|
||||
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
|
||||
|
||||
document_transformer = create_metadata_tagger(schema, llm)
|
||||
original_documents = [
|
||||
Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\nThis is the greatest movie ever made. 4 out of 5 stars."),
|
||||
Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
|
||||
]
|
||||
|
||||
enhanced_documents = document_transformer.transform_documents(original_documents)
|
||||
""" # noqa: E501
|
||||
from langchain.chains.openai_functions import create_tagging_chain
|
||||
metadata_schema = (
|
||||
metadata_schema
|
||||
if isinstance(metadata_schema, dict)
|
||||
else metadata_schema.schema()
|
||||
)
|
||||
_tagging_chain_kwargs = tagging_chain_kwargs or {}
|
||||
tagging_chain = create_tagging_chain(
|
||||
metadata_schema, llm, prompt=prompt, **_tagging_chain_kwargs
|
||||
)
|
||||
return OpenAIMetadataTagger(tagging_chain=tagging_chain)
|
||||
@@ -1,161 +0,0 @@
|
||||
"""**Embedding models** are wrappers around embedding models
|
||||
from different APIs and services.
|
||||
|
||||
**Embedding models** can be LLMs or not.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
Embeddings --> <name>Embeddings # Examples: OpenAIEmbeddings, HuggingFaceEmbeddings
|
||||
"""
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_community.embeddings.aleph_alpha import (
|
||||
AlephAlphaAsymmetricSemanticEmbedding,
|
||||
AlephAlphaSymmetricSemanticEmbedding,
|
||||
)
|
||||
from langchain_community.embeddings.awa import AwaEmbeddings
|
||||
from langchain_community.embeddings.azure_openai import AzureOpenAIEmbeddings
|
||||
from langchain_community.embeddings.baidu_qianfan_endpoint import (
|
||||
QianfanEmbeddingsEndpoint,
|
||||
)
|
||||
from langchain_community.embeddings.bedrock import BedrockEmbeddings
|
||||
from langchain_community.embeddings.bookend import BookendEmbeddings
|
||||
from langchain_community.embeddings.clarifai import ClarifaiEmbeddings
|
||||
from langchain_community.embeddings.cohere import CohereEmbeddings
|
||||
from langchain_community.embeddings.dashscope import DashScopeEmbeddings
|
||||
from langchain_community.embeddings.databricks import DatabricksEmbeddings
|
||||
from langchain_community.embeddings.deepinfra import DeepInfraEmbeddings
|
||||
from langchain_community.embeddings.edenai import EdenAiEmbeddings
|
||||
from langchain_community.embeddings.elasticsearch import ElasticsearchEmbeddings
|
||||
from langchain_community.embeddings.embaas import EmbaasEmbeddings
|
||||
from langchain_community.embeddings.ernie import ErnieEmbeddings
|
||||
from langchain_community.embeddings.fake import (
|
||||
DeterministicFakeEmbedding,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
|
||||
from langchain_community.embeddings.google_palm import GooglePalmEmbeddings
|
||||
from langchain_community.embeddings.gpt4all import GPT4AllEmbeddings
|
||||
from langchain_community.embeddings.gradient_ai import GradientEmbeddings
|
||||
from langchain_community.embeddings.huggingface import (
|
||||
HuggingFaceBgeEmbeddings,
|
||||
HuggingFaceEmbeddings,
|
||||
HuggingFaceInferenceAPIEmbeddings,
|
||||
HuggingFaceInstructEmbeddings,
|
||||
)
|
||||
from langchain_community.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||
from langchain_community.embeddings.infinity import InfinityEmbeddings
|
||||
from langchain_community.embeddings.javelin_ai_gateway import JavelinAIGatewayEmbeddings
|
||||
from langchain_community.embeddings.jina import JinaEmbeddings
|
||||
from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
||||
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
|
||||
from langchain_community.embeddings.localai import LocalAIEmbeddings
|
||||
from langchain_community.embeddings.minimax import MiniMaxEmbeddings
|
||||
from langchain_community.embeddings.mlflow import MlflowEmbeddings
|
||||
from langchain_community.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
|
||||
from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings
|
||||
from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
||||
from langchain_community.embeddings.nlpcloud import NLPCloudEmbeddings
|
||||
from langchain_community.embeddings.octoai_embeddings import OctoAIEmbeddings
|
||||
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.embeddings.sagemaker_endpoint import (
|
||||
SagemakerEndpointEmbeddings,
|
||||
)
|
||||
from langchain_community.embeddings.self_hosted import SelfHostedEmbeddings
|
||||
from langchain_community.embeddings.self_hosted_hugging_face import (
|
||||
SelfHostedHuggingFaceEmbeddings,
|
||||
SelfHostedHuggingFaceInstructEmbeddings,
|
||||
)
|
||||
from langchain_community.embeddings.sentence_transformer import (
|
||||
SentenceTransformerEmbeddings,
|
||||
)
|
||||
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings
|
||||
from langchain_community.embeddings.tensorflow_hub import TensorflowHubEmbeddings
|
||||
from langchain_community.embeddings.vertexai import VertexAIEmbeddings
|
||||
from langchain_community.embeddings.voyageai import VoyageEmbeddings
|
||||
from langchain_community.embeddings.xinference import XinferenceEmbeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"ClarifaiEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"DatabricksEmbeddings",
|
||||
"ElasticsearchEmbeddings",
|
||||
"FastEmbedEmbeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
"HuggingFaceInferenceAPIEmbeddings",
|
||||
"InfinityEmbeddings",
|
||||
"GradientEmbeddings",
|
||||
"JinaEmbeddings",
|
||||
"LlamaCppEmbeddings",
|
||||
"HuggingFaceHubEmbeddings",
|
||||
"MlflowEmbeddings",
|
||||
"MlflowAIGatewayEmbeddings",
|
||||
"ModelScopeEmbeddings",
|
||||
"TensorflowHubEmbeddings",
|
||||
"SagemakerEndpointEmbeddings",
|
||||
"HuggingFaceInstructEmbeddings",
|
||||
"MosaicMLInstructorEmbeddings",
|
||||
"SelfHostedEmbeddings",
|
||||
"SelfHostedHuggingFaceEmbeddings",
|
||||
"SelfHostedHuggingFaceInstructEmbeddings",
|
||||
"FakeEmbeddings",
|
||||
"DeterministicFakeEmbedding",
|
||||
"AlephAlphaAsymmetricSemanticEmbedding",
|
||||
"AlephAlphaSymmetricSemanticEmbedding",
|
||||
"SentenceTransformerEmbeddings",
|
||||
"GooglePalmEmbeddings",
|
||||
"MiniMaxEmbeddings",
|
||||
"VertexAIEmbeddings",
|
||||
"BedrockEmbeddings",
|
||||
"DeepInfraEmbeddings",
|
||||
"EdenAiEmbeddings",
|
||||
"DashScopeEmbeddings",
|
||||
"EmbaasEmbeddings",
|
||||
"OctoAIEmbeddings",
|
||||
"SpacyEmbeddings",
|
||||
"NLPCloudEmbeddings",
|
||||
"GPT4AllEmbeddings",
|
||||
"XinferenceEmbeddings",
|
||||
"LocalAIEmbeddings",
|
||||
"AwaEmbeddings",
|
||||
"HuggingFaceBgeEmbeddings",
|
||||
"ErnieEmbeddings",
|
||||
"JavelinAIGatewayEmbeddings",
|
||||
"OllamaEmbeddings",
|
||||
"QianfanEmbeddingsEndpoint",
|
||||
"JohnSnowLabsEmbeddings",
|
||||
"VoyageEmbeddings",
|
||||
"BookendEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
# TODO: this is in here to maintain backwards compatibility
|
||||
class HypotheticalDocumentEmbedder:
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
logger.warning(
|
||||
"Using a deprecated class. Please use "
|
||||
"`from langchain.chains import HypotheticalDocumentEmbedder` instead"
|
||||
)
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
|
||||
|
||||
return H(*args, **kwargs) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
logger.warning(
|
||||
"Using a deprecated class. Please use "
|
||||
"`from langchain.chains import HypotheticalDocumentEmbedder` instead"
|
||||
)
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
|
||||
|
||||
return H.from_llm(*args, **kwargs)
|
||||
@@ -1,343 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
DEFAULT_BGE_MODEL = "BAAI/bge-large-en"
|
||||
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
|
||||
DEFAULT_QUERY_INSTRUCTION = (
|
||||
"Represent the question for retrieving supporting documents: "
|
||||
)
|
||||
DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
|
||||
"Represent this question for searching relevant passages: "
|
||||
)
|
||||
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""HuggingFace sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': False}
|
||||
hf = HuggingFaceEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_MODEL_NAME
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
||||
multi_process: bool = False
|
||||
"""Run encode() on multiple GPUs."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers
|
||||
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
) from exc
|
||||
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
import sentence_transformers
|
||||
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
if self.multi_process:
|
||||
pool = self.client.start_multi_process_pool()
|
||||
embeddings = self.client.encode_multi_process(texts, pool)
|
||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||
else:
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers``
|
||||
and ``InstructorEmbedding`` python packages installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
||||
|
||||
model_name = "hkunlp/instructor-large"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': True}
|
||||
hf = HuggingFaceInstructEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_INSTRUCT_MODEL
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
||||
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
|
||||
"""Instruction to use for embedding documents."""
|
||||
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
||||
"""Instruction to use for embedding query."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from InstructorEmbedding import INSTRUCTOR
|
||||
|
||||
self.client = INSTRUCTOR(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
instruction_pairs = [[self.embed_instruction, text] for text in texts]
|
||||
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
instruction_pair = [self.query_instruction, text]
|
||||
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
"""HuggingFace BGE sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||||
|
||||
model_name = "BAAI/bge-large-en"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': True}
|
||||
hf = HuggingFaceBgeEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_BGE_MODEL
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
||||
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN
|
||||
"""Instruction to use for embedding query."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers
|
||||
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence_transformers`."
|
||||
) from exc
|
||||
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
if "-zh" in self.model_name:
|
||||
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
texts = [t.replace("\n", " ") for t in texts]
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
embedding = self.client.encode(
|
||||
self.query_instruction + text, **self.encode_kwargs
|
||||
)
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
||||
"""Embed texts using the HuggingFace API.
|
||||
|
||||
Requires a HuggingFace Inference API key and a model name.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
"""Your API key for the HuggingFace Inference API."""
|
||||
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
"""The name of the model to use for text embeddings."""
|
||||
api_url: Optional[str] = None
|
||||
"""Custom inference endpoint url. None for using default public url."""
|
||||
|
||||
@property
|
||||
def _api_url(self) -> str:
|
||||
return self.api_url or self._default_api_url
|
||||
|
||||
@property
|
||||
def _default_api_url(self) -> str:
|
||||
return (
|
||||
"https://api-inference.huggingface.co"
|
||||
"/pipeline"
|
||||
"/feature-extraction"
|
||||
f"/{self.model_name}"
|
||||
)
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict:
|
||||
return {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embedded texts as List[List[float]], where each inner List[float]
|
||||
corresponds to a single input text.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
||||
|
||||
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
|
||||
api_key="your_api_key",
|
||||
model_name="sentence-transformers/all-MiniLM-l6-v2"
|
||||
)
|
||||
texts = ["Hello, world!", "How are you?"]
|
||||
hf_embeddings.embed_documents(texts)
|
||||
""" # noqa: E501
|
||||
response = requests.post(
|
||||
self._api_url,
|
||||
headers=self._headers,
|
||||
json={
|
||||
"inputs": texts,
|
||||
"options": {"wait_for_model": True, "use_cache": True},
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
@@ -1,92 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||
|
||||
|
||||
class JohnSnowLabsEmbeddings(BaseModel, Embeddings):
|
||||
"""JohnSnowLabs embedding models
|
||||
|
||||
To use, you should have the ``johnsnowlabs`` python package installed.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
||||
|
||||
embedding = JohnSnowLabsEmbeddings(model='embed_sentence.bert')
|
||||
output = embedding.embed_query("foo bar")
|
||||
""" # noqa: E501
|
||||
|
||||
model: Any = "embed_sentence.bert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any = "embed_sentence.bert",
|
||||
hardware_target: str = "cpu",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize the johnsnowlabs model."""
|
||||
super().__init__(**kwargs)
|
||||
# 1) Check imports
|
||||
try:
|
||||
from johnsnowlabs import nlp
|
||||
from nlu.pipe.pipeline import NLUPipeline
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import johnsnowlabs python package. "
|
||||
"Please install it with `pip install johnsnowlabs`."
|
||||
) from exc
|
||||
|
||||
# 2) Start a Spark Session
|
||||
try:
|
||||
os.environ["PYSPARK_PYTHON"] = sys.executable
|
||||
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
|
||||
nlp.start(hardware_target=hardware_target)
|
||||
except Exception as exc:
|
||||
raise Exception("Failure starting Spark Session") from exc
|
||||
|
||||
# 3) Load the model
|
||||
try:
|
||||
if isinstance(model, str):
|
||||
self.model = nlp.load(model)
|
||||
elif isinstance(model, NLUPipeline):
|
||||
self.model = model
|
||||
else:
|
||||
self.model = nlp.to_nlu_pipe(model)
|
||||
except Exception as exc:
|
||||
raise Exception("Failure loading model") from exc
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a JohnSnowLabs transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
df = self.model.predict(texts, output_level="document")
|
||||
emb_col = None
|
||||
for c in df.columns:
|
||||
if "embedding" in c:
|
||||
emb_col = c
|
||||
return [vec.tolist() for vec in df[emb_col].tolist()]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a JohnSnowLabs transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
@@ -1,168 +0,0 @@
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from langchain_community.embeddings.self_hosted import SelfHostedEmbeddings
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
|
||||
DEFAULT_QUERY_INSTRUCTION = (
|
||||
"Represent the question for retrieving supporting documents: "
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _embed_documents(client: Any, *args: Any, **kwargs: Any) -> List[List[float]]:
|
||||
"""Inference function to send to the remote hardware.
|
||||
|
||||
Accepts a sentence_transformer model_id and
|
||||
returns a list of embeddings for each document in the batch.
|
||||
"""
|
||||
return client.encode(*args, **kwargs)
|
||||
|
||||
|
||||
def load_embedding_model(model_id: str, instruct: bool = False, device: int = 0) -> Any:
|
||||
"""Load the embedding model."""
|
||||
if not instruct:
|
||||
import sentence_transformers
|
||||
|
||||
client = sentence_transformers.SentenceTransformer(model_id)
|
||||
else:
|
||||
from InstructorEmbedding import INSTRUCTOR
|
||||
|
||||
client = INSTRUCTOR(model_id)
|
||||
|
||||
if importlib.util.find_spec("torch") is not None:
|
||||
import torch
|
||||
|
||||
cuda_device_count = torch.cuda.device_count()
|
||||
if device < -1 or (device >= cuda_device_count):
|
||||
raise ValueError(
|
||||
f"Got device=={device}, "
|
||||
f"device is required to be within [-1, {cuda_device_count})"
|
||||
)
|
||||
if device < 0 and cuda_device_count > 0:
|
||||
logger.warning(
|
||||
"Device has %d GPUs available. "
|
||||
"Provide device={deviceId} to `from_model_id` to use available"
|
||||
"GPUs for execution. deviceId is -1 for CPU and "
|
||||
"can be a positive integer associated with CUDA device id.",
|
||||
cuda_device_count,
|
||||
)
|
||||
|
||||
client = client.to(device)
|
||||
return client
|
||||
|
||||
|
||||
class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings):
|
||||
"""HuggingFace embedding models on self-hosted remote hardware.
|
||||
|
||||
Supported hardware includes auto-launched instances on AWS, GCP, Azure,
|
||||
and Lambda, as well as servers specified
|
||||
by IP address and SSH credentials (such as on-prem, or another cloud
|
||||
like Paperspace, Coreweave, etc.).
|
||||
|
||||
To use, you should have the ``runhouse`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import SelfHostedHuggingFaceEmbeddings
|
||||
import runhouse as rh
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
|
||||
hf = SelfHostedHuggingFaceEmbeddings(model_name=model_name, hardware=gpu)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_id: str = DEFAULT_MODEL_NAME
|
||||
"""Model name to use."""
|
||||
model_reqs: List[str] = ["./", "sentence_transformers", "torch"]
|
||||
"""Requirements to install on hardware to inference the model."""
|
||||
hardware: Any
|
||||
"""Remote hardware to send the inference function to."""
|
||||
model_load_fn: Callable = load_embedding_model
|
||||
"""Function to load the model remotely on the server."""
|
||||
load_fn_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model load function."""
|
||||
inference_fn: Callable = _embed_documents
|
||||
"""Inference function to extract the embeddings."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the remote inference function."""
|
||||
load_fn_kwargs = kwargs.pop("load_fn_kwargs", {})
|
||||
load_fn_kwargs["model_id"] = load_fn_kwargs.get("model_id", DEFAULT_MODEL_NAME)
|
||||
load_fn_kwargs["instruct"] = load_fn_kwargs.get("instruct", False)
|
||||
load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
|
||||
super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
|
||||
|
||||
|
||||
class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
|
||||
"""HuggingFace InstructEmbedding models on self-hosted remote hardware.
|
||||
|
||||
Supported hardware includes auto-launched instances on AWS, GCP, Azure,
|
||||
and Lambda, as well as servers specified
|
||||
by IP address and SSH credentials (such as on-prem, or another
|
||||
cloud like Paperspace, Coreweave, etc.).
|
||||
|
||||
To use, you should have the ``runhouse`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import SelfHostedHuggingFaceInstructEmbeddings
|
||||
import runhouse as rh
|
||||
model_name = "hkunlp/instructor-large"
|
||||
gpu = rh.cluster(name='rh-a10x', instance_type='A100:1')
|
||||
hf = SelfHostedHuggingFaceInstructEmbeddings(
|
||||
model_name=model_name, hardware=gpu)
|
||||
""" # noqa: E501
|
||||
|
||||
model_id: str = DEFAULT_INSTRUCT_MODEL
|
||||
"""Model name to use."""
|
||||
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
|
||||
"""Instruction to use for embedding documents."""
|
||||
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
||||
"""Instruction to use for embedding query."""
|
||||
model_reqs: List[str] = ["./", "InstructorEmbedding", "torch"]
|
||||
"""Requirements to install on hardware to inference the model."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the remote inference function."""
|
||||
load_fn_kwargs = kwargs.pop("load_fn_kwargs", {})
|
||||
load_fn_kwargs["model_id"] = load_fn_kwargs.get(
|
||||
"model_id", DEFAULT_INSTRUCT_MODEL
|
||||
)
|
||||
load_fn_kwargs["instruct"] = load_fn_kwargs.get("instruct", True)
|
||||
load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
|
||||
super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
instruction_pairs = []
|
||||
for text in texts:
|
||||
instruction_pairs.append([self.embed_instruction, text])
|
||||
embeddings = self.client(self.pipeline_ref, instruction_pairs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
instruction_pair = [self.query_instruction, text]
|
||||
embedding = self.client(self.pipeline_ref, [instruction_pair])[0]
|
||||
return embedding.tolist()
|
||||
@@ -1,351 +0,0 @@
|
||||
import re
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
check_package_version,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str
|
||||
|
||||
|
||||
class _AnthropicCommon(BaseLanguageModel):
|
||||
client: Any = None #: :meta private:
|
||||
async_client: Any = None #: :meta private:
|
||||
model: str = Field(default="claude-2", alias="model_name")
|
||||
"""Model name to use."""
|
||||
|
||||
max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
top_k: Optional[int] = None
|
||||
"""Number of most likely tokens to consider at each step."""
|
||||
|
||||
top_p: Optional[float] = None
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
default_request_timeout: Optional[float] = None
|
||||
"""Timeout for requests to Anthropic Completion API. Default is 600 seconds."""
|
||||
|
||||
anthropic_api_url: Optional[str] = None
|
||||
|
||||
anthropic_api_key: Optional[SecretStr] = None
|
||||
|
||||
HUMAN_PROMPT: Optional[str] = None
|
||||
AI_PROMPT: Optional[str] = None
|
||||
count_tokens: Optional[Callable[[str], int]] = None
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict) -> Dict:
|
||||
extra = values.get("model_kwargs", {})
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["anthropic_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["anthropic_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anthropic_api_url",
|
||||
"ANTHROPIC_API_URL",
|
||||
default="https://api.anthropic.com",
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
check_package_version("anthropic", gte_version="0.3")
|
||||
values["client"] = anthropic.Anthropic(
|
||||
base_url=values["anthropic_api_url"],
|
||||
api_key=values["anthropic_api_key"].get_secret_value(),
|
||||
timeout=values["default_request_timeout"],
|
||||
)
|
||||
values["async_client"] = anthropic.AsyncAnthropic(
|
||||
base_url=values["anthropic_api_url"],
|
||||
api_key=values["anthropic_api_key"].get_secret_value(),
|
||||
timeout=values["default_request_timeout"],
|
||||
)
|
||||
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
|
||||
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
||||
values["count_tokens"] = values["client"].count_tokens
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import anthropic python package. "
|
||||
"Please it install it with `pip install anthropic`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters for calling Anthropic API."""
|
||||
d = {
|
||||
"max_tokens_to_sample": self.max_tokens_to_sample,
|
||||
"model": self.model,
|
||||
}
|
||||
if self.temperature is not None:
|
||||
d["temperature"] = self.temperature
|
||||
if self.top_k is not None:
|
||||
d["top_k"] = self.top_k
|
||||
if self.top_p is not None:
|
||||
d["top_p"] = self.top_p
|
||||
return {**d, **self.model_kwargs}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{}, **self._default_params}
|
||||
|
||||
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
|
||||
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
|
||||
if stop is None:
|
||||
stop = []
|
||||
|
||||
# Never want model to invent new turns of Human / Assistant dialog.
|
||||
stop.extend([self.HUMAN_PROMPT])
|
||||
|
||||
return stop
|
||||
|
||||
|
||||
class Anthropic(LLM, _AnthropicCommon):
|
||||
"""Anthropic large language models.
|
||||
|
||||
To use, you should have the ``anthropic`` python package installed, and the
|
||||
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
import anthropic
|
||||
from langchain_community.llms import Anthropic
|
||||
|
||||
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||
|
||||
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
|
||||
# and AI_PROMPT.
|
||||
response = model("What are the biggest risks facing humanity?")
|
||||
|
||||
# Or if you want to use the chat mode, build a few-shot-prompt, or
|
||||
# put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT:
|
||||
raw_prompt = "What are the biggest risks facing humanity?"
|
||||
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
|
||||
response = model(prompt)
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def raise_warning(cls, values: Dict) -> Dict:
|
||||
"""Raise warning that this class is deprecated."""
|
||||
warnings.warn(
|
||||
"This Anthropic LLM is deprecated. "
|
||||
"Please use `from langchain_community.chat_models import ChatAnthropic` "
|
||||
"instead"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "anthropic-llm"
|
||||
|
||||
def _wrap_prompt(self, prompt: str) -> str:
|
||||
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
|
||||
if prompt.startswith(self.HUMAN_PROMPT):
|
||||
return prompt # Already wrapped.
|
||||
|
||||
# Guard against common errors in specifying wrong number of newlines.
|
||||
corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt)
|
||||
if n_subs == 1:
|
||||
return corrected_prompt
|
||||
|
||||
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
|
||||
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
r"""Call out to Anthropic's completion endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt = "What are the biggest risks facing humanity?"
|
||||
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||
response = model(prompt)
|
||||
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
response = self.client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
**params,
|
||||
)
|
||||
return response.completion
|
||||
|
||||
def convert_prompt(self, prompt: PromptValue) -> str:
|
||||
return self._wrap_prompt(prompt.to_string())
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Anthropic's completion endpoint asynchronously."""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
|
||||
response = await self.async_client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
**params,
|
||||
)
|
||||
return response.completion
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
r"""Call Anthropic completion_stream and return the resulting generator.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
A generator representing the stream of tokens from Anthropic.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt = "Write a poem about a stream."
|
||||
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||
generator = anthropic.stream(prompt)
|
||||
for token in generator:
|
||||
yield token
|
||||
"""
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
|
||||
for token in self.client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
|
||||
):
|
||||
chunk = GenerationChunk(text=token.completion)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
r"""Call Anthropic completion_stream and return the resulting generator.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
A generator representing the stream of tokens from Anthropic.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
prompt = "Write a poem about a stream."
|
||||
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||
generator = anthropic.stream(prompt)
|
||||
for token in generator:
|
||||
yield token
|
||||
"""
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
|
||||
async for token in await self.async_client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
stream=True,
|
||||
**params,
|
||||
):
|
||||
chunk = GenerationChunk(text=token.completion)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
if not self.count_tokens:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
return self.count_tokens(text)
|
||||
@@ -1,126 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CloudflareWorkersAI(LLM):
|
||||
"""Langchain LLM class to help to access Cloudflare Workers AI service.
|
||||
|
||||
To use, you must provide an API token and
|
||||
account ID to access Cloudflare Workers AI, and
|
||||
pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI
|
||||
|
||||
my_account_id = "my_account_id"
|
||||
my_api_token = "my_secret_api_token"
|
||||
llm_model = "@cf/meta/llama-2-7b-chat-int8"
|
||||
|
||||
cf_ai = CloudflareWorkersAI(
|
||||
account_id=my_account_id,
|
||||
api_token=my_api_token,
|
||||
model=llm_model
|
||||
)
|
||||
""" # noqa: E501
|
||||
|
||||
account_id: str
|
||||
api_token: str
|
||||
model: str = "@cf/meta/llama-2-7b-chat-int8"
|
||||
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
|
||||
streaming: bool = False
|
||||
endpoint_url: str = ""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the Cloudflare Workers AI class."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of LLM."""
|
||||
return "cloudflare"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Default parameters"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Identifying parameters"""
|
||||
return {
|
||||
"account_id": self.account_id,
|
||||
"api_token": self.api_token,
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
}
|
||||
|
||||
def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response:
|
||||
"""Call Cloudflare Workers API"""
|
||||
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||||
data = {"prompt": prompt, "stream": self.streaming, **params}
|
||||
response = requests.post(self.endpoint_url, headers=headers, json=data)
|
||||
return response
|
||||
|
||||
def _process_response(self, response: requests.Response) -> str:
|
||||
"""Process API response"""
|
||||
if response.ok:
|
||||
data = response.json()
|
||||
return data["result"]["response"]
|
||||
else:
|
||||
raise ValueError(f"Request failed with status {response.status_code}")
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Streaming prediction"""
|
||||
original_steaming: bool = self.streaming
|
||||
self.streaming = True
|
||||
_response_prefix_count = len("data: ")
|
||||
_response_stream_end = b"data: [DONE]"
|
||||
for chunk in self._call_api(prompt, kwargs).iter_lines():
|
||||
if chunk == _response_stream_end:
|
||||
break
|
||||
if len(chunk) > _response_prefix_count:
|
||||
try:
|
||||
data = json.loads(chunk[_response_prefix_count:])
|
||||
except Exception as e:
|
||||
logger.debug(chunk)
|
||||
raise e
|
||||
if data is not None and "response" in data:
|
||||
yield GenerationChunk(text=data["response"])
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(data["response"])
|
||||
logger.debug("stream end")
|
||||
self.streaming = original_steaming
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Regular prediction"""
|
||||
if self.streaming:
|
||||
return "".join(
|
||||
[c.text for c in self._stream(prompt, stop, run_manager, **kwargs)]
|
||||
)
|
||||
else:
|
||||
response = self._call_api(prompt, kwargs)
|
||||
return self._process_response(response)
|
||||
@@ -1,106 +0,0 @@
|
||||
"""**Retriever** class returns Documents given a text **query**.
|
||||
|
||||
It is more general than a vector store. A retriever does not need to be able to
|
||||
store documents, only to return (or retrieve) it. Vector stores can be used as
|
||||
the backbone of a retriever, but there are other types of retrievers as well.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseRetriever --> <name>Retriever # Examples: ArxivRetriever, MergerRetriever
|
||||
|
||||
**Main helpers:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
Document, Serializable, Callbacks,
|
||||
CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
|
||||
"""
|
||||
|
||||
from langchain_community.retrievers.arcee import ArceeRetriever
|
||||
from langchain_community.retrievers.arxiv import ArxivRetriever
|
||||
from langchain_community.retrievers.azure_cognitive_search import (
|
||||
AzureCognitiveSearchRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.bedrock import AmazonKnowledgeBasesRetriever
|
||||
from langchain_community.retrievers.bm25 import BM25Retriever
|
||||
from langchain_community.retrievers.chaindesk import ChaindeskRetriever
|
||||
from langchain_community.retrievers.chatgpt_plugin_retriever import (
|
||||
ChatGPTPluginRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.cohere_rag_retriever import CohereRagRetriever
|
||||
from langchain_community.retrievers.docarray import DocArrayRetriever
|
||||
from langchain_community.retrievers.elastic_search_bm25 import (
|
||||
ElasticSearchBM25Retriever,
|
||||
)
|
||||
from langchain_community.retrievers.embedchain import EmbedchainRetriever
|
||||
from langchain_community.retrievers.google_cloud_documentai_warehouse import (
|
||||
GoogleDocumentAIWarehouseRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.google_vertex_ai_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
GoogleVertexAIMultiTurnSearchRetriever,
|
||||
GoogleVertexAISearchRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.kay import KayAiRetriever
|
||||
from langchain_community.retrievers.kendra import AmazonKendraRetriever
|
||||
from langchain_community.retrievers.knn import KNNRetriever
|
||||
from langchain_community.retrievers.llama_index import (
|
||||
LlamaIndexGraphRetriever,
|
||||
LlamaIndexRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.metal import MetalRetriever
|
||||
from langchain_community.retrievers.milvus import MilvusRetriever
|
||||
from langchain_community.retrievers.outline import OutlineRetriever
|
||||
from langchain_community.retrievers.pinecone_hybrid_search import (
|
||||
PineconeHybridSearchRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.pubmed import PubMedRetriever
|
||||
from langchain_community.retrievers.remote_retriever import RemoteLangChainRetriever
|
||||
from langchain_community.retrievers.svm import SVMRetriever
|
||||
from langchain_community.retrievers.tavily_search_api import TavilySearchAPIRetriever
|
||||
from langchain_community.retrievers.tfidf import TFIDFRetriever
|
||||
from langchain_community.retrievers.weaviate_hybrid_search import (
|
||||
WeaviateHybridSearchRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.wikipedia import WikipediaRetriever
|
||||
from langchain_community.retrievers.zep import ZepRetriever
|
||||
from langchain_community.retrievers.zilliz import ZillizRetriever
|
||||
|
||||
__all__ = [
|
||||
"AmazonKendraRetriever",
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
"ChaindeskRetriever",
|
||||
"CohereRagRetriever",
|
||||
"ElasticSearchBM25Retriever",
|
||||
"EmbedchainRetriever",
|
||||
"GoogleDocumentAIWarehouseRetriever",
|
||||
"GoogleCloudEnterpriseSearchRetriever",
|
||||
"GoogleVertexAIMultiTurnSearchRetriever",
|
||||
"GoogleVertexAISearchRetriever",
|
||||
"KayAiRetriever",
|
||||
"KNNRetriever",
|
||||
"LlamaIndexGraphRetriever",
|
||||
"LlamaIndexRetriever",
|
||||
"MetalRetriever",
|
||||
"MilvusRetriever",
|
||||
"OutlineRetriever",
|
||||
"PineconeHybridSearchRetriever",
|
||||
"PubMedRetriever",
|
||||
"RemoteLangChainRetriever",
|
||||
"SVMRetriever",
|
||||
"TavilySearchAPIRetriever",
|
||||
"TFIDFRetriever",
|
||||
"BM25Retriever",
|
||||
"VespaRetriever",
|
||||
"WeaviateHybridSearchRetriever",
|
||||
"WikipediaRetriever",
|
||||
"ZepRetriever",
|
||||
"ZillizRetriever",
|
||||
"DocArrayRetriever",
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Implementations of key-value stores and storage helpers.
|
||||
|
||||
Module provides implementations of various key-value stores that conform
|
||||
to a simple key-value interface.
|
||||
|
||||
The primary goal of these storages is to support implementation of caching.
|
||||
"""
|
||||
|
||||
from langchain_community.storage.redis import RedisStore
|
||||
from langchain_community.storage.upstash_redis import (
|
||||
UpstashRedisByteStore,
|
||||
UpstashRedisStore,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RedisStore",
|
||||
"UpstashRedisByteStore",
|
||||
"UpstashRedisStore",
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.tools.amadeus.base import AmadeusBaseTool
|
||||
|
||||
|
||||
class ClosestAirportSchema(BaseModel):
|
||||
"""Schema for the AmadeusClosestAirport tool."""
|
||||
|
||||
location: str = Field(
|
||||
description=(
|
||||
" The location for which you would like to find the nearest airport "
|
||||
" along with optional details such as country, state, region, or "
|
||||
" province, allowing for easy processing and identification of "
|
||||
" the closest airport. Examples of the format are the following:\n"
|
||||
" Cali, Colombia\n "
|
||||
" Lincoln, Nebraska, United States\n"
|
||||
" New York, United States\n"
|
||||
" Sydney, New South Wales, Australia\n"
|
||||
" Rome, Lazio, Italy\n"
|
||||
" Toronto, Ontario, Canada\n"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AmadeusClosestAirport(AmadeusBaseTool):
|
||||
"""Tool for finding the closest airport to a particular location."""
|
||||
|
||||
name: str = "closest_airport"
|
||||
description: str = (
|
||||
"Use this tool to find the closest airport to a particular location."
|
||||
)
|
||||
args_schema: Type[ClosestAirportSchema] = ClosestAirportSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
location: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
content = (
|
||||
f" What is the nearest airport to {location}? Please respond with the "
|
||||
" airport's International Air Transport Association (IATA) Location "
|
||||
' Identifier in the following JSON format. JSON: "iataCode": "IATA '
|
||||
' Location Identifier" '
|
||||
)
|
||||
|
||||
return ChatOpenAI(temperature=0).predict(content)
|
||||
@@ -1,42 +0,0 @@
|
||||
"""
|
||||
This tool allows agents to interact with the clickup library
|
||||
and operate on a Clickup instance.
|
||||
To use this tool, you must first set as environment variables:
|
||||
client_secret
|
||||
client_id
|
||||
code
|
||||
|
||||
Below is a sample script that uses the Clickup tool:
|
||||
|
||||
```python
|
||||
from langchain_community.agent_toolkits.clickup.toolkit import ClickupToolkit
|
||||
from langchain_community.utilities.clickup import ClickupAPIWrapper
|
||||
|
||||
clickup = ClickupAPIWrapper()
|
||||
toolkit = ClickupToolkit.from_clickup_api_wrapper(clickup)
|
||||
```
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.utilities.clickup import ClickupAPIWrapper
|
||||
|
||||
|
||||
class ClickupAction(BaseTool):
|
||||
"""Tool that queries the Clickup API."""
|
||||
|
||||
api_wrapper: ClickupAPIWrapper = Field(default_factory=ClickupAPIWrapper)
|
||||
mode: str
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
instructions: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the Clickup API to run an operation."""
|
||||
return self.api_wrapper.run(self.mode, instructions)
|
||||
@@ -1,44 +0,0 @@
|
||||
"""
|
||||
This tool allows agents to interact with the atlassian-python-api library
|
||||
and operate on a Jira instance. For more information on the
|
||||
atlassian-python-api library, see https://atlassian-python-api.readthedocs.io/jira.html
|
||||
|
||||
To use this tool, you must first set as environment variables:
|
||||
JIRA_API_TOKEN
|
||||
JIRA_USERNAME
|
||||
JIRA_INSTANCE_URL
|
||||
|
||||
Below is a sample script that uses the Jira tool:
|
||||
|
||||
```python
|
||||
from langchain_community.agent_toolkits.jira.toolkit import JiraToolkit
|
||||
from langchain_community.utilities.jira import JiraAPIWrapper
|
||||
|
||||
jira = JiraAPIWrapper()
|
||||
toolkit = JiraToolkit.from_jira_api_wrapper(jira)
|
||||
```
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.utilities.jira import JiraAPIWrapper
|
||||
|
||||
|
||||
class JiraAction(BaseTool):
|
||||
"""Tool that queries the Atlassian Jira API."""
|
||||
|
||||
api_wrapper: JiraAPIWrapper = Field(default_factory=JiraAPIWrapper)
|
||||
mode: str
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
instructions: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the Atlassian Jira API to run an operation."""
|
||||
return self.api_wrapper.run(self.mode, instructions)
|
||||
@@ -1,276 +0,0 @@
|
||||
"""Tools for interacting with a Power BI dataset."""
|
||||
import logging
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field, validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_community.chat_models.openai import _import_tiktoken
|
||||
|
||||
from langchain_community.tools.powerbi.prompt import (
|
||||
BAD_REQUEST_RESPONSE,
|
||||
DEFAULT_FEWSHOT_EXAMPLES,
|
||||
RETRY_RESPONSE,
|
||||
)
|
||||
from langchain_community.utilities.powerbi import PowerBIDataset, json_to_md
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryPowerBITool(BaseTool):
|
||||
"""Tool for querying a Power BI Dataset."""
|
||||
|
||||
name: str = "query_powerbi"
|
||||
description: str = """
|
||||
Input to this tool is a detailed question about the dataset, output is a result from the dataset. It will try to answer the question using the dataset, and if it cannot, it will ask for clarification.
|
||||
|
||||
Example Input: "How many rows are in table1?"
|
||||
""" # noqa: E501
|
||||
llm_chain: Any
|
||||
powerbi: PowerBIDataset = Field(exclude=True)
|
||||
examples: Optional[str] = DEFAULT_FEWSHOT_EXAMPLES
|
||||
session_cache: Dict[str, Any] = Field(default_factory=dict, exclude=True)
|
||||
max_iterations: int = 5
|
||||
output_token_limit: int = 4000
|
||||
tiktoken_model_name: Optional[str] = None # "cl100k_base"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("llm_chain")
|
||||
def validate_llm_chain_input_variables( # pylint: disable=E0213
|
||||
cls, llm_chain: Any
|
||||
) -> Any:
|
||||
"""Make sure the LLM chain has the correct input variables."""
|
||||
for var in llm_chain.prompt.input_variables:
|
||||
if var not in ["tool_input", "tables", "schemas", "examples"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: C0301 E501 # pylint: disable=C0301
|
||||
llm_chain.prompt.input_variables,
|
||||
)
|
||||
return llm_chain
|
||||
|
||||
def _check_cache(self, tool_input: str) -> Optional[str]:
|
||||
"""Check if the input is present in the cache.
|
||||
|
||||
If the value is a bad request, overwrite with the escalated version,
|
||||
if not present return None."""
|
||||
if tool_input not in self.session_cache:
|
||||
return None
|
||||
return self.session_cache[tool_input]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
if cache := self._check_cache(tool_input):
|
||||
logger.debug("Found cached result for %s: %s", tool_input, cache)
|
||||
return cache
|
||||
|
||||
try:
|
||||
logger.info("Running PBI Query Tool with input: %s", tool_input)
|
||||
query = self.llm_chain.predict(
|
||||
tool_input=tool_input,
|
||||
tables=self.powerbi.get_table_names(),
|
||||
schemas=self.powerbi.get_schemas(),
|
||||
examples=self.examples,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
self.session_cache[tool_input] = f"Error on call to LLM: {exc}"
|
||||
return self.session_cache[tool_input]
|
||||
if query == "I cannot answer this":
|
||||
self.session_cache[tool_input] = query
|
||||
return self.session_cache[tool_input]
|
||||
logger.info("PBI Query:\n%s", query)
|
||||
start_time = perf_counter()
|
||||
pbi_result = self.powerbi.run(command=query)
|
||||
end_time = perf_counter()
|
||||
logger.debug("PBI Result: %s", pbi_result)
|
||||
logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}")
|
||||
result, error = self._parse_output(pbi_result)
|
||||
if error is not None and "TokenExpired" in error:
|
||||
self.session_cache[
|
||||
tool_input
|
||||
] = "Authentication token expired or invalid, please try reauthenticate."
|
||||
return self.session_cache[tool_input]
|
||||
|
||||
iterations = kwargs.get("iterations", 0)
|
||||
if error and iterations < self.max_iterations:
|
||||
return self._run(
|
||||
tool_input=RETRY_RESPONSE.format(
|
||||
tool_input=tool_input, query=query, error=error
|
||||
),
|
||||
run_manager=run_manager,
|
||||
iterations=iterations + 1,
|
||||
)
|
||||
|
||||
self.session_cache[tool_input] = (
|
||||
result if result else BAD_REQUEST_RESPONSE.format(error=error)
|
||||
)
|
||||
return self.session_cache[tool_input]
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
tool_input: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
if cache := self._check_cache(tool_input):
|
||||
logger.debug("Found cached result for %s: %s", tool_input, cache)
|
||||
return f"{cache}, from cache, you have already asked this question."
|
||||
try:
|
||||
logger.info("Running PBI Query Tool with input: %s", tool_input)
|
||||
query = await self.llm_chain.apredict(
|
||||
tool_input=tool_input,
|
||||
tables=self.powerbi.get_table_names(),
|
||||
schemas=self.powerbi.get_schemas(),
|
||||
examples=self.examples,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
self.session_cache[tool_input] = f"Error on call to LLM: {exc}"
|
||||
return self.session_cache[tool_input]
|
||||
|
||||
if query == "I cannot answer this":
|
||||
self.session_cache[tool_input] = query
|
||||
return self.session_cache[tool_input]
|
||||
logger.info("PBI Query: %s", query)
|
||||
start_time = perf_counter()
|
||||
pbi_result = await self.powerbi.arun(command=query)
|
||||
end_time = perf_counter()
|
||||
logger.debug("PBI Result: %s", pbi_result)
|
||||
logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}")
|
||||
result, error = self._parse_output(pbi_result)
|
||||
if error is not None and ("TokenExpired" in error or "TokenError" in error):
|
||||
self.session_cache[
|
||||
tool_input
|
||||
] = "Authentication token expired or invalid, please try to reauthenticate or check the scope of the credential." # noqa: E501
|
||||
return self.session_cache[tool_input]
|
||||
|
||||
iterations = kwargs.get("iterations", 0)
|
||||
if error and iterations < self.max_iterations:
|
||||
return await self._arun(
|
||||
tool_input=RETRY_RESPONSE.format(
|
||||
tool_input=tool_input, query=query, error=error
|
||||
),
|
||||
run_manager=run_manager,
|
||||
iterations=iterations + 1,
|
||||
)
|
||||
|
||||
self.session_cache[tool_input] = (
|
||||
result if result else BAD_REQUEST_RESPONSE.format(error=error)
|
||||
)
|
||||
return self.session_cache[tool_input]
|
||||
|
||||
def _parse_output(
|
||||
self, pbi_result: Dict[str, Any]
|
||||
) -> Tuple[Optional[str], Optional[Any]]:
|
||||
"""Parse the output of the query to a markdown table."""
|
||||
if "results" in pbi_result:
|
||||
rows = pbi_result["results"][0]["tables"][0]["rows"]
|
||||
if len(rows) == 0:
|
||||
logger.info("0 records in result, query was valid.")
|
||||
return (
|
||||
None,
|
||||
"0 rows returned, this might be correct, but please validate if all filter values were correct?", # noqa: E501
|
||||
)
|
||||
result = json_to_md(rows)
|
||||
too_long, length = self._result_too_large(result)
|
||||
if too_long:
|
||||
return (
|
||||
f"Result too large, please try to be more specific or use the `TOPN` function. The result is {length} tokens long, the limit is {self.output_token_limit} tokens.", # noqa: E501
|
||||
None,
|
||||
)
|
||||
return result, None
|
||||
|
||||
if "error" in pbi_result:
|
||||
if (
|
||||
"pbi.error" in pbi_result["error"]
|
||||
and "details" in pbi_result["error"]["pbi.error"]
|
||||
):
|
||||
return None, pbi_result["error"]["pbi.error"]["details"][0]["detail"]
|
||||
return None, pbi_result["error"]
|
||||
return None, pbi_result
|
||||
|
||||
def _result_too_large(self, result: str) -> Tuple[bool, int]:
|
||||
"""Tokenize the output of the query."""
|
||||
if self.tiktoken_model_name:
|
||||
tiktoken_ = _import_tiktoken()
|
||||
encoding = tiktoken_.encoding_for_model(self.tiktoken_model_name)
|
||||
length = len(encoding.encode(result))
|
||||
logger.info("Result length: %s", length)
|
||||
return length > self.output_token_limit, length
|
||||
return False, 0
|
||||
|
||||
|
||||
class InfoPowerBITool(BaseTool):
|
||||
"""Tool for getting metadata about a PowerBI Dataset."""
|
||||
|
||||
name: str = "schema_powerbi"
|
||||
description: str = """
|
||||
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
|
||||
Be sure that the tables actually exist by calling list_tables_powerbi first!
|
||||
|
||||
Example Input: "table1, table2, table3"
|
||||
""" # noqa: E501
|
||||
powerbi: PowerBIDataset = Field(exclude=True)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for tables in a comma-separated list."""
|
||||
return self.powerbi.get_table_info(tool_input.split(", "))
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
tool_input: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
return await self.powerbi.aget_table_info(tool_input.split(", "))
|
||||
|
||||
|
||||
class ListPowerBITool(BaseTool):
|
||||
"""Tool for getting tables names."""
|
||||
|
||||
name: str = "list_tables_powerbi"
|
||||
description: str = "Input is an empty string, output is a comma separated list of tables in the database." # noqa: E501 # pylint: disable=C0301
|
||||
powerbi: PowerBIDataset = Field(exclude=True)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: Optional[str] = None,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the names of the tables."""
|
||||
return ", ".join(self.powerbi.get_table_names())
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
tool_input: Optional[str] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the names of the tables."""
|
||||
return ", ".join(self.powerbi.get_table_names())
|
||||
@@ -1,130 +0,0 @@
|
||||
# flake8: noqa
|
||||
"""Tools for interacting with Spark SQL."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_community.utilities.spark_sql import SparkSQL
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_community.tools.spark_sql.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class BaseSparkSQLTool(BaseModel):
|
||||
"""Base tool for interacting with Spark SQL."""
|
||||
|
||||
db: SparkSQL = Field(exclude=True)
|
||||
|
||||
class Config(BaseTool.Config):
|
||||
pass
|
||||
|
||||
|
||||
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
|
||||
"""Tool for querying a Spark SQL."""
|
||||
|
||||
name: str = "query_sql_db"
|
||||
description: str = """
|
||||
Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL.
|
||||
If the query is not correct, an error message will be returned.
|
||||
If an error is returned, rewrite the query, check the query, and try again.
|
||||
"""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
return self.db.run_no_throw(query)
|
||||
|
||||
|
||||
class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
|
||||
"""Tool for getting metadata about a Spark SQL."""
|
||||
|
||||
name: str = "schema_sql_db"
|
||||
description: str = """
|
||||
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
|
||||
Be sure that the tables actually exist by calling list_tables_sql_db first!
|
||||
|
||||
Example Input: "table1, table2, table3"
|
||||
"""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
table_names: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for tables in a comma-separated list."""
|
||||
return self.db.get_table_info_no_throw(table_names.split(", "))
|
||||
|
||||
|
||||
class ListSparkSQLTool(BaseSparkSQLTool, BaseTool):
|
||||
"""Tool for getting tables names."""
|
||||
|
||||
name: str = "list_tables_sql_db"
|
||||
description: str = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: str = "",
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for a specific table."""
|
||||
return ", ".join(self.db.get_usable_table_names())
|
||||
|
||||
|
||||
class QueryCheckerTool(BaseSparkSQLTool, BaseTool):
|
||||
"""Use an LLM to check if a query is correct.
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
|
||||
template: str = QUERY_CHECKER
|
||||
llm: BaseLanguageModel
|
||||
llm_chain: Any = Field(init=False)
|
||||
name: str = "query_checker_sql_db"
|
||||
description: str = """
|
||||
Use this tool to double check if your query is correct before executing it.
|
||||
Always use this tool before executing a query with query_sql_db!
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
from langchain.chains.llm import LLMChain
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["query"]
|
||||
),
|
||||
)
|
||||
|
||||
if values["llm_chain"].prompt.input_variables != ["query"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryCheckerTool need to use ['query'] as input_variables "
|
||||
"for the embedded prompt"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the LLM to check the query."""
|
||||
return self.llm_chain.predict(
|
||||
query=query, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
return await self.llm_chain.apredict(
|
||||
query=query, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
@@ -1,134 +0,0 @@
|
||||
# flake8: noqa
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
"""Base tool for interacting with a SQL database."""
|
||||
|
||||
db: SQLDatabase = Field(exclude=True)
|
||||
|
||||
class Config(BaseTool.Config):
|
||||
pass
|
||||
|
||||
|
||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for querying a SQL database."""
|
||||
|
||||
name: str = "sql_db_query"
|
||||
description: str = """
|
||||
Input to this tool is a detailed and correct SQL query, output is a result from the database.
|
||||
If the query is not correct, an error message will be returned.
|
||||
If an error is returned, rewrite the query, check the query, and try again.
|
||||
"""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
return self.db.run_no_throw(query)
|
||||
|
||||
|
||||
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting metadata about a SQL database."""
|
||||
|
||||
name: str = "sql_db_schema"
|
||||
description: str = """
|
||||
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
|
||||
|
||||
Example Input: "table1, table2, table3"
|
||||
"""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
table_names: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for tables in a comma-separated list."""
|
||||
return self.db.get_table_info_no_throw(
|
||||
[t.strip() for t in table_names.split(",")]
|
||||
)
|
||||
|
||||
|
||||
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting tables names."""
|
||||
|
||||
name: str = "sql_db_list_tables"
|
||||
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: str = "",
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for a specific table."""
|
||||
return ", ".join(self.db.get_usable_table_names())
|
||||
|
||||
|
||||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Use an LLM to check if a query is correct.
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
|
||||
template: str = QUERY_CHECKER
|
||||
llm: BaseLanguageModel
|
||||
llm_chain: Any = Field(init=False)
|
||||
name: str = "sql_db_query_checker"
|
||||
description: str = """
|
||||
Use this tool to double check if your query is correct before executing it.
|
||||
Always use this tool before executing a query with sql_db_query!
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
from langchain.chains.llm import LLMChain
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
),
|
||||
)
|
||||
|
||||
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the LLM to check the query."""
|
||||
return self.llm_chain.predict(
|
||||
query=query,
|
||||
dialect=self.db.dialect,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
return await self.llm_chain.apredict(
|
||||
query=query,
|
||||
dialect=self.db.dialect,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
@@ -1,215 +0,0 @@
|
||||
"""[DEPRECATED]
|
||||
|
||||
## Zapier Natural Language Actions API
|
||||
\
|
||||
Full docs here: https://nla.zapier.com/start/
|
||||
|
||||
**Zapier Natural Language Actions** gives you access to the 5k+ apps, 20k+ actions
|
||||
on Zapier's platform through a natural language API interface.
|
||||
|
||||
NLA supports apps like Gmail, Salesforce, Trello, Slack, Asana, HubSpot, Google Sheets,
|
||||
Microsoft Teams, and thousands more apps: https://zapier.com/apps
|
||||
|
||||
Zapier NLA handles ALL the underlying API auth and translation from
|
||||
natural language --> underlying API call --> return simplified output for LLMs
|
||||
The key idea is you, or your users, expose a set of actions via an oauth-like setup
|
||||
window, which you can then query and execute via a REST API.
|
||||
|
||||
NLA offers both API Key and OAuth for signing NLA API requests.
|
||||
|
||||
1. Server-side (API Key): for quickly getting started, testing, and production scenarios
|
||||
where LangChain will only use actions exposed in the developer's Zapier account
|
||||
(and will use the developer's connected accounts on Zapier.com)
|
||||
|
||||
2. User-facing (Oauth): for production scenarios where you are deploying an end-user
|
||||
facing application and LangChain needs access to end-user's exposed actions and
|
||||
connected accounts on Zapier.com
|
||||
|
||||
This quick start will focus on the server-side use case for brevity.
|
||||
Review [full docs](https://nla.zapier.com/start/) for user-facing oauth developer
|
||||
support.
|
||||
|
||||
Typically, you'd use SequentialChain, here's a basic example:
|
||||
|
||||
1. Use NLA to find an email in Gmail
|
||||
2. Use LLMChain to generate a draft reply to (1)
|
||||
3. Use NLA to send the draft reply (2) to someone in Slack via direct message
|
||||
|
||||
In code, below:
|
||||
|
||||
```python
|
||||
|
||||
import os
|
||||
|
||||
# get from https://platform.openai.com/
|
||||
os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "")
|
||||
|
||||
# get from https://nla.zapier.com/docs/authentication/
|
||||
os.environ["ZAPIER_NLA_API_KEY"] = os.environ.get("ZAPIER_NLA_API_KEY", "")
|
||||
|
||||
from langchain_community.agent_toolkits import ZapierToolkit
|
||||
from langchain_community.utilities.zapier import ZapierNLAWrapper
|
||||
|
||||
## step 0. expose gmail 'find email' and slack 'send channel message' actions
|
||||
|
||||
# first go here, log in, expose (enable) the two actions:
|
||||
# https://nla.zapier.com/demo/start
|
||||
# -- for this example, can leave all fields "Have AI guess"
|
||||
# in an oauth scenario, you'd get your own <provider> id (instead of 'demo')
|
||||
# which you route your users through first
|
||||
|
||||
zapier = ZapierNLAWrapper()
|
||||
## To leverage OAuth you may pass the value `nla_oauth_access_token` to
|
||||
## the ZapierNLAWrapper. If you do this there is no need to initialize
|
||||
## the ZAPIER_NLA_API_KEY env variable
|
||||
# zapier = ZapierNLAWrapper(zapier_nla_oauth_access_token="TOKEN_HERE")
|
||||
toolkit = ZapierToolkit.from_zapier_nla_wrapper(zapier)
|
||||
```
|
||||
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
|
||||
from langchain_community.utilities.zapier import ZapierNLAWrapper
|
||||
|
||||
|
||||
class ZapierNLARunAction(BaseTool):
|
||||
"""
|
||||
Args:
|
||||
action_id: a specific action ID (from list actions) of the action to execute
|
||||
(the set api_key must be associated with the action owner)
|
||||
instructions: a natural language instruction string for using the action
|
||||
(eg. "get the latest email from Mike Knoop" for "Gmail: find email" action)
|
||||
params: a dict, optional. Any params provided will *override* AI guesses
|
||||
from `instructions` (see "understanding the AI guessing flow" here:
|
||||
https://nla.zapier.com/docs/using-the-api#ai-guessing)
|
||||
|
||||
"""
|
||||
|
||||
api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper)
|
||||
action_id: str
|
||||
params: Optional[dict] = None
|
||||
base_prompt: str = BASE_ZAPIER_TOOL_PROMPT
|
||||
zapier_description: str
|
||||
params_schema: Dict[str, str] = Field(default_factory=dict)
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
|
||||
@root_validator
|
||||
def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
zapier_description = values["zapier_description"]
|
||||
params_schema = values["params_schema"]
|
||||
if "instructions" in params_schema:
|
||||
del params_schema["instructions"]
|
||||
|
||||
# Ensure base prompt (if overridden) contains necessary input fields
|
||||
necessary_fields = {"{zapier_description}", "{params}"}
|
||||
if not all(field in values["base_prompt"] for field in necessary_fields):
|
||||
raise ValueError(
|
||||
"Your custom base Zapier prompt must contain input fields for "
|
||||
"{zapier_description} and {params}."
|
||||
)
|
||||
|
||||
values["name"] = zapier_description
|
||||
values["description"] = values["base_prompt"].format(
|
||||
zapier_description=zapier_description,
|
||||
params=str(list(params_schema.keys())),
|
||||
)
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self, instructions: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
"""Use the Zapier NLA tool to return a list of all exposed user actions."""
|
||||
warn_deprecated(
|
||||
since="0.0.319",
|
||||
message=(
|
||||
"This tool will be deprecated on 2023-11-17. See "
|
||||
"https://nla.zapier.com/sunset/ for details"
|
||||
),
|
||||
)
|
||||
return self.api_wrapper.run_as_str(self.action_id, instructions, self.params)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
instructions: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the Zapier NLA tool to return a list of all exposed user actions."""
|
||||
warn_deprecated(
|
||||
since="0.0.319",
|
||||
message=(
|
||||
"This tool will be deprecated on 2023-11-17. See "
|
||||
"https://nla.zapier.com/sunset/ for details"
|
||||
),
|
||||
)
|
||||
return await self.api_wrapper.arun_as_str(
|
||||
self.action_id,
|
||||
instructions,
|
||||
self.params,
|
||||
)
|
||||
|
||||
|
||||
ZapierNLARunAction.__doc__ = (
|
||||
ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore
|
||||
)
|
||||
|
||||
|
||||
# other useful actions
|
||||
|
||||
|
||||
class ZapierNLAListActions(BaseTool):
|
||||
"""
|
||||
Args:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
name: str = "ZapierNLA_list_actions"
|
||||
description: str = BASE_ZAPIER_TOOL_PROMPT + (
|
||||
"This tool returns a list of the user's exposed actions."
|
||||
)
|
||||
api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
_: str = "",
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the Zapier NLA tool to return a list of all exposed user actions."""
|
||||
warn_deprecated(
|
||||
since="0.0.319",
|
||||
message=(
|
||||
"This tool will be deprecated on 2023-11-17. See "
|
||||
"https://nla.zapier.com/sunset/ for details"
|
||||
),
|
||||
)
|
||||
return self.api_wrapper.list_as_str()
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
_: str = "",
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the Zapier NLA tool to return a list of all exposed user actions."""
|
||||
warn_deprecated(
|
||||
since="0.0.319",
|
||||
message=(
|
||||
"This tool will be deprecated on 2023-11-17. See "
|
||||
"https://nla.zapier.com/sunset/ for details"
|
||||
),
|
||||
)
|
||||
return await self.api_wrapper.alist_as_str()
|
||||
|
||||
|
||||
ZapierNLAListActions.__doc__ = (
|
||||
ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore
|
||||
)
|
||||
@@ -1,283 +0,0 @@
|
||||
"""Integration tests for the langchain tracer module."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from langchain_core.callbacks.manager import atrace_as_chain_group, trace_as_chain_group
|
||||
from langchain_core.tracers.context import tracing_v2_enabled, tracing_enabled
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.llms import OpenAI
|
||||
|
||||
questions = [
|
||||
(
|
||||
"Who won the US Open men's final in 2019? "
|
||||
"What is his age raised to the 0.334 power?"
|
||||
),
|
||||
(
|
||||
"Who is Olivia Wilde's boyfriend? "
|
||||
"What is his current age raised to the 0.23 power?"
|
||||
),
|
||||
(
|
||||
"Who won the most recent formula 1 grand prix? "
|
||||
"What is their age raised to the 0.23 power?"
|
||||
),
|
||||
(
|
||||
"Who won the US Open women's final in 2019? "
|
||||
"What is her age raised to the 0.34 power?"
|
||||
),
|
||||
("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"),
|
||||
]
|
||||
|
||||
|
||||
def test_tracing_sequential() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_TRACING"] = "true"
|
||||
|
||||
for q in questions[:3]:
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(q)
|
||||
|
||||
|
||||
def test_tracing_session_env_var() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_TRACING"] = "true"
|
||||
os.environ["LANGCHAIN_SESSION"] = "my_session"
|
||||
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(questions[0])
|
||||
if "LANGCHAIN_SESSION" in os.environ:
|
||||
del os.environ["LANGCHAIN_SESSION"]
|
||||
|
||||
|
||||
async def test_tracing_concurrent() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_TRACING"] = "true"
|
||||
aiosession = ClientSession()
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
tasks = [agent.arun(q) for q in questions[:3]]
|
||||
await asyncio.gather(*tasks)
|
||||
await aiosession.close()
|
||||
|
||||
|
||||
async def test_tracing_concurrent_bw_compat_environ() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
if "LANGCHAIN_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING"]
|
||||
aiosession = ClientSession()
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
tasks = [agent.arun(q) for q in questions[:3]]
|
||||
await asyncio.gather(*tasks)
|
||||
await aiosession.close()
|
||||
if "LANGCHAIN_HANDLER" in os.environ:
|
||||
del os.environ["LANGCHAIN_HANDLER"]
|
||||
|
||||
|
||||
def test_tracing_context_manager() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING"]
|
||||
with tracing_enabled() as session:
|
||||
assert session
|
||||
agent.run(questions[0]) # this should be traced
|
||||
|
||||
agent.run(questions[0]) # this should not be traced
|
||||
|
||||
|
||||
async def test_tracing_context_manager_async() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING"]
|
||||
|
||||
# start a background task
|
||||
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
|
||||
with tracing_enabled() as session:
|
||||
assert session
|
||||
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await task
|
||||
|
||||
|
||||
async def test_tracing_v2_environment_variable() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
|
||||
aiosession = ClientSession()
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
tasks = [agent.arun(q) for q in questions[:3]]
|
||||
await asyncio.gather(*tasks)
|
||||
await aiosession.close()
|
||||
|
||||
|
||||
def test_tracing_v2_context_manager() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_TRACING_V2" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING_V2"]
|
||||
with tracing_v2_enabled():
|
||||
agent.run(questions[0]) # this should be traced
|
||||
|
||||
agent.run(questions[0]) # this should not be traced
|
||||
|
||||
|
||||
def test_tracing_v2_chain_with_tags() -> None:
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.constitutional_ai.base import ConstitutionalChain
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
llm = OpenAI(temperature=0)
|
||||
chain = ConstitutionalChain.from_llm(
|
||||
llm,
|
||||
chain=LLMChain.from_string(llm, "Q: {question} A:"),
|
||||
tags=["only-root"],
|
||||
constitutional_principles=[
|
||||
ConstitutionalPrinciple(
|
||||
critique_request="Tell if this answer is good.",
|
||||
revision_request="Give a better answer.",
|
||||
)
|
||||
],
|
||||
)
|
||||
if "LANGCHAIN_TRACING_V2" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING_V2"]
|
||||
with tracing_v2_enabled():
|
||||
chain.run("what is the meaning of life", tags=["a-tag"])
|
||||
|
||||
|
||||
def test_tracing_v2_agent_with_metadata() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
llm = OpenAI(temperature=0)
|
||||
chat = ChatOpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
chat_agent = initialize_agent(
|
||||
tools, chat, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
|
||||
chat_agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
|
||||
|
||||
|
||||
async def test_tracing_v2_async_agent_with_metadata() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
llm = OpenAI(temperature=0, metadata={"f": "g", "h": "i"})
|
||||
chat = ChatOpenAI(temperature=0, metadata={"f": "g", "h": "i"})
|
||||
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
chat_agent = initialize_agent(
|
||||
async_tools,
|
||||
chat,
|
||||
agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
await agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
|
||||
await chat_agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
|
||||
|
||||
|
||||
def test_trace_as_group() -> None:
|
||||
from langchain.chains.llm import LLMChain
|
||||
llm = OpenAI(temperature=0.9)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["product"],
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
with trace_as_chain_group("my_group", inputs={"input": "cars"}) as group_manager:
|
||||
chain.run(product="cars", callbacks=group_manager)
|
||||
chain.run(product="computers", callbacks=group_manager)
|
||||
final_res = chain.run(product="toys", callbacks=group_manager)
|
||||
group_manager.on_chain_end({"output": final_res})
|
||||
|
||||
with trace_as_chain_group("my_group_2", inputs={"input": "toys"}) as group_manager:
|
||||
final_res = chain.run(product="toys", callbacks=group_manager)
|
||||
group_manager.on_chain_end({"output": final_res})
|
||||
|
||||
|
||||
def test_trace_as_group_with_env_set() -> None:
|
||||
from langchain.chains.llm import LLMChain
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
llm = OpenAI(temperature=0.9)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["product"],
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
with trace_as_chain_group(
|
||||
"my_group_env_set", inputs={"input": "cars"}
|
||||
) as group_manager:
|
||||
chain.run(product="cars", callbacks=group_manager)
|
||||
chain.run(product="computers", callbacks=group_manager)
|
||||
final_res = chain.run(product="toys", callbacks=group_manager)
|
||||
group_manager.on_chain_end({"output": final_res})
|
||||
|
||||
with trace_as_chain_group(
|
||||
"my_group_2_env_set", inputs={"input": "toys"}
|
||||
) as group_manager:
|
||||
final_res = chain.run(product="toys", callbacks=group_manager)
|
||||
group_manager.on_chain_end({"output": final_res})
|
||||
|
||||
|
||||
async def test_trace_as_group_async() -> None:
|
||||
from langchain.chains.llm import LLMChain
|
||||
llm = OpenAI(temperature=0.9)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["product"],
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
async with atrace_as_chain_group("my_async_group") as group_manager:
|
||||
await chain.arun(product="cars", callbacks=group_manager)
|
||||
await chain.arun(product="computers", callbacks=group_manager)
|
||||
await chain.arun(product="toys", callbacks=group_manager)
|
||||
|
||||
async with atrace_as_chain_group(
|
||||
"my_async_group_2", inputs={"input": "toys"}
|
||||
) as group_manager:
|
||||
res = await asyncio.gather(
|
||||
*[
|
||||
chain.arun(product="toys", callbacks=group_manager),
|
||||
chain.arun(product="computers", callbacks=group_manager),
|
||||
chain.arun(product="cars", callbacks=group_manager),
|
||||
]
|
||||
)
|
||||
await group_manager.on_chain_end({"output": res})
|
||||
@@ -1,68 +0,0 @@
|
||||
"""Integration tests for the langchain tracer module."""
|
||||
import asyncio
|
||||
|
||||
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_community.llms import OpenAI
|
||||
|
||||
|
||||
async def test_openai_callback() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
with get_openai_callback() as cb:
|
||||
llm("What is the square root of 4?")
|
||||
|
||||
total_tokens = cb.total_tokens
|
||||
assert total_tokens > 0
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
llm("What is the square root of 4?")
|
||||
llm("What is the square root of 4?")
|
||||
|
||||
assert cb.total_tokens == total_tokens * 2
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
await asyncio.gather(
|
||||
*[llm.agenerate(["What is the square root of 4?"]) for _ in range(3)]
|
||||
)
|
||||
|
||||
assert cb.total_tokens == total_tokens * 3
|
||||
|
||||
task = asyncio.create_task(llm.agenerate(["What is the square root of 4?"]))
|
||||
with get_openai_callback() as cb:
|
||||
await llm.agenerate(["What is the square root of 4?"])
|
||||
|
||||
await task
|
||||
assert cb.total_tokens == total_tokens
|
||||
|
||||
|
||||
def test_openai_callback_batch_llm() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
with get_openai_callback() as cb:
|
||||
llm.generate(["What is the square root of 4?", "What is the square root of 4?"])
|
||||
|
||||
assert cb.total_tokens > 0
|
||||
total_tokens = cb.total_tokens
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
llm("What is the square root of 4?")
|
||||
llm("What is the square root of 4?")
|
||||
|
||||
assert cb.total_tokens == total_tokens
|
||||
|
||||
|
||||
def test_openai_callback_agent() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
with get_openai_callback() as cb:
|
||||
agent.run(
|
||||
"Who is Olivia Wilde's boyfriend? "
|
||||
"What is his current age raised to the 0.23 power?"
|
||||
)
|
||||
print(f"Total Tokens: {cb.total_tokens}")
|
||||
print(f"Prompt Tokens: {cb.prompt_tokens}")
|
||||
print(f"Completion Tokens: {cb.completion_tokens}")
|
||||
print(f"Total Cost (USD): ${cb.total_cost}")
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Integration tests for the StreamlitCallbackHandler module."""
|
||||
|
||||
import pytest
|
||||
|
||||
# Import the internal StreamlitCallbackHandler from its module - and not from
|
||||
# the `langchain_community.callbacks.streamlit` package - so that we don't end up using
|
||||
# Streamlit's externally-provided callback handler.
|
||||
from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
|
||||
StreamlitCallbackHandler,
|
||||
)
|
||||
from langchain_community.llms import OpenAI
|
||||
|
||||
|
||||
@pytest.mark.requires("streamlit")
|
||||
def test_streamlit_callback_agent() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
import streamlit as st
|
||||
|
||||
streamlit_callback = StreamlitCallbackHandler(st.container())
|
||||
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(
|
||||
"Who is Olivia Wilde's boyfriend? "
|
||||
"What is his current age raised to the 0.23 power?",
|
||||
callbacks=[streamlit_callback],
|
||||
)
|
||||
@@ -1,118 +0,0 @@
|
||||
"""Integration tests for the langchain tracer module."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from langchain_community.callbacks import wandb_tracing_enabled
|
||||
|
||||
from langchain_community.llms import OpenAI
|
||||
|
||||
questions = [
|
||||
(
|
||||
"Who won the US Open men's final in 2019? "
|
||||
"What is his age raised to the 0.334 power?"
|
||||
),
|
||||
(
|
||||
"Who is Olivia Wilde's boyfriend? "
|
||||
"What is his current age raised to the 0.23 power?"
|
||||
),
|
||||
(
|
||||
"Who won the most recent formula 1 grand prix? "
|
||||
"What is their age raised to the 0.23 power?"
|
||||
),
|
||||
(
|
||||
"Who won the US Open women's final in 2019? "
|
||||
"What is her age raised to the 0.34 power?"
|
||||
),
|
||||
("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"),
|
||||
]
|
||||
|
||||
|
||||
def test_tracing_sequential() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||
os.environ["WANDB_PROJECT"] = "langchain-tracing"
|
||||
|
||||
for q in questions[:3]:
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(q)
|
||||
|
||||
|
||||
def test_tracing_session_env_var() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(questions[0])
|
||||
|
||||
|
||||
async def test_tracing_concurrent() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||
aiosession = ClientSession()
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
aiosession=aiosession,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
tasks = [agent.arun(q) for q in questions[:3]]
|
||||
await asyncio.gather(*tasks)
|
||||
await aiosession.close()
|
||||
|
||||
|
||||
def test_tracing_context_manager() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_WANDB_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_WANDB_TRACING"]
|
||||
with wandb_tracing_enabled():
|
||||
agent.run(questions[0]) # this should be traced
|
||||
|
||||
agent.run(questions[0]) # this should not be traced
|
||||
|
||||
|
||||
async def test_tracing_context_manager_async() -> None:
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_WANDB_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING"]
|
||||
|
||||
# start a background task
|
||||
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
|
||||
with wandb_tracing_enabled():
|
||||
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await task
|
||||
@@ -1,333 +0,0 @@
|
||||
"""Test ChatOpenAI wrapper."""
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain_community.chat_models.openai import ChatOpenAI
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai() -> None:
|
||||
"""Test ChatOpenAI wrapper."""
|
||||
chat = ChatOpenAI(
|
||||
temperature=0.7,
|
||||
base_url=None,
|
||||
organization=None,
|
||||
openai_proxy=None,
|
||||
timeout=10.0,
|
||||
max_retries=3,
|
||||
http_client=None,
|
||||
n=1,
|
||||
max_tokens=10,
|
||||
default_headers=None,
|
||||
default_query=None,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_openai_model() -> None:
|
||||
"""Test ChatOpenAI wrapper handles model_name."""
|
||||
chat = ChatOpenAI(model="foo")
|
||||
assert chat.model_name == "foo"
|
||||
chat = ChatOpenAI(model_name="bar")
|
||||
assert chat.model_name == "bar"
|
||||
|
||||
|
||||
def test_chat_openai_system_message() -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_generate() -> None:
|
||||
"""Test ChatOpenAI wrapper with generate."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
assert response.llm_output
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_multiple_completions() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=5)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat._generate([message])
|
||||
assert isinstance(response, ChatResult)
|
||||
assert len(response.generations) == 5
|
||||
for generation in response.generations:
|
||||
assert isinstance(generation.message, BaseMessage)
|
||||
assert isinstance(generation.message.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, BaseMessage)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_streaming_generation_info() -> None:
|
||||
"""Test that generation info is preserved when streaming."""
|
||||
|
||||
class _FakeCallback(FakeCallbackHandler):
|
||||
saved_things: dict = {}
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
# Save the generation
|
||||
self.saved_things["generation"] = args[0]
|
||||
|
||||
callback = _FakeCallback()
|
||||
callback_manager = CallbackManager([callback])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=2,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
list(chat.stream("hi"))
|
||||
generation = callback.saved_things["generation"]
|
||||
# `Hello!` is two tokens, assert that that is what is returned
|
||||
assert generation.generations[0][0].text == "Hello!"
|
||||
|
||||
|
||||
def test_chat_openai_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatOpenAI(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_openai_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai() -> None:
|
||||
"""Test async generation."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
assert response.llm_output
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai_bind_functions() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
|
||||
class Person(BaseModel):
|
||||
"""Identifying information about a person."""
|
||||
|
||||
name: str = Field(..., title="Name", description="The person's name")
|
||||
age: int = Field(..., title="Age", description="The person's age")
|
||||
fav_food: Optional[str] = Field(
|
||||
default=None, title="Fav Food", description="The person's favorite food"
|
||||
)
|
||||
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=30,
|
||||
n=1,
|
||||
streaming=True,
|
||||
).bind_functions(functions=[Person], function_call="Person")
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "Use the provided Person function"),
|
||||
("user", "{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
message = HumanMessage(content="Sally is 13 years old")
|
||||
response = await chain.abatch([{"input": message}])
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
for generation in response:
|
||||
assert isinstance(generation, AIMessage)
|
||||
|
||||
|
||||
def test_chat_openai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatOpenAI(foo=3, max_tokens=10)
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2})
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(foo=3, model_kwargs={"foo": 2})
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(model_kwargs={"temperature": 0.2})
|
||||
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(model_kwargs={"model": "text-davinci-003"})
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch() -> None:
|
||||
"""Test streaming tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_batch() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_invoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
@@ -1,219 +0,0 @@
|
||||
"""Test Baidu Qianfan Chat Endpoint."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
|
||||
from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
_FUNCTIONS: Any = [
|
||||
{
|
||||
"name": "format_person_info",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": {
|
||||
"title": "Person",
|
||||
"description": "Identifying information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "The person's name",
|
||||
"type": "string",
|
||||
},
|
||||
"age": {
|
||||
"title": "Age",
|
||||
"description": "The person's age",
|
||||
"type": "integer",
|
||||
},
|
||||
"fav_food": {
|
||||
"title": "Fav Food",
|
||||
"description": "The person's favorite food",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get_current_temperature",
|
||||
"description": ("Used to get the location's temperature."),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "city name",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["centigrade", "Fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location", "unit"],
|
||||
},
|
||||
"responses": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature": {
|
||||
"type": "integer",
|
||||
"description": "city temperature",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["centigrade", "Fahrenheit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test default model(`ERNIE-Bot`) call."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_model() -> None:
|
||||
"""Test model kwarg works."""
|
||||
chat = QianfanChatEndpoint(model="BLOOMZ-7B")
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_model_param() -> None:
|
||||
"""Test model params works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = chat(model="BLOOMZ-7B", messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_endpoint() -> None:
|
||||
"""Test user custom model deployments like some open source models."""
|
||||
chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_endpoint_param() -> None:
|
||||
"""Test user custom model deployments like some open source models."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(endpoint="qianfan_bloomz_7b_compressed", content="Hello")
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = QianfanChatEndpoint(streaming=True)
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="Who are you?"),
|
||||
],
|
||||
stream=True,
|
||||
callbacks=callback_manager,
|
||||
)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
message = HumanMessage(content="Hi, how are you.")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
def test_functions_call_thoughts() -> None:
|
||||
chat = QianfanChatEndpoint(model="ERNIE-Bot")
|
||||
|
||||
prompt_tmpl = "Use the given functions to answer following question: {input}"
|
||||
prompt_msgs = [
|
||||
HumanMessagePromptTemplate.from_template(prompt_tmpl),
|
||||
]
|
||||
prompt = ChatPromptTemplate(messages=prompt_msgs)
|
||||
|
||||
chain = prompt | chat.bind(functions=_FUNCTIONS)
|
||||
|
||||
message = HumanMessage(content="What's the temperature in Shanghai today?")
|
||||
response = chain.batch([{"input": message}])
|
||||
assert isinstance(response[0], AIMessage)
|
||||
assert "function_call" in response[0].additional_kwargs
|
||||
|
||||
|
||||
def test_functions_call() -> None:
|
||||
chat = QianfanChatEndpoint(model="ERNIE-Bot")
|
||||
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[
|
||||
HumanMessage(content="What's the temperature in Shanghai today?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "get_current_temperature",
|
||||
"thoughts": "i will use get_current_temperature "
|
||||
"to resolve the questions",
|
||||
"arguments": '{"location":"Shanghai","unit":"centigrade"}',
|
||||
}
|
||||
},
|
||||
),
|
||||
FunctionMessage(
|
||||
name="get_current_weather",
|
||||
content='{"temperature": "25", \
|
||||
"unit": "摄氏度", "description": "晴朗"}',
|
||||
),
|
||||
]
|
||||
)
|
||||
chain = prompt | chat.bind(functions=_FUNCTIONS)
|
||||
resp = chain.invoke({})
|
||||
assert isinstance(resp, AIMessage)
|
||||
@@ -1,182 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.document_loaders.concurrent import ConcurrentLoader
|
||||
from langchain_community.document_loaders.generic import GenericLoader
|
||||
from langchain_community.document_loaders.parsers import LanguageParser
|
||||
|
||||
|
||||
def test_language_loader_for_python() -> None:
|
||||
"""Test Python loader with parser enabled."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = GenericLoader.from_filesystem(
|
||||
file_path, glob="hello_world.py", parser=LanguageParser(parser_threshold=5)
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 2
|
||||
|
||||
metadata = docs[0].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.py")
|
||||
assert metadata["content_type"] == "functions_classes"
|
||||
assert metadata["language"] == "python"
|
||||
metadata = docs[1].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.py")
|
||||
assert metadata["content_type"] == "simplified_code"
|
||||
assert metadata["language"] == "python"
|
||||
|
||||
assert (
|
||||
docs[0].page_content
|
||||
== """def main():
|
||||
print("Hello World!")
|
||||
|
||||
return 0"""
|
||||
)
|
||||
assert (
|
||||
docs[1].page_content
|
||||
== """#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
# Code for: def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())"""
|
||||
)
|
||||
|
||||
|
||||
def test_language_loader_for_python_with_parser_threshold() -> None:
|
||||
"""Test Python loader with parser enabled and below threshold."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = GenericLoader.from_filesystem(
|
||||
file_path,
|
||||
glob="hello_world.py",
|
||||
parser=LanguageParser(language="python", parser_threshold=1000),
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def esprima_installed() -> bool:
|
||||
try:
|
||||
import esprima # noqa: F401
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"esprima not installed, skipping test {e}")
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not esprima_installed(), reason="requires esprima package")
|
||||
def test_language_loader_for_javascript() -> None:
|
||||
"""Test JavaScript loader with parser enabled."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = GenericLoader.from_filesystem(
|
||||
file_path, glob="hello_world.js", parser=LanguageParser(parser_threshold=5)
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 3
|
||||
|
||||
metadata = docs[0].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.js")
|
||||
assert metadata["content_type"] == "functions_classes"
|
||||
assert metadata["language"] == "js"
|
||||
metadata = docs[1].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.js")
|
||||
assert metadata["content_type"] == "functions_classes"
|
||||
assert metadata["language"] == "js"
|
||||
metadata = docs[2].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.js")
|
||||
assert metadata["content_type"] == "simplified_code"
|
||||
assert metadata["language"] == "js"
|
||||
|
||||
assert (
|
||||
docs[0].page_content
|
||||
== """class HelloWorld {
|
||||
sayHello() {
|
||||
console.log("Hello World!");
|
||||
}
|
||||
}"""
|
||||
)
|
||||
assert (
|
||||
docs[1].page_content
|
||||
== """function main() {
|
||||
const hello = new HelloWorld();
|
||||
hello.sayHello();
|
||||
}"""
|
||||
)
|
||||
assert (
|
||||
docs[2].page_content
|
||||
== """// Code for: class HelloWorld {
|
||||
|
||||
// Code for: function main() {
|
||||
|
||||
main();"""
|
||||
)
|
||||
|
||||
|
||||
def test_language_loader_for_javascript_with_parser_threshold() -> None:
|
||||
"""Test JavaScript loader with parser enabled and below threshold."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = GenericLoader.from_filesystem(
|
||||
file_path,
|
||||
glob="hello_world.js",
|
||||
parser=LanguageParser(language="js", parser_threshold=1000),
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def test_concurrent_language_loader_for_javascript_with_parser_threshold() -> None:
|
||||
"""Test JavaScript ConcurrentLoader with parser enabled and below threshold."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = ConcurrentLoader.from_filesystem(
|
||||
file_path,
|
||||
glob="hello_world.js",
|
||||
parser=LanguageParser(language="js", parser_threshold=1000),
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def test_concurrent_language_loader_for_python_with_parser_threshold() -> None:
|
||||
"""Test Python ConcurrentLoader with parser enabled and below threshold."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = ConcurrentLoader.from_filesystem(
|
||||
file_path,
|
||||
glob="hello_world.py",
|
||||
parser=LanguageParser(language="python", parser_threshold=1000),
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(not esprima_installed(), reason="requires esprima package")
|
||||
def test_concurrent_language_loader_for_javascript() -> None:
|
||||
"""Test JavaScript ConcurrentLoader with parser enabled."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = ConcurrentLoader.from_filesystem(
|
||||
file_path, glob="hello_world.js", parser=LanguageParser(parser_threshold=5)
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 3
|
||||
|
||||
|
||||
def test_concurrent_language_loader_for_python() -> None:
|
||||
"""Test Python ConcurrentLoader with parser enabled."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = ConcurrentLoader.from_filesystem(
|
||||
file_path, glob="hello_world.py", parser=LanguageParser(parser_threshold=5)
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 2
|
||||
@@ -1,136 +0,0 @@
|
||||
"""Test Fireworks AI API Wrapper."""
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.llms.fireworks import Fireworks
|
||||
|
||||
@pytest.fixture
|
||||
def llm() -> Fireworks:
|
||||
return Fireworks(model_kwargs={"temperature": 0, "max_tokens": 512})
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_fireworks_call(llm: Fireworks) -> None:
|
||||
"""Test valid call to fireworks."""
|
||||
output = llm("How is the weather in New York today?")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_fireworks_model_param() -> None:
|
||||
"""Tests model parameters for Fireworks"""
|
||||
llm = Fireworks(model="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_fireworks_invoke(llm: Fireworks) -> None:
|
||||
"""Tests completion with invoke"""
|
||||
output = llm.invoke("How is the weather in New York today?", stop=[","])
|
||||
assert isinstance(output, str)
|
||||
assert output[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_fireworks_ainvoke(llm: Fireworks) -> None:
|
||||
"""Tests completion with invoke"""
|
||||
output = await llm.ainvoke("How is the weather in New York today?", stop=[","])
|
||||
assert isinstance(output, str)
|
||||
assert output[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_fireworks_batch(llm: Fireworks) -> None:
|
||||
"""Tests completion with invoke"""
|
||||
llm = Fireworks()
|
||||
output = llm.batch(
|
||||
[
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
],
|
||||
stop=[","],
|
||||
)
|
||||
for token in output:
|
||||
assert isinstance(token, str)
|
||||
assert token[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_fireworks_abatch(llm: Fireworks) -> None:
|
||||
"""Tests completion with invoke"""
|
||||
output = await llm.abatch(
|
||||
[
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
],
|
||||
stop=[","],
|
||||
)
|
||||
for token in output:
|
||||
assert isinstance(token, str)
|
||||
assert token[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_fireworks_multiple_prompts(
|
||||
llm: Fireworks,
|
||||
) -> None:
|
||||
"""Test completion with multiple prompts."""
|
||||
output = llm.generate(["How is the weather in New York today?", "I'm pickle rick"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_fireworks_streaming(llm: Fireworks) -> None:
|
||||
"""Test stream completion."""
|
||||
generator = llm.stream("Who's the best quarterback in the NFL?")
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_fireworks_streaming_stop_words(llm: Fireworks) -> None:
|
||||
"""Test stream completion with stop words."""
|
||||
generator = llm.stream("Who's the best quarterback in the NFL?", stop=[","])
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
last_token = ""
|
||||
for token in generator:
|
||||
last_token = token
|
||||
assert isinstance(token, str)
|
||||
assert last_token[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_fireworks_streaming_async(llm: Fireworks) -> None:
|
||||
"""Test stream completion."""
|
||||
|
||||
last_token = ""
|
||||
async for token in llm.astream(
|
||||
"Who's the best quarterback in the NFL?", stop=[","]
|
||||
):
|
||||
last_token = token
|
||||
assert isinstance(token, str)
|
||||
assert last_token[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_fireworks_async_agenerate(llm: Fireworks) -> None:
|
||||
"""Test async."""
|
||||
output = await llm.agenerate(["What is the best city to live in California?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_fireworks_multiple_prompts_async_agenerate(llm: Fireworks) -> None:
|
||||
output = await llm.agenerate(
|
||||
["How is the weather in New York today?", "I'm pickle rick"]
|
||||
)
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 2
|
||||
@@ -1,77 +0,0 @@
|
||||
import langchain_community.utilities.opaqueprompts as op
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_community.llms.opaqueprompts import OpaquePrompts
|
||||
|
||||
prompt_template = """
|
||||
As an AI assistant, you will answer questions according to given context.
|
||||
|
||||
Sensitive personal information in the question is masked for privacy.
|
||||
For instance, if the original text says "Giana is good," it will be changed
|
||||
to "PERSON_998 is good."
|
||||
|
||||
Here's how to handle these changes:
|
||||
* Consider these masked phrases just as placeholders, but still refer to
|
||||
them in a relevant way when answering.
|
||||
* It's possible that different masked terms might mean the same thing.
|
||||
Stick with the given term and don't modify it.
|
||||
* All masked terms follow the "TYPE_ID" pattern.
|
||||
* Please don't invent new masked terms. For instance, if you see "PERSON_998,"
|
||||
don't come up with "PERSON_997" or "PERSON_999" unless they're already in the question.
|
||||
|
||||
Conversation History: ```{history}```
|
||||
Context : ```During our recent meeting on February 23, 2023, at 10:30 AM,
|
||||
John Doe provided me with his personal details. His email is johndoe@example.com
|
||||
and his contact number is 650-456-7890. He lives in New York City, USA, and
|
||||
belongs to the American nationality with Christian beliefs and a leaning towards
|
||||
the Democratic party. He mentioned that he recently made a transaction using his
|
||||
credit card 4111 1111 1111 1111 and transferred bitcoins to the wallet address
|
||||
1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa. While discussing his European travels, he
|
||||
noted down his IBAN as GB29 NWBK 6016 1331 9268 19. Additionally, he provided
|
||||
his website as https://johndoeportfolio.com. John also discussed
|
||||
some of his US-specific details. He said his bank account number is
|
||||
1234567890123456 and his drivers license is Y12345678. His ITIN is 987-65-4321,
|
||||
and he recently renewed his passport,
|
||||
the number for which is 123456789. He emphasized not to share his SSN, which is
|
||||
669-45-6789. Furthermore, he mentioned that he accesses his work files remotely
|
||||
through the IP 192.168.1.1 and has a medical license number MED-123456. ```
|
||||
Question: ```{question}```
|
||||
"""
|
||||
|
||||
|
||||
def test_opaqueprompts() -> None:
|
||||
chain = PromptTemplate.from_template(prompt_template) | OpaquePrompts(llm=OpenAI())
|
||||
output = chain.invoke(
|
||||
{
|
||||
"question": "Write a text message to remind John to do password reset \
|
||||
for his website through his email to stay secure."
|
||||
}
|
||||
)
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_opaqueprompts_functions() -> None:
|
||||
prompt = (PromptTemplate.from_template(prompt_template),)
|
||||
llm = OpenAI()
|
||||
pg_chain = (
|
||||
op.sanitize
|
||||
| RunnableParallel(
|
||||
secure_context=lambda x: x["secure_context"], # type: ignore
|
||||
response=(lambda x: x["sanitized_input"]) # type: ignore
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser(),
|
||||
)
|
||||
| (lambda x: op.desanitize(x["response"], x["secure_context"]))
|
||||
)
|
||||
|
||||
pg_chain.invoke(
|
||||
{
|
||||
"question": "Write a text message to remind John to do password reset\
|
||||
for his website through his email to stay secure.",
|
||||
"history": "",
|
||||
}
|
||||
)
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Test Nebula API wrapper."""
|
||||
from langchain_community.llms.symblai_nebula import Nebula
|
||||
|
||||
|
||||
def test_symblai_nebula_call() -> None:
|
||||
"""Test valid call to Nebula."""
|
||||
conversation = """Sam: Good morning, team! Let's keep this standup concise.
|
||||
We'll go in the usual order: what you did yesterday,
|
||||
what you plan to do today, and any blockers. Alex, kick us off.
|
||||
Alex: Morning! Yesterday, I wrapped up the UI for the user dashboard.
|
||||
The new charts and widgets are now responsive.
|
||||
I also had a sync with the design team to ensure the final touchups are in
|
||||
line with the brand guidelines. Today, I'll start integrating the frontend with
|
||||
the new API endpoints Rhea was working on.
|
||||
The only blocker is waiting for some final API documentation,
|
||||
but I guess Rhea can update on that.
|
||||
Rhea: Hey, all! Yep, about the API documentation - I completed the majority of
|
||||
the backend work for user data retrieval yesterday.
|
||||
The endpoints are mostly set up, but I need to do a bit more testing today.
|
||||
I'll finalize the API documentation by noon, so that should unblock Alex.
|
||||
After that, I’ll be working on optimizing the database queries
|
||||
for faster data fetching. No other blockers on my end.
|
||||
Sam: Great, thanks Rhea. Do reach out if you need any testing assistance
|
||||
or if there are any hitches with the database.
|
||||
Now, my update: Yesterday, I coordinated with the client to get clarity
|
||||
on some feature requirements. Today, I'll be updating our project roadmap
|
||||
and timelines based on their feedback. Additionally, I'll be sitting with
|
||||
the QA team in the afternoon for preliminary testing.
|
||||
Blocker: I might need both of you to be available for a quick call
|
||||
in case the client wants to discuss the changes live.
|
||||
Alex: Sounds good, Sam. Just let us know a little in advance for the call.
|
||||
Rhea: Agreed. We can make time for that.
|
||||
Sam: Perfect! Let's keep the momentum going. Reach out if there are any
|
||||
sudden issues or support needed. Have a productive day!
|
||||
Alex: You too.
|
||||
Rhea: Thanks, bye!"""
|
||||
llm = Nebula(nebula_api_key="<your_api_key>")
|
||||
|
||||
instruction = """Identify the main objectives mentioned in this
|
||||
conversation."""
|
||||
output = llm.invoke(f"{instruction}\n{conversation}")
|
||||
assert isinstance(output, str)
|
||||
@@ -1,151 +0,0 @@
|
||||
"""Test Vertex AI API wrapper.
|
||||
In order to run this test, you need to install VertexAI SDK:
|
||||
pip install google-cloud-aiplatform>=1.36.0
|
||||
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
"""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.llms import VertexAI, VertexAIModelGarden
|
||||
|
||||
|
||||
def test_vertex_initialization() -> None:
|
||||
llm = VertexAI()
|
||||
assert llm._llm_type == "vertexai"
|
||||
assert llm.model_name == llm.client._model_id
|
||||
|
||||
|
||||
def test_vertex_call() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertex_generate() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertex_generate_code() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001")
|
||||
output = llm.generate(["generate a python method that says foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_vertex_agenerate() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertex_stream() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
|
||||
async def test_vertex_consistency() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
output = llm.generate(["Please say foo:"])
|
||||
streaming_output = llm.generate(["Please say foo:"], stream=True)
|
||||
async_output = await llm.agenerate(["Please say foo:"])
|
||||
assert output.generations[0][0].text == streaming_output.generations[0][0].text
|
||||
assert output.generations[0][0].text == async_output.generations[0][0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
def test_model_garden(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
"""In order to run this test, you should provide endpoint names.
|
||||
|
||||
Example:
|
||||
export FALCON_ENDPOINT_ID=...
|
||||
export LLAMA_ENDPOINT_ID=...
|
||||
export PROJECT=...
|
||||
"""
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = llm("What is the meaning of life?")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "vertexai_model_garden"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
def test_model_garden_generate(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
"""In order to run this test, you should provide endpoint names.
|
||||
|
||||
Example:
|
||||
export FALCON_ENDPOINT_ID=...
|
||||
export LLAMA_ENDPOINT_ID=...
|
||||
export PROJECT=...
|
||||
"""
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
async def test_model_garden_agenerate(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_vertex_call_count_tokens() -> None:
|
||||
llm = VertexAI()
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
@@ -1,171 +0,0 @@
|
||||
"""Integration test for Arxiv API Wrapper."""
|
||||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.tools import ArxivQueryRun
|
||||
from langchain_community.utilities import ArxivAPIWrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client() -> ArxivAPIWrapper:
|
||||
return ArxivAPIWrapper()
|
||||
|
||||
|
||||
def test_run_success_paper_name(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of paper name that returns the correct answer"""
|
||||
|
||||
output = api_client.run("Heat-bath random walks with Markov bases")
|
||||
assert "Probability distributions for Markov chains based quantum walks" in output
|
||||
assert (
|
||||
"Transformations of random walks on groups via Markov stopping times" in output
|
||||
)
|
||||
assert (
|
||||
"Recurrence of Multidimensional Persistent Random Walks. Fourier and Series "
|
||||
"Criteria" in output
|
||||
)
|
||||
|
||||
|
||||
def test_run_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of an arxiv identifier returns the correct answer"""
|
||||
|
||||
output = api_client.run("1605.08386v1")
|
||||
assert "Heat-bath random walks with Markov bases" in output
|
||||
|
||||
|
||||
def test_run_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of multiple arxiv identifiers that returns the correct answer"""
|
||||
|
||||
output = api_client.run("1605.08386v1 2212.00794v2 2308.07912")
|
||||
assert "Heat-bath random walks with Markov bases" in output
|
||||
assert "Scaling Language-Image Pre-training via Masking" in output
|
||||
assert (
|
||||
"Ultra-low mass PBHs in the early universe can explain the PTA signal" in output
|
||||
)
|
||||
|
||||
|
||||
def test_run_returns_several_docs(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test that returns several docs"""
|
||||
|
||||
output = api_client.run("Caprice Stanley")
|
||||
assert "On Mixing Behavior of a Family of Random Walks" in output
|
||||
|
||||
|
||||
def test_run_returns_no_result(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test that gives no result."""
|
||||
|
||||
output = api_client.run("1605.08386WWW")
|
||||
assert "No good Arxiv Result was found" == output
|
||||
|
||||
|
||||
def assert_docs(docs: List[Document]) -> None:
|
||||
for doc in docs:
|
||||
assert doc.page_content
|
||||
assert doc.metadata
|
||||
assert set(doc.metadata) == {"Published", "Title", "Authors", "Summary"}
|
||||
|
||||
|
||||
def test_load_success_paper_name(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of paper name that returns one document"""
|
||||
|
||||
docs = api_client.load("Heat-bath random walks with Markov bases")
|
||||
assert len(docs) == 3
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of an arxiv identifier that returns one document"""
|
||||
|
||||
docs = api_client.load("1605.08386v1")
|
||||
assert len(docs) == 1
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of arxiv identifiers that returns the correct answer"""
|
||||
|
||||
docs = api_client.load("1605.08386v1 2212.00794v2 2308.07912")
|
||||
assert len(docs) == 3
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_returns_no_result(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test that returns no docs"""
|
||||
|
||||
docs = api_client.load("1605.08386WWW")
|
||||
assert len(docs) == 0
|
||||
|
||||
|
||||
def test_load_returns_limited_docs() -> None:
|
||||
"""Test that returns several docs"""
|
||||
expected_docs = 2
|
||||
api_client = ArxivAPIWrapper(load_max_docs=expected_docs)
|
||||
docs = api_client.load("ChatGPT")
|
||||
assert len(docs) == expected_docs
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_returns_limited_doc_content_chars() -> None:
|
||||
"""Test that returns limited doc_content_chars_max"""
|
||||
|
||||
doc_content_chars_max = 100
|
||||
api_client = ArxivAPIWrapper(doc_content_chars_max=doc_content_chars_max)
|
||||
docs = api_client.load("1605.08386")
|
||||
assert len(docs[0].page_content) == doc_content_chars_max
|
||||
|
||||
|
||||
def test_load_returns_unlimited_doc_content_chars() -> None:
|
||||
"""Test that returns unlimited doc_content_chars_max"""
|
||||
|
||||
doc_content_chars_max = None
|
||||
api_client = ArxivAPIWrapper(doc_content_chars_max=doc_content_chars_max)
|
||||
docs = api_client.load("1605.08386")
|
||||
assert len(docs[0].page_content) == pytest.approx(54338, rel=1e-2)
|
||||
|
||||
|
||||
def test_load_returns_full_set_of_metadata() -> None:
|
||||
"""Test that returns several docs"""
|
||||
api_client = ArxivAPIWrapper(load_max_docs=1, load_all_available_meta=True)
|
||||
docs = api_client.load("ChatGPT")
|
||||
assert len(docs) == 1
|
||||
for doc in docs:
|
||||
assert doc.page_content
|
||||
assert doc.metadata
|
||||
assert set(doc.metadata).issuperset(
|
||||
{"Published", "Title", "Authors", "Summary"}
|
||||
)
|
||||
print(doc.metadata)
|
||||
assert len(set(doc.metadata)) > 4
|
||||
|
||||
|
||||
def _load_arxiv_from_universal_entry(**kwargs: Any) -> BaseTool:
|
||||
from langchain.agents.load_tools import load_tools
|
||||
tools = load_tools(["arxiv"], **kwargs)
|
||||
assert len(tools) == 1, "loaded more than 1 tool"
|
||||
return tools[0]
|
||||
|
||||
|
||||
def test_load_arxiv_from_universal_entry() -> None:
|
||||
arxiv_tool = _load_arxiv_from_universal_entry()
|
||||
output = arxiv_tool("Caprice Stanley")
|
||||
assert (
|
||||
"On Mixing Behavior of a Family of Random Walks" in output
|
||||
), "failed to fetch a valid result"
|
||||
|
||||
|
||||
def test_load_arxiv_from_universal_entry_with_params() -> None:
|
||||
params = {
|
||||
"top_k_results": 1,
|
||||
"load_max_docs": 10,
|
||||
"load_all_available_meta": True,
|
||||
}
|
||||
arxiv_tool = _load_arxiv_from_universal_entry(**params)
|
||||
assert isinstance(arxiv_tool, ArxivQueryRun)
|
||||
wp = arxiv_tool.api_wrapper
|
||||
assert wp.top_k_results == 1, "failed to assert top_k_results"
|
||||
assert wp.load_max_docs == 10, "failed to assert load_max_docs"
|
||||
assert (
|
||||
wp.load_all_available_meta is True
|
||||
), "failed to assert load_all_available_meta"
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Integration test for PubMed API Wrapper."""
|
||||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.tools import PubmedQueryRun
|
||||
from langchain_community.utilities import PubMedAPIWrapper
|
||||
|
||||
xmltodict = pytest.importorskip("xmltodict")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client() -> PubMedAPIWrapper:
|
||||
return PubMedAPIWrapper()
|
||||
|
||||
|
||||
def test_run_success(api_client: PubMedAPIWrapper) -> None:
|
||||
"""Test that returns the correct answer"""
|
||||
|
||||
search_string = (
|
||||
"Examining the Validity of ChatGPT in Identifying "
|
||||
"Relevant Nephrology Literature"
|
||||
)
|
||||
output = api_client.run(search_string)
|
||||
test_string = (
|
||||
"Examining the Validity of ChatGPT in Identifying "
|
||||
"Relevant Nephrology Literature: Findings and Implications"
|
||||
)
|
||||
assert test_string in output
|
||||
assert len(output) == api_client.doc_content_chars_max
|
||||
|
||||
|
||||
def test_run_returns_no_result(api_client: PubMedAPIWrapper) -> None:
|
||||
"""Test that gives no result."""
|
||||
|
||||
output = api_client.run("1605.08386WWW")
|
||||
assert "No good PubMed Result was found" == output
|
||||
|
||||
|
||||
def test_retrieve_article_returns_book_abstract(api_client: PubMedAPIWrapper) -> None:
|
||||
"""Test that returns the excerpt of a book."""
|
||||
|
||||
output_nolabel = api_client.retrieve_article("25905357", "")
|
||||
output_withlabel = api_client.retrieve_article("29262144", "")
|
||||
test_string_nolabel = (
|
||||
"Osteoporosis is a multifactorial disorder associated with low bone mass and "
|
||||
"enhanced skeletal fragility. Although"
|
||||
)
|
||||
assert test_string_nolabel in output_nolabel["Summary"]
|
||||
assert (
|
||||
"Wallenberg syndrome was first described in 1808 by Gaspard Vieusseux. However,"
|
||||
in output_withlabel["Summary"]
|
||||
)
|
||||
|
||||
|
||||
def test_retrieve_article_returns_article_abstract(
|
||||
api_client: PubMedAPIWrapper,
|
||||
) -> None:
|
||||
"""Test that returns the abstract of an article."""
|
||||
|
||||
output_nolabel = api_client.retrieve_article("37666905", "")
|
||||
output_withlabel = api_client.retrieve_article("37666551", "")
|
||||
test_string_nolabel = (
|
||||
"This work aims to: (1) Provide maximal hand force data on six different "
|
||||
"grasp types for healthy subjects; (2) detect grasp types with maximal "
|
||||
"force significantly affected by hand osteoarthritis (HOA) in women; (3) "
|
||||
"look for predictors to detect HOA from the maximal forces using discriminant "
|
||||
"analyses."
|
||||
)
|
||||
assert test_string_nolabel in output_nolabel["Summary"]
|
||||
test_string_withlabel = (
|
||||
"OBJECTIVES: To assess across seven hospitals from six different countries "
|
||||
"the extent to which the COVID-19 pandemic affected the volumes of orthopaedic "
|
||||
"hospital admissions and patient outcomes for non-COVID-19 patients admitted "
|
||||
"for orthopaedic care."
|
||||
)
|
||||
assert test_string_withlabel in output_withlabel["Summary"]
|
||||
|
||||
|
||||
def test_retrieve_article_no_abstract_available(api_client: PubMedAPIWrapper) -> None:
|
||||
"""Test that returns 'No abstract available'."""
|
||||
|
||||
output = api_client.retrieve_article("10766884", "")
|
||||
assert "No abstract available" == output["Summary"]
|
||||
|
||||
|
||||
def assert_docs(docs: List[Document]) -> None:
|
||||
for doc in docs:
|
||||
assert doc.metadata
|
||||
assert set(doc.metadata) == {
|
||||
"Copyright Information",
|
||||
"uid",
|
||||
"Title",
|
||||
"Published",
|
||||
}
|
||||
|
||||
|
||||
def test_load_success(api_client: PubMedAPIWrapper) -> None:
|
||||
"""Test that returns one document"""
|
||||
|
||||
docs = api_client.load_docs("chatgpt")
|
||||
assert len(docs) == api_client.top_k_results == 3
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_returns_no_result(api_client: PubMedAPIWrapper) -> None:
|
||||
"""Test that returns no docs"""
|
||||
|
||||
docs = api_client.load_docs("1605.08386WWW")
|
||||
assert len(docs) == 0
|
||||
|
||||
|
||||
def test_load_returns_limited_docs() -> None:
|
||||
"""Test that returns several docs"""
|
||||
expected_docs = 2
|
||||
api_client = PubMedAPIWrapper(top_k_results=expected_docs)
|
||||
docs = api_client.load_docs("ChatGPT")
|
||||
assert len(docs) == expected_docs
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_returns_full_set_of_metadata() -> None:
|
||||
"""Test that returns several docs"""
|
||||
api_client = PubMedAPIWrapper(load_max_docs=1, load_all_available_meta=True)
|
||||
docs = api_client.load_docs("ChatGPT")
|
||||
assert len(docs) == 3
|
||||
for doc in docs:
|
||||
assert doc.metadata
|
||||
assert set(doc.metadata).issuperset(
|
||||
{"Copyright Information", "Published", "Title", "uid"}
|
||||
)
|
||||
|
||||
|
||||
def _load_pubmed_from_universal_entry(**kwargs: Any) -> BaseTool:
|
||||
from langchain.agents.load_tools import load_tools
|
||||
tools = load_tools(["pubmed"], **kwargs)
|
||||
assert len(tools) == 1, "loaded more than 1 tool"
|
||||
return tools[0]
|
||||
|
||||
|
||||
def test_load_pupmed_from_universal_entry() -> None:
|
||||
pubmed_tool = _load_pubmed_from_universal_entry()
|
||||
search_string = (
|
||||
"Examining the Validity of ChatGPT in Identifying "
|
||||
"Relevant Nephrology Literature"
|
||||
)
|
||||
output = pubmed_tool(search_string)
|
||||
test_string = (
|
||||
"Examining the Validity of ChatGPT in Identifying "
|
||||
"Relevant Nephrology Literature: Findings and Implications"
|
||||
)
|
||||
assert test_string in output
|
||||
|
||||
|
||||
def test_load_pupmed_from_universal_entry_with_params() -> None:
|
||||
params = {
|
||||
"top_k_results": 1,
|
||||
}
|
||||
pubmed_tool = _load_pubmed_from_universal_entry(**params)
|
||||
assert isinstance(pubmed_tool, PubmedQueryRun)
|
||||
wp = pubmed_tool.api_wrapper
|
||||
assert wp.top_k_results == 1, "failed to assert top_k_results"
|
||||
@@ -1,44 +0,0 @@
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from vcr.request import Request
|
||||
|
||||
# Those environment variables turn on Deep Lake pytest mode.
|
||||
# It significantly makes tests run much faster.
|
||||
# Need to run before `import deeplake`
|
||||
os.environ["BUGGER_OFF"] = "true"
|
||||
os.environ["DEEPLAKE_DOWNLOAD_PATH"] = "./testing/local_storage"
|
||||
os.environ["DEEPLAKE_PYTEST_ENABLED"] = "true"
|
||||
|
||||
|
||||
# This fixture returns a dictionary containing filter_headers options
|
||||
# for replacing certain headers with dummy values during cassette playback
|
||||
# Specifically, it replaces the authorization header with a dummy value to
|
||||
# prevent sensitive data from being recorded in the cassette.
|
||||
# It also filters request to certain hosts (specified in the `ignored_hosts` list)
|
||||
# to prevent data from being recorded in the cassette.
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config() -> dict:
|
||||
skipped_host = ["pinecone.io"]
|
||||
|
||||
def before_record_response(response: dict) -> Union[dict, None]:
|
||||
return response
|
||||
|
||||
def before_record_request(request: Request) -> Union[Request, None]:
|
||||
for host in skipped_host:
|
||||
if request.host.startswith(host) or request.host.endswith(host):
|
||||
return None
|
||||
return request
|
||||
|
||||
return {
|
||||
"before_record_request": before_record_request,
|
||||
"before_record_response": before_record_response,
|
||||
"filter_headers": [
|
||||
("authorization", "authorization-DUMMY"),
|
||||
("X-OpenAI-Client-User-Agent", "X-OpenAI-Client-User-Agent-DUMMY"),
|
||||
("Api-Key", "Api-Key-DUMMY"),
|
||||
("User-Agent", "User-Agent-DUMMY"),
|
||||
],
|
||||
"ignore_localhost": True,
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Test CallbackManager."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.callbacks.manager import trace_as_chain_group, CallbackManager
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
||||
from langchain_community.llms.openai import BaseOpenAI
|
||||
|
||||
|
||||
def test_callback_manager_configure_context_vars(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test callback manager configuration."""
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true")
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
|
||||
with patch.object(LangChainTracer, "_update_run_single"):
|
||||
with patch.object(LangChainTracer, "_persist_run_single"):
|
||||
with trace_as_chain_group("test") as group_manager:
|
||||
assert len(group_manager.handlers) == 1
|
||||
tracer = group_manager.handlers[0]
|
||||
assert isinstance(tracer, LangChainTracer)
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
# This is a new empty callback handler
|
||||
assert cb.successful_requests == 0
|
||||
assert cb.total_tokens == 0
|
||||
|
||||
# configure adds this openai cb but doesn't modify the group manager
|
||||
mngr = CallbackManager.configure(group_manager)
|
||||
assert mngr.handlers == [tracer, cb]
|
||||
assert group_manager.handlers == [tracer]
|
||||
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
},
|
||||
)
|
||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||
|
||||
# The callback handler has been updated
|
||||
assert cb.successful_requests == 1
|
||||
assert cb.total_tokens == 3
|
||||
assert cb.prompt_tokens == 2
|
||||
assert cb.completion_tokens == 1
|
||||
assert cb.total_cost > 0
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
# This is a new empty callback handler
|
||||
assert cb.successful_requests == 0
|
||||
assert cb.total_tokens == 0
|
||||
|
||||
# configure adds this openai cb but doesn't modify the group manager
|
||||
mngr = CallbackManager.configure(group_manager)
|
||||
assert mngr.handlers == [tracer, cb]
|
||||
assert group_manager.handlers == [tracer]
|
||||
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
},
|
||||
)
|
||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||
|
||||
# The callback handler has been updated
|
||||
assert cb.successful_requests == 1
|
||||
assert cb.total_tokens == 3
|
||||
assert cb.prompt_tokens == 2
|
||||
assert cb.completion_tokens == 1
|
||||
assert cb.total_cost > 0
|
||||
wait_for_all_tracers()
|
||||
assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore
|
||||
@@ -1,31 +0,0 @@
|
||||
from langchain_community.callbacks import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AimCallbackHandler",
|
||||
"ArgillaCallbackHandler",
|
||||
"ArizeCallbackHandler",
|
||||
"PromptLayerCallbackHandler",
|
||||
"ArthurCallbackHandler",
|
||||
"ClearMLCallbackHandler",
|
||||
"CometCallbackHandler",
|
||||
"ContextCallbackHandler",
|
||||
"HumanApprovalCallbackHandler",
|
||||
"InfinoCallbackHandler",
|
||||
"MlflowCallbackHandler",
|
||||
"LLMonitorCallbackHandler",
|
||||
"OpenAICallbackHandler",
|
||||
"LLMThoughtLabeler",
|
||||
"StreamlitCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
"WhyLabsCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
"LabelStudioCallbackHandler",
|
||||
"TrubricsCallbackHandler",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
@@ -1,23 +0,0 @@
|
||||
import pathlib
|
||||
|
||||
from langchain_community.chat_loaders import slack, utils
|
||||
|
||||
|
||||
def test_slack_chat_loader() -> None:
|
||||
chat_path = (
|
||||
pathlib.Path(__file__).parents[2]
|
||||
/ "examples"
|
||||
/ "slack_export.zip"
|
||||
)
|
||||
loader = slack.SlackChatLoader(str(chat_path))
|
||||
|
||||
chat_sessions = list(
|
||||
utils.map_ai_messages(loader.lazy_load(), sender="U0500003428")
|
||||
)
|
||||
assert chat_sessions, "Chat sessions should not be empty"
|
||||
|
||||
assert chat_sessions[1]["messages"], "Chat messages should not be empty"
|
||||
|
||||
assert (
|
||||
"Example message" in chat_sessions[1]["messages"][0].content
|
||||
), "Chat content mismatch"
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Test Anthropic Chat API wrapper."""
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
from langchain_community.chat_models import BedrockChat
|
||||
from langchain_community.chat_models.meta import convert_messages_to_prompt_llama
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("messages", "expected"),
|
||||
[
|
||||
([HumanMessage(content="Hello")], "[INST] Hello [/INST]"),
|
||||
(
|
||||
[HumanMessage(content="Hello"), AIMessage(content="Answer:")],
|
||||
"[INST] Hello [/INST]\nAnswer:",
|
||||
),
|
||||
(
|
||||
[
|
||||
SystemMessage(content="You're an assistant"),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Answer:"),
|
||||
],
|
||||
"<<SYS>> You're an assistant <</SYS>>\n[INST] Hello [/INST]\nAnswer:",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_formatting(messages: List[BaseMessage], expected: str) -> None:
|
||||
result = convert_messages_to_prompt_llama(messages)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_anthropic_bedrock() -> None:
|
||||
client = MagicMock()
|
||||
respbody = MagicMock(
|
||||
read=MagicMock(
|
||||
return_value=MagicMock(
|
||||
decode=MagicMock(return_value=b'{"completion":"Hi back"}')
|
||||
)
|
||||
)
|
||||
)
|
||||
client.invoke_model.return_value = {"body": respbody}
|
||||
model = BedrockChat(model_id="anthropic.claude-v2", client=client)
|
||||
|
||||
# should not throw an error
|
||||
model.invoke("hello there")
|
||||
@@ -1,96 +0,0 @@
|
||||
"""Tests for the various PDF parsers."""
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.document_loaders.base import BaseBlobParser
|
||||
from langchain_community.document_loaders.blob_loaders import Blob
|
||||
from langchain_community.document_loaders.parsers.pdf import (
|
||||
PDFMinerParser,
|
||||
PyMuPDFParser,
|
||||
PyPDFium2Parser,
|
||||
PyPDFParser,
|
||||
)
|
||||
|
||||
_THIS_DIR = Path(__file__).parents[3]
|
||||
|
||||
_EXAMPLES_DIR = _THIS_DIR / "examples"
|
||||
|
||||
# Paths to test PDF files
|
||||
HELLO_PDF = _EXAMPLES_DIR / "hello.pdf"
|
||||
LAYOUT_PARSER_PAPER_PDF = _EXAMPLES_DIR / "layout-parser-paper.pdf"
|
||||
|
||||
|
||||
def _assert_with_parser(parser: BaseBlobParser, splits_by_page: bool = True) -> None:
|
||||
"""Standard tests to verify that the given parser works.
|
||||
|
||||
Args:
|
||||
parser (BaseBlobParser): The parser to test.
|
||||
splits_by_page (bool): Whether the parser splits by page or not by default.
|
||||
"""
|
||||
blob = Blob.from_path(HELLO_PDF)
|
||||
doc_generator = parser.lazy_parse(blob)
|
||||
assert isinstance(doc_generator, Iterator)
|
||||
docs = list(doc_generator)
|
||||
assert len(docs) == 1
|
||||
page_content = docs[0].page_content
|
||||
assert isinstance(page_content, str)
|
||||
# The different parsers return different amount of whitespace, so using
|
||||
# startswith instead of equals.
|
||||
assert docs[0].page_content.startswith("Hello world!")
|
||||
|
||||
blob = Blob.from_path(LAYOUT_PARSER_PAPER_PDF)
|
||||
doc_generator = parser.lazy_parse(blob)
|
||||
assert isinstance(doc_generator, Iterator)
|
||||
docs = list(doc_generator)
|
||||
|
||||
if splits_by_page:
|
||||
assert len(docs) == 16
|
||||
else:
|
||||
assert len(docs) == 1
|
||||
# Test is imprecise since the parsers yield different parse information depending
|
||||
# on configuration. Each parser seems to yield a slightly different result
|
||||
# for this page!
|
||||
assert "LayoutParser" in docs[0].page_content
|
||||
metadata = docs[0].metadata
|
||||
|
||||
assert metadata["source"] == str(LAYOUT_PARSER_PAPER_PDF)
|
||||
|
||||
if splits_by_page:
|
||||
assert int(metadata["page"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("pypdf")
|
||||
def test_pypdf_parser() -> None:
|
||||
"""Test PyPDF parser."""
|
||||
_assert_with_parser(PyPDFParser())
|
||||
|
||||
|
||||
@pytest.mark.requires("pdfminer")
|
||||
def test_pdfminer_parser() -> None:
|
||||
"""Test PDFMiner parser."""
|
||||
# Does not follow defaults to split by page.
|
||||
_assert_with_parser(PDFMinerParser(), splits_by_page=False)
|
||||
|
||||
|
||||
@pytest.mark.requires("fitz") # package is PyMuPDF
|
||||
def test_pymupdf_loader() -> None:
|
||||
"""Test PyMuPDF loader."""
|
||||
_assert_with_parser(PyMuPDFParser())
|
||||
|
||||
|
||||
@pytest.mark.requires("pypdfium2")
|
||||
def test_pypdfium2_parser() -> None:
|
||||
"""Test PyPDFium2 parser."""
|
||||
# Does not follow defaults to split by page.
|
||||
_assert_with_parser(PyPDFium2Parser())
|
||||
|
||||
|
||||
@pytest.mark.requires("rapidocr_onnxruntime")
|
||||
def test_extract_images_text_from_pdf() -> None:
|
||||
"""Test extract image from pdf and recognize text with rapid ocr"""
|
||||
_assert_with_parser(PyPDFParser(extract_images=True))
|
||||
_assert_with_parser(PDFMinerParser(extract_images=True))
|
||||
_assert_with_parser(PyMuPDFParser(extract_images=True))
|
||||
_assert_with_parser(PyPDFium2Parser(extract_images=True))
|
||||
@@ -1,60 +0,0 @@
|
||||
from langchain_community.embeddings import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"ClarifaiEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"DatabricksEmbeddings",
|
||||
"ElasticsearchEmbeddings",
|
||||
"FastEmbedEmbeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
"HuggingFaceInferenceAPIEmbeddings",
|
||||
"InfinityEmbeddings",
|
||||
"GradientEmbeddings",
|
||||
"JinaEmbeddings",
|
||||
"LlamaCppEmbeddings",
|
||||
"HuggingFaceHubEmbeddings",
|
||||
"MlflowAIGatewayEmbeddings",
|
||||
"MlflowEmbeddings",
|
||||
"ModelScopeEmbeddings",
|
||||
"TensorflowHubEmbeddings",
|
||||
"SagemakerEndpointEmbeddings",
|
||||
"HuggingFaceInstructEmbeddings",
|
||||
"MosaicMLInstructorEmbeddings",
|
||||
"SelfHostedEmbeddings",
|
||||
"SelfHostedHuggingFaceEmbeddings",
|
||||
"SelfHostedHuggingFaceInstructEmbeddings",
|
||||
"FakeEmbeddings",
|
||||
"DeterministicFakeEmbedding",
|
||||
"AlephAlphaAsymmetricSemanticEmbedding",
|
||||
"AlephAlphaSymmetricSemanticEmbedding",
|
||||
"SentenceTransformerEmbeddings",
|
||||
"GooglePalmEmbeddings",
|
||||
"MiniMaxEmbeddings",
|
||||
"VertexAIEmbeddings",
|
||||
"BedrockEmbeddings",
|
||||
"DeepInfraEmbeddings",
|
||||
"EdenAiEmbeddings",
|
||||
"DashScopeEmbeddings",
|
||||
"EmbaasEmbeddings",
|
||||
"OctoAIEmbeddings",
|
||||
"SpacyEmbeddings",
|
||||
"NLPCloudEmbeddings",
|
||||
"GPT4AllEmbeddings",
|
||||
"XinferenceEmbeddings",
|
||||
"LocalAIEmbeddings",
|
||||
"AwaEmbeddings",
|
||||
"HuggingFaceBgeEmbeddings",
|
||||
"ErnieEmbeddings",
|
||||
"JavelinAIGatewayEmbeddings",
|
||||
"OllamaEmbeddings",
|
||||
"QianfanEmbeddingsEndpoint",
|
||||
"JohnSnowLabsEmbeddings",
|
||||
"VoyageEmbeddings",
|
||||
"BookendEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
@@ -1,56 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def _openai_v1_installed() -> bool:
|
||||
try:
|
||||
return is_openai_v1()
|
||||
except Exception as _:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_model_param() -> None:
|
||||
llm = OpenAI(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = OpenAI(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_model_kwargs() -> None:
|
||||
llm = OpenAI(model_kwargs={"foo": "bar"})
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(model_kwargs={"model_name": "foo"})
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = OpenAI(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
|
||||
"object": "text_completion",
|
||||
"created": 1689989000,
|
||||
"model": "text-davinci-003",
|
||||
"choices": [
|
||||
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
from langchain_community.retrievers import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AmazonKendraRetriever",
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
"ChaindeskRetriever",
|
||||
"CohereRagRetriever",
|
||||
"ElasticSearchBM25Retriever",
|
||||
"EmbedchainRetriever",
|
||||
"GoogleDocumentAIWarehouseRetriever",
|
||||
"GoogleCloudEnterpriseSearchRetriever",
|
||||
"GoogleVertexAIMultiTurnSearchRetriever",
|
||||
"GoogleVertexAISearchRetriever",
|
||||
"KayAiRetriever",
|
||||
"KNNRetriever",
|
||||
"LlamaIndexGraphRetriever",
|
||||
"LlamaIndexRetriever",
|
||||
"MetalRetriever",
|
||||
"MilvusRetriever",
|
||||
"OutlineRetriever",
|
||||
"PineconeHybridSearchRetriever",
|
||||
"PubMedRetriever",
|
||||
"RemoteLangChainRetriever",
|
||||
"SVMRetriever",
|
||||
"TavilySearchAPIRetriever",
|
||||
"TFIDFRetriever",
|
||||
"BM25Retriever",
|
||||
"VespaRetriever",
|
||||
"WeaviateHybridSearchRetriever",
|
||||
"WikipediaRetriever",
|
||||
"ZepRetriever",
|
||||
"ZillizRetriever",
|
||||
"DocArrayRetriever",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
@@ -1,11 +0,0 @@
|
||||
from langchain_community.storage import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"RedisStore",
|
||||
"UpstashRedisByteStore",
|
||||
"UpstashRedisStore",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
@@ -1,40 +0,0 @@
|
||||
from typing import List, Type
|
||||
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
|
||||
import langchain_community.tools
|
||||
from langchain_community.tools import _DEPRECATED_TOOLS
|
||||
from langchain_community.tools import __all__ as tools_all
|
||||
|
||||
_EXCLUDE = {
|
||||
BaseTool,
|
||||
StructuredTool,
|
||||
}
|
||||
|
||||
|
||||
def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseTool]]:
|
||||
results = []
|
||||
for tool_class_name in tools_all:
|
||||
if tool_class_name in _DEPRECATED_TOOLS:
|
||||
continue
|
||||
# Resolve the str to the class
|
||||
tool_class = getattr(langchain_community.tools, tool_class_name)
|
||||
if isinstance(tool_class, type) and issubclass(tool_class, BaseTool):
|
||||
if tool_class in _EXCLUDE:
|
||||
continue
|
||||
if (
|
||||
skip_tools_without_default_names
|
||||
and tool_class.__fields__["name"].default # type: ignore
|
||||
in [None, ""]
|
||||
):
|
||||
continue
|
||||
results.append(tool_class)
|
||||
return results
|
||||
|
||||
|
||||
def test_tool_names_unique() -> None:
|
||||
"""Test that the default names for our core tools are unique."""
|
||||
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
|
||||
names = sorted([tool_cls.__fields__["name"].default for tool_cls in tool_classes])
|
||||
duplicated_names = [name for name in names if names.count(name) > 1]
|
||||
assert not duplicated_names
|
||||
@@ -1,728 +0,0 @@
|
||||
"""Test FAISS functionality."""
|
||||
import datetime
|
||||
import math
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.docstore.base import Docstore
|
||||
from langchain_community.docstore.in_memory import InMemoryDocstore
|
||||
from langchain_community.vectorstores.faiss import FAISS
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
_PAGE_CONTENT = """This is a page about LangChain.
|
||||
|
||||
It is a really cool framework.
|
||||
|
||||
What isn't there to love about langchain?
|
||||
|
||||
Made in 2022."""
|
||||
|
||||
|
||||
class FakeDocstore(Docstore):
|
||||
"""Fake docstore for testing purposes."""
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
"""Return the fake document."""
|
||||
document = Document(page_content=_PAGE_CONTENT)
|
||||
return document
|
||||
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_afrom_texts() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_vector_sim() -> None:
|
||||
"""Test vector similarity."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_by_vector(query_vec, k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_vector_sim() -> None:
|
||||
"""Test vector similarity."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = await FakeEmbeddings().aembed_query(text="foo")
|
||||
output = await docsearch.asimilarity_search_by_vector(query_vec, k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_vector_sim_with_score_threshold() -> None:
|
||||
"""Test vector similarity."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_by_vector(query_vec, k=2, score_threshold=0.2)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_vector_async_sim_with_score_threshold() -> None:
|
||||
"""Test vector similarity."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = await FakeEmbeddings().aembed_query(text="foo")
|
||||
output = await docsearch.asimilarity_search_by_vector(
|
||||
query_vec, k=2, score_threshold=0.2
|
||||
)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_similarity_search_with_score_by_vector() -> None:
|
||||
"""Test vector similarity with score by vector."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_with_score_by_vector(query_vec, k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_similarity_async_search_with_score_by_vector() -> None:
|
||||
"""Test vector similarity with score by vector."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = await FakeEmbeddings().aembed_query(text="foo")
|
||||
output = await docsearch.asimilarity_search_with_score_by_vector(query_vec, k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_similarity_search_with_score_by_vector_with_score_threshold() -> None:
|
||||
"""Test vector similarity with score by vector."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_with_score_by_vector(
|
||||
query_vec,
|
||||
k=2,
|
||||
score_threshold=0.2,
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] < 0.2
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_sim_asearch_with_score_by_vector_with_score_threshold() -> None:
|
||||
"""Test vector similarity with score by vector."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = await FakeEmbeddings().aembed_query(text="foo")
|
||||
output = await docsearch.asimilarity_search_with_score_by_vector(
|
||||
query_vec,
|
||||
k=2,
|
||||
score_threshold=0.2,
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] < 0.2
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_mmr() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
# make sure we can have k > docstore size
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo")
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_mmr() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
query_vec = await FakeEmbeddings().aembed_query(text="foo")
|
||||
# make sure we can have k > docstore size
|
||||
output = await docsearch.amax_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo")
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_mmr_with_metadatas() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_mmr_with_metadatas() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
query_vec = await FakeEmbeddings().aembed_query(text="foo")
|
||||
output = await docsearch.amax_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_mmr_with_metadatas_and_filter() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1, filter={"page": 1}
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
|
||||
assert output[0][1] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_mmr_with_metadatas_and_filter() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
query_vec = await FakeEmbeddings().aembed_query(text="foo")
|
||||
output = await docsearch.amax_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1, filter={"page": 1}
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
|
||||
assert output[0][1] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_mmr_with_metadatas_and_list_filter() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1, filter={"page": [0, 1, 2]}
|
||||
)
|
||||
assert len(output) == 3
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_mmr_with_metadatas_and_list_filter() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
query_vec = await FakeEmbeddings().aembed_query(text="foo")
|
||||
output = await docsearch.amax_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1, filter={"page": [0, 1, 2]}
|
||||
)
|
||||
assert len(output) == 3
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_with_metadatas() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_with_metadatas() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_with_metadatas_and_filter() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
|
||||
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_with_metadatas_and_filter() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = await docsearch.asimilarity_search("foo", k=1, filter={"page": 1})
|
||||
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_with_metadatas_and_list_filter() -> None:
|
||||
texts = ["foo", "bar", "baz", "foo", "qux"]
|
||||
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
docsearch.index_to_docstore_id[3]: Document(
|
||||
page_content="foo", metadata={"page": 3}
|
||||
),
|
||||
docsearch.index_to_docstore_id[4]: Document(
|
||||
page_content="qux", metadata={"page": 3}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]})
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_with_metadatas_and_list_filter() -> None:
|
||||
texts = ["foo", "bar", "baz", "foo", "qux"]
|
||||
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
docsearch.index_to_docstore_id[3]: Document(
|
||||
page_content="foo", metadata={"page": 3}
|
||||
),
|
||||
docsearch.index_to_docstore_id[4]: Document(
|
||||
page_content="qux", metadata={"page": 3}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = await docsearch.asimilarity_search("foor", k=1, filter={"page": [0, 1, 2]})
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_search_not_found() -> None:
|
||||
"""Test what happens when document is not found."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
# Get rid of the docstore to purposefully induce errors.
|
||||
docsearch.docstore = InMemoryDocstore({})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search("foo")
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_search_not_found() -> None:
|
||||
"""Test what happens when document is not found."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
# Get rid of the docstore to purposefully induce errors.
|
||||
docsearch.docstore = InMemoryDocstore({})
|
||||
with pytest.raises(ValueError):
|
||||
await docsearch.asimilarity_search("foo")
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_add_texts() -> None:
|
||||
"""Test end to end adding of texts."""
|
||||
# Create initial doc store.
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
# Test adding a similar document as before.
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == [Document(page_content="foo"), Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_add_texts() -> None:
|
||||
"""Test end to end adding of texts."""
|
||||
# Create initial doc store.
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
# Test adding a similar document as before.
|
||||
await docsearch.aadd_texts(["foo"])
|
||||
output = await docsearch.asimilarity_search("foo", k=2)
|
||||
assert output == [Document(page_content="foo"), Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_add_texts_not_supported() -> None:
|
||||
"""Test adding of texts to a docstore that doesn't support it."""
|
||||
docsearch = FAISS(FakeEmbeddings(), None, FakeDocstore(), {})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.add_texts(["foo"])
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_add_texts_not_supported() -> None:
|
||||
"""Test adding of texts to a docstore that doesn't support it."""
|
||||
docsearch = FAISS(FakeEmbeddings(), None, FakeDocstore(), {})
|
||||
with pytest.raises(ValueError):
|
||||
await docsearch.aadd_texts(["foo"])
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_local_save_load() -> None:
|
||||
"""Test end to end serialization."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||
docsearch.save_local(temp_folder)
|
||||
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
|
||||
assert new_docsearch.index is not None
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_local_save_load() -> None:
|
||||
"""Test end to end serialization."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(texts, FakeEmbeddings())
|
||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||
docsearch.save_local(temp_folder)
|
||||
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
|
||||
assert new_docsearch.index is not None
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_similarity_search_with_relevance_scores() -> None:
|
||||
"""Test the similarity search with normalized similarities."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
relevance_score_fn=lambda score: 1.0 - score / math.sqrt(2),
|
||||
)
|
||||
outputs = docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||
output, score = outputs[0]
|
||||
assert output == Document(page_content="foo")
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_similarity_search_with_relevance_scores() -> None:
|
||||
"""Test the similarity search with normalized similarities."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
relevance_score_fn=lambda score: 1.0 - score / math.sqrt(2),
|
||||
)
|
||||
outputs = await docsearch.asimilarity_search_with_relevance_scores("foo", k=1)
|
||||
output, score = outputs[0]
|
||||
assert output == Document(page_content="foo")
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_similarity_search_with_relevance_scores_with_threshold() -> None:
|
||||
"""Test the similarity search with normalized similarities with score threshold."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
relevance_score_fn=lambda score: 1.0 - score / math.sqrt(2),
|
||||
)
|
||||
outputs = docsearch.similarity_search_with_relevance_scores(
|
||||
"foo", k=2, score_threshold=0.5
|
||||
)
|
||||
assert len(outputs) == 1
|
||||
output, score = outputs[0]
|
||||
assert output == Document(page_content="foo")
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_asimilarity_search_with_relevance_scores_with_threshold() -> None:
|
||||
"""Test the similarity search with normalized similarities with score threshold."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
relevance_score_fn=lambda score: 1.0 - score / math.sqrt(2),
|
||||
)
|
||||
outputs = await docsearch.asimilarity_search_with_relevance_scores(
|
||||
"foo", k=2, score_threshold=0.5
|
||||
)
|
||||
assert len(outputs) == 1
|
||||
output, score = outputs[0]
|
||||
assert output == Document(page_content="foo")
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_faiss_invalid_normalize_fn() -> None:
|
||||
"""Test the similarity search with normalized similarities."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(
|
||||
texts, FakeEmbeddings(), relevance_score_fn=lambda _: 2.0
|
||||
)
|
||||
with pytest.warns(Warning, match="scores must be between"):
|
||||
docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_faiss_async_invalid_normalize_fn() -> None:
|
||||
"""Test the similarity search with normalized similarities."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await FAISS.afrom_texts(
|
||||
texts, FakeEmbeddings(), relevance_score_fn=lambda _: 2.0
|
||||
)
|
||||
with pytest.warns(Warning, match="scores must be between"):
|
||||
await docsearch.asimilarity_search_with_relevance_scores("foo", k=1)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_missing_normalize_score_fn() -> None:
|
||||
"""Test doesn't perform similarity search without a valid distance strategy."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
faiss_instance = FAISS.from_texts(texts, FakeEmbeddings(), distance_strategy="fake")
|
||||
with pytest.raises(ValueError):
|
||||
faiss_instance.similarity_search_with_relevance_scores("foo", k=2)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_async_missing_normalize_score_fn() -> None:
|
||||
"""Test doesn't perform similarity search without a valid distance strategy."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
faiss_instance = await FAISS.afrom_texts(
|
||||
texts, FakeEmbeddings(), distance_strategy="fake"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await faiss_instance.asimilarity_search_with_relevance_scores("foo", k=2)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
def test_delete() -> None:
|
||||
"""Test the similarity search with normalized similarities."""
|
||||
ids = ["a", "b", "c"]
|
||||
docsearch = FAISS.from_texts(["foo", "bar", "baz"], FakeEmbeddings(), ids=ids)
|
||||
docsearch.delete(ids[1:2])
|
||||
|
||||
result = docsearch.similarity_search("bar", k=2)
|
||||
assert sorted([d.page_content for d in result]) == ["baz", "foo"]
|
||||
assert docsearch.index_to_docstore_id == {0: ids[0], 1: ids[2]}
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
async def test_async_delete() -> None:
|
||||
"""Test the similarity search with normalized similarities."""
|
||||
ids = ["a", "b", "c"]
|
||||
docsearch = await FAISS.afrom_texts(
|
||||
["foo", "bar", "baz"], FakeEmbeddings(), ids=ids
|
||||
)
|
||||
docsearch.delete(ids[1:2])
|
||||
|
||||
result = await docsearch.asimilarity_search("bar", k=2)
|
||||
assert sorted([d.page_content for d in result]) == ["baz", "foo"]
|
||||
assert docsearch.index_to_docstore_id == {0: ids[0], 1: ids[2]}
|
||||
@@ -1,13 +0,0 @@
|
||||
from langchain_community import vectorstores
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
"""Simple test to make sure all things can be imported."""
|
||||
for cls in vectorstores.__all__:
|
||||
if cls not in [
|
||||
"AlibabaCloudOpenSearchSettings",
|
||||
"ClickhouseSettings",
|
||||
"MyScaleSettings",
|
||||
]:
|
||||
assert issubclass(getattr(vectorstores, cls), VectorStore)
|
||||
@@ -1,144 +0,0 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
DEFAULT_NAMESPACES = ["langchain", "langchain_core", "langchain_community"]
|
||||
|
||||
|
||||
class Reviver:
|
||||
"""Reviver for JSON objects."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.secrets_map = secrets_map or dict()
|
||||
# By default only support langchain, but user can pass in additional namespaces
|
||||
self.valid_namespaces = (
|
||||
[*DEFAULT_NAMESPACES, *valid_namespaces]
|
||||
if valid_namespaces
|
||||
else DEFAULT_NAMESPACES
|
||||
)
|
||||
|
||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "secret"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[key] = value["id"]
|
||||
if key in self.secrets_map:
|
||||
return self.secrets_map[key]
|
||||
else:
|
||||
if key in os.environ and os.environ[key]:
|
||||
return os.environ[key]
|
||||
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "not_implemented"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Trying to load an object that doesn't implement "
|
||||
f"serialization: {value}"
|
||||
)
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "constructor"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[*namespace, name] = value["id"]
|
||||
|
||||
if namespace[0] not in self.valid_namespaces:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# The root namespace "langchain" is not a valid identifier.
|
||||
if len(namespace) == 1 and namespace[0] == "langchain":
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# Get the importable path
|
||||
key = tuple(namespace + [name])
|
||||
if key not in SERIALIZABLE_MAPPING:
|
||||
raise ValueError(
|
||||
"Trying to deserialize something that cannot "
|
||||
"be deserialized in current version of langchain-core: "
|
||||
f"{key}"
|
||||
)
|
||||
import_path = SERIALIZABLE_MAPPING[key]
|
||||
# Split into module and name
|
||||
import_dir, import_obj = import_path[:-1], import_path[-1]
|
||||
# Import module
|
||||
mod = importlib.import_module(".".join(import_dir))
|
||||
# Import class
|
||||
cls = getattr(mod, import_obj)
|
||||
|
||||
# The class must be a subclass of Serializable.
|
||||
if not issubclass(cls, Serializable):
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# We don't need to recurse on kwargs
|
||||
# as json.loads will do that for us.
|
||||
kwargs = value.get("kwargs", dict())
|
||||
return cls(**kwargs)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def loads(
|
||||
text: str,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON string.
|
||||
Equivalent to `load(json.loads(text))`.
|
||||
|
||||
Args:
|
||||
text: The string to load.
|
||||
secrets_map: A map of secrets to load.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
|
||||
|
||||
|
||||
def load(
|
||||
obj: Any,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||
|
||||
Args:
|
||||
obj: The object to load.
|
||||
secrets_map: A map of secrets to load.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
reviver = Reviver(secrets_map, valid_namespaces)
|
||||
|
||||
def _load(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
# Need to revive leaf nodes before reviving this node
|
||||
loaded_obj = {k: _load(v) for k, v in obj.items()}
|
||||
return reviver(loaded_obj)
|
||||
if isinstance(obj, list):
|
||||
return [_load(o) for o in obj]
|
||||
return obj
|
||||
|
||||
return _load(obj)
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
**Utility functions** for LangChain.
|
||||
|
||||
These functions do not depend on any other LangChain module.
|
||||
"""
|
||||
|
||||
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
|
||||
from langchain_core.utils.formatting import StrictFormatter, formatter
|
||||
from langchain_core.utils.input import (
|
||||
get_bolded_text,
|
||||
get_color_mapping,
|
||||
get_colored_text,
|
||||
print_text,
|
||||
)
|
||||
from langchain_core.utils.loading import try_load_from_hub
|
||||
from langchain_core.utils.strings import comma_list, stringify_dict, stringify_value
|
||||
from langchain_core.utils.utils import (
|
||||
build_extra_kwargs,
|
||||
check_package_version,
|
||||
convert_to_secret_str,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
mock_now,
|
||||
raise_for_status_with_text,
|
||||
xor_args,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"StrictFormatter",
|
||||
"check_package_version",
|
||||
"convert_to_secret_str",
|
||||
"formatter",
|
||||
"get_bolded_text",
|
||||
"get_color_mapping",
|
||||
"get_colored_text",
|
||||
"get_pydantic_field_names",
|
||||
"guard_import",
|
||||
"mock_now",
|
||||
"print_text",
|
||||
"raise_for_status_with_text",
|
||||
"xor_args",
|
||||
"try_load_from_hub",
|
||||
"build_extra_kwargs",
|
||||
"get_from_env",
|
||||
"get_from_dict_or_env",
|
||||
"stringify_dict",
|
||||
"comma_list",
|
||||
"stringify_value",
|
||||
]
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def env_var_is_set(env_var: str) -> bool:
|
||||
"""Check if an environment variable is set.
|
||||
|
||||
Args:
|
||||
env_var (str): The name of the environment variable.
|
||||
|
||||
Returns:
|
||||
bool: True if the environment variable is set, False otherwise.
|
||||
"""
|
||||
return env_var in os.environ and os.environ[env_var] not in (
|
||||
"",
|
||||
"0",
|
||||
"false",
|
||||
"False",
|
||||
)
|
||||
|
||||
|
||||
def get_from_dict_or_env(
|
||||
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if key in data and data[key]:
|
||||
return data[key]
|
||||
else:
|
||||
return get_from_env(key, env_key, default=default)
|
||||
|
||||
|
||||
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
elif default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{env_key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
from langchain_core.utils import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"StrictFormatter",
|
||||
"check_package_version",
|
||||
"convert_to_secret_str",
|
||||
"formatter",
|
||||
"get_bolded_text",
|
||||
"get_color_mapping",
|
||||
"get_colored_text",
|
||||
"get_pydantic_field_names",
|
||||
"guard_import",
|
||||
"mock_now",
|
||||
"print_text",
|
||||
"raise_for_status_with_text",
|
||||
"xor_args",
|
||||
"try_load_from_hub",
|
||||
"build_extra_kwargs",
|
||||
"get_from_dict_or_env",
|
||||
"get_from_env",
|
||||
"stringify_dict",
|
||||
"comma_list",
|
||||
"stringify_value"
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
@@ -1,83 +0,0 @@
|
||||
"""**Callback handlers** allow listening to events in LangChain.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
|
||||
"""
|
||||
|
||||
from langchain_core.callbacks import StdOutCallbackHandler, StreamingStdOutCallbackHandler
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.context import (
|
||||
collect_runs,
|
||||
tracing_enabled,
|
||||
tracing_v2_enabled,
|
||||
)
|
||||
|
||||
from langchain_community.callbacks.aim_callback import AimCallbackHandler
|
||||
from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler
|
||||
from langchain_community.callbacks.arize_callback import ArizeCallbackHandler
|
||||
from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler
|
||||
from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler
|
||||
from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler
|
||||
from langchain_community.callbacks.context_callback import ContextCallbackHandler
|
||||
from langchain.callbacks.file import FileCallbackHandler
|
||||
from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler
|
||||
from langchain_community.callbacks.human import HumanApprovalCallbackHandler
|
||||
from langchain_community.callbacks.infino_callback import InfinoCallbackHandler
|
||||
from langchain_community.callbacks.labelstudio_callback import LabelStudioCallbackHandler
|
||||
from langchain_community.callbacks.llmonitor_callback import LLMonitorCallbackHandler
|
||||
from langchain_community.callbacks.mlflow_callback import MlflowCallbackHandler
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain_community.callbacks.promptlayer_callback import PromptLayerCallbackHandler
|
||||
from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout_final_only import (
|
||||
FinalStreamingStdOutCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.streamlit import LLMThoughtLabeler, StreamlitCallbackHandler
|
||||
from langchain_community.callbacks.trubrics_callback import TrubricsCallbackHandler
|
||||
from langchain_community.callbacks.wandb_callback import WandbCallbackHandler
|
||||
from langchain_community.callbacks.whylabs_callback import WhyLabsCallbackHandler
|
||||
|
||||
from langchain_community.callbacks.manager import (
|
||||
get_openai_callback,
|
||||
wandb_tracing_enabled,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AimCallbackHandler",
|
||||
"ArgillaCallbackHandler",
|
||||
"ArizeCallbackHandler",
|
||||
"PromptLayerCallbackHandler",
|
||||
"ArthurCallbackHandler",
|
||||
"ClearMLCallbackHandler",
|
||||
"CometCallbackHandler",
|
||||
"ContextCallbackHandler",
|
||||
"FileCallbackHandler",
|
||||
"HumanApprovalCallbackHandler",
|
||||
"InfinoCallbackHandler",
|
||||
"MlflowCallbackHandler",
|
||||
"LLMonitorCallbackHandler",
|
||||
"OpenAICallbackHandler",
|
||||
"StdOutCallbackHandler",
|
||||
"AsyncIteratorCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
"FinalStreamingStdOutCallbackHandler",
|
||||
"LLMThoughtLabeler",
|
||||
"LangChainTracer",
|
||||
"StreamlitCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
"WhyLabsCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"tracing_enabled",
|
||||
"tracing_v2_enabled",
|
||||
"collect_runs",
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
"LabelStudioCallbackHandler",
|
||||
"TrubricsCallbackHandler",
|
||||
]
|
||||
@@ -1,68 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainGroup,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
AsyncParentRunManager,
|
||||
AsyncRunManager,
|
||||
BaseRunManager,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainGroup,
|
||||
CallbackManagerForChainRun,
|
||||
CallbackManagerForLLMRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
CallbackManagerForToolRun,
|
||||
Callbacks,
|
||||
ParentRunManager,
|
||||
RunManager,
|
||||
ahandle_event,
|
||||
atrace_as_chain_group,
|
||||
handle_event,
|
||||
trace_as_chain_group,
|
||||
)
|
||||
from langchain_core.tracers.context import (
|
||||
collect_runs,
|
||||
tracing_enabled,
|
||||
tracing_v2_enabled,
|
||||
)
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
from langchain_community.callbacks.manager import (
|
||||
get_openai_callback,
|
||||
wandb_tracing_enabled,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseRunManager",
|
||||
"RunManager",
|
||||
"ParentRunManager",
|
||||
"AsyncRunManager",
|
||||
"AsyncParentRunManager",
|
||||
"CallbackManagerForLLMRun",
|
||||
"AsyncCallbackManagerForLLMRun",
|
||||
"CallbackManagerForChainRun",
|
||||
"AsyncCallbackManagerForChainRun",
|
||||
"CallbackManagerForToolRun",
|
||||
"AsyncCallbackManagerForToolRun",
|
||||
"CallbackManagerForRetrieverRun",
|
||||
"AsyncCallbackManagerForRetrieverRun",
|
||||
"CallbackManager",
|
||||
"CallbackManagerForChainGroup",
|
||||
"AsyncCallbackManager",
|
||||
"AsyncCallbackManagerForChainGroup",
|
||||
"tracing_enabled",
|
||||
"tracing_v2_enabled",
|
||||
"collect_runs",
|
||||
"atrace_as_chain_group",
|
||||
"trace_as_chain_group",
|
||||
"handle_event",
|
||||
"ahandle_event",
|
||||
"Callbacks",
|
||||
"env_var_is_set",
|
||||
"get_openai_callback",
|
||||
"wandb_tracing_enabled",
|
||||
]
|
||||
@@ -1,36 +0,0 @@
|
||||
from langchain.callbacks.manager import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseRunManager",
|
||||
"RunManager",
|
||||
"ParentRunManager",
|
||||
"AsyncRunManager",
|
||||
"AsyncParentRunManager",
|
||||
"CallbackManagerForLLMRun",
|
||||
"AsyncCallbackManagerForLLMRun",
|
||||
"CallbackManagerForChainRun",
|
||||
"AsyncCallbackManagerForChainRun",
|
||||
"CallbackManagerForToolRun",
|
||||
"AsyncCallbackManagerForToolRun",
|
||||
"CallbackManagerForRetrieverRun",
|
||||
"AsyncCallbackManagerForRetrieverRun",
|
||||
"CallbackManager",
|
||||
"CallbackManagerForChainGroup",
|
||||
"AsyncCallbackManager",
|
||||
"AsyncCallbackManagerForChainGroup",
|
||||
"tracing_enabled",
|
||||
"tracing_v2_enabled",
|
||||
"collect_runs",
|
||||
"atrace_as_chain_group",
|
||||
"trace_as_chain_group",
|
||||
"handle_event",
|
||||
"ahandle_event",
|
||||
"env_var_is_set",
|
||||
"Callbacks",
|
||||
"get_openai_callback",
|
||||
"wandb_tracing_enabled",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
@@ -1,75 +0,0 @@
|
||||
"""Test LLM chain."""
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict, List, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class FakeOutputParser(BaseOutputParser):
|
||||
"""Fake output parser class for testing."""
|
||||
|
||||
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Parse by splitting."""
|
||||
return text.split()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_chain() -> LLMChain:
|
||||
"""Fake LLM chain for testing purposes."""
|
||||
prompt = PromptTemplate(input_variables=["bar"], template="This is a {bar}:")
|
||||
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
||||
|
||||
|
||||
@patch(
|
||||
"langchain_community.llms.loading.get_type_to_cls_dict",
|
||||
lambda: {"fake": lambda: FakeLLM},
|
||||
)
|
||||
def test_serialization(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test serialization."""
|
||||
from langchain.chains.loading import load_chain
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
file = temp_dir + "/llm.json"
|
||||
fake_llm_chain.save(file)
|
||||
loaded_chain = load_chain(file)
|
||||
assert loaded_chain == fake_llm_chain
|
||||
|
||||
|
||||
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test error is raised if inputs are missing."""
|
||||
with pytest.raises(ValueError):
|
||||
fake_llm_chain({"foo": "bar"})
|
||||
|
||||
|
||||
def test_valid_call(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test valid call of LLM chain."""
|
||||
output = fake_llm_chain({"bar": "baz"})
|
||||
assert output == {"bar": "baz", "text1": "foo"}
|
||||
|
||||
# Test with stop words.
|
||||
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
|
||||
# Response should be `bar` now.
|
||||
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
|
||||
|
||||
|
||||
def test_predict_method(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test predict method works."""
|
||||
output = fake_llm_chain.predict(bar="baz")
|
||||
assert output == "foo"
|
||||
|
||||
|
||||
def test_predict_and_parse() -> None:
|
||||
"""Test parsing ability."""
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["foo"], template="{foo}", output_parser=FakeOutputParser()
|
||||
)
|
||||
llm = FakeLLM(queries={"foo": "foo bar"})
|
||||
chain = LLMChain(prompt=prompt, llm=llm)
|
||||
output = chain.predict_and_parse(foo="foo")
|
||||
assert output == ["foo", "bar"]
|
||||
@@ -1,57 +0,0 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
|
||||
|
||||
|
||||
def import_all_modules(package_name: str) -> dict:
|
||||
package = importlib.import_module(package_name)
|
||||
classes: dict = {}
|
||||
|
||||
for attribute_name in dir(package):
|
||||
attribute = getattr(package, attribute_name)
|
||||
if hasattr(attribute, "is_lc_serializable") and isinstance(attribute, type):
|
||||
if (
|
||||
isinstance(attribute.is_lc_serializable(), bool) # type: ignore
|
||||
and attribute.is_lc_serializable() # type: ignore
|
||||
):
|
||||
key = tuple(attribute.lc_id()) # type: ignore
|
||||
value = tuple(attribute.__module__.split(".") + [attribute.__name__])
|
||||
if key in classes and classes[key] != value:
|
||||
raise ValueError
|
||||
classes[key] = value
|
||||
if hasattr(package, "__path__"):
|
||||
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||
package.__path__, package_name + "."
|
||||
):
|
||||
if module_name not in (
|
||||
"langchain.chains.llm_bash",
|
||||
"langchain.chains.llm_symbolic_math",
|
||||
"langchain.tools.python",
|
||||
"langchain.vectorstores._pgvector_data_models",
|
||||
# TODO: why does this error?
|
||||
"langchain.agents.agent_toolkits.openapi.planner",
|
||||
):
|
||||
importlib.import_module(module_name)
|
||||
new_classes = import_all_modules(module_name)
|
||||
for k, v in new_classes.items():
|
||||
if k in classes and classes[k] != v:
|
||||
raise ValueError
|
||||
classes[k] = v
|
||||
return classes
|
||||
|
||||
|
||||
def test_serializable_mapping() -> None:
|
||||
serializable_modules = import_all_modules("langchain")
|
||||
missing = set(SERIALIZABLE_MAPPING).difference(serializable_modules)
|
||||
assert missing == set()
|
||||
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
|
||||
assert extra == set()
|
||||
|
||||
for k, import_path in serializable_modules.items():
|
||||
import_dir, import_obj = import_path[:-1], import_path[-1]
|
||||
# Import module
|
||||
mod = importlib.import_module(".".join(import_dir))
|
||||
# Import class
|
||||
cls = getattr(mod, import_obj)
|
||||
assert list(k) == cls.lc_id()
|
||||
@@ -1,113 +0,0 @@
|
||||
"""A unit test meant to catch accidental introduction of non-optional dependencies."""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
import pytest
|
||||
import toml
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
PYPROJECT_TOML = HERE / "../../pyproject.toml"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def poetry_conf() -> Dict[str, Any]:
|
||||
"""Load the pyproject.toml file."""
|
||||
with open(PYPROJECT_TOML) as f:
|
||||
return toml.load(f)["tool"]["poetry"]
|
||||
|
||||
|
||||
def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"""A test that checks if a new non-optional dependency is being introduced.
|
||||
|
||||
If this test is triggered, it means that a contributor is trying to introduce a new
|
||||
required dependency. This should be avoided in most situations.
|
||||
"""
|
||||
# Get the dependencies from the [tool.poetry.dependencies] section
|
||||
dependencies = poetry_conf["dependencies"]
|
||||
|
||||
is_required = {
|
||||
package_name: isinstance(requirements, str)
|
||||
or not requirements.get("optional", False)
|
||||
for package_name, requirements in dependencies.items()
|
||||
}
|
||||
required_dependencies = [
|
||||
package_name for package_name, required in is_required.items() if required
|
||||
]
|
||||
|
||||
assert sorted(required_dependencies) == sorted(
|
||||
[
|
||||
"PyYAML",
|
||||
"SQLAlchemy",
|
||||
"aiohttp",
|
||||
"async-timeout",
|
||||
"dataclasses-json",
|
||||
"jsonpatch",
|
||||
"langchain-core",
|
||||
"langsmith",
|
||||
"numpy",
|
||||
"pydantic",
|
||||
"python",
|
||||
"requests",
|
||||
"tenacity",
|
||||
"langchain-community",
|
||||
]
|
||||
)
|
||||
|
||||
unrequired_dependencies = [
|
||||
package_name for package_name, required in is_required.items() if not required
|
||||
]
|
||||
in_extras = [dep for group in poetry_conf["extras"].values() for dep in group]
|
||||
assert set(unrequired_dependencies) == set(in_extras)
|
||||
|
||||
|
||||
def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"""Check if someone is attempting to add additional test dependencies.
|
||||
|
||||
Only dependencies associated with test running infrastructure should be added
|
||||
to the test group; e.g., pytest, pytest-cov etc.
|
||||
|
||||
Examples of dependencies that should NOT be included: boto3, azure, postgres, etc.
|
||||
"""
|
||||
|
||||
test_group_deps = sorted(poetry_conf["group"]["test"]["dependencies"])
|
||||
|
||||
assert test_group_deps == sorted(
|
||||
[
|
||||
"duckdb-engine",
|
||||
"freezegun",
|
||||
"langchain-core",
|
||||
"lark",
|
||||
"pandas",
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"pytest-cov",
|
||||
"pytest-dotenv",
|
||||
"pytest-mock",
|
||||
"pytest-socket",
|
||||
"pytest-watcher",
|
||||
"responses",
|
||||
"syrupy",
|
||||
"requests-mock",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_imports() -> None:
|
||||
"""Test that you can import all top level things okay."""
|
||||
from langchain_core.prompts import BasePromptTemplate # noqa: F401
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent # noqa: F401
|
||||
from langchain.callbacks import OpenAICallbackHandler # noqa: F401
|
||||
from langchain.chains import LLMChain # noqa: F401
|
||||
from langchain.chat_models import ChatOpenAI # noqa: F401
|
||||
from langchain.document_loaders import BSHTMLLoader # noqa: F401
|
||||
from langchain.embeddings import OpenAIEmbeddings # noqa: F401
|
||||
from langchain.llms import OpenAI # noqa: F401
|
||||
from langchain.retrievers import VespaRetriever # noqa: F401
|
||||
from langchain.tools import DuckDuckGoSearchResults # noqa: F401
|
||||
from langchain.utilities import (
|
||||
SearchApiAPIWrapper, # noqa: F401
|
||||
SerpAPIWrapper, # noqa: F401
|
||||
)
|
||||
from langchain.vectorstores import FAISS # noqa: F401
|
||||
@@ -1,313 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd libs
|
||||
|
||||
# cleanup anything existing
|
||||
git checkout master -- langchain/{langchain,tests}
|
||||
git checkout master -- core/{langchain_core,tests}
|
||||
git checkout master -- experimental/{langchain_experimental,tests}
|
||||
rm -rf community/{langchain_community,tests}
|
||||
|
||||
# make new dirs
|
||||
mkdir -p community/langchain_community
|
||||
touch community/langchain_community/__init__.py
|
||||
touch community/langchain_community/py.typed
|
||||
touch community/README.md
|
||||
mkdir -p community/tests
|
||||
touch community/tests/__init__.py
|
||||
mkdir community/tests/unit_tests
|
||||
touch community/tests/unit_tests/__init__.py
|
||||
mkdir community/tests/integration_tests/
|
||||
touch community/tests/integration_tests/__init__.py
|
||||
mkdir -p community/langchain_community/utils
|
||||
touch community/langchain_community/utils/__init__.py
|
||||
mkdir -p community/tests/unit_tests/utils
|
||||
touch community/tests/unit_tests/utils/__init__.py
|
||||
mkdir -p community/langchain_community/indexes
|
||||
touch community/langchain_community/indexes/__init__.py
|
||||
mkdir community/tests/unit_tests/indexes
|
||||
touch community/tests/unit_tests/indexes/__init__.py
|
||||
|
||||
# import core stuff from core
|
||||
cd langchain
|
||||
|
||||
git grep -l 'from langchain.pydantic_v1' | xargs sed -i '' 's/from langchain.pydantic_v1/from langchain_core.pydantic_v1/g'
|
||||
git grep -l 'from langchain.tools.base' | xargs sed -i '' 's/from langchain.tools.base/from langchain_core.tools/g'
|
||||
git grep -l 'from langchain.chat_models.base' | xargs sed -i '' 's/from langchain.chat_models.base/from langchain_core.language_models.chat_models/g'
|
||||
git grep -l 'from langchain.llms.base' | xargs sed -i '' 's/from langchain.llms.base/from langchain_core.language_models.llms/g'
|
||||
git grep -l 'from langchain.embeddings.base' | xargs sed -i '' 's/from langchain.embeddings.base/from langchain_core.embeddings/g'
|
||||
git grep -l 'from langchain.vectorstores.base' | xargs sed -i '' 's/from langchain.vectorstores.base/from langchain_core.vectorstores/g'
|
||||
git grep -l 'from langchain.agents.tools' | xargs sed -i '' 's/from langchain.agents.tools/from langchain_core.tools/g'
|
||||
git grep -l 'from langchain.schema.output' | xargs sed -i '' 's/from langchain.schema.output/from langchain_core.outputs/g'
|
||||
git grep -l 'from langchain.schema.messages' | xargs sed -i '' 's/from langchain.schema.messages/from langchain_core.messages/g'
|
||||
git grep -l 'from langchain.schema.embeddings' | xargs sed -i '' 's/from langchain.schema.embeddings/from langchain_core.embeddings/g'
|
||||
|
||||
# mv stuff to community
|
||||
cd ..
|
||||
|
||||
mv langchain/langchain/adapters community/langchain_community
|
||||
mv langchain/langchain/callbacks community/langchain_community/callbacks
|
||||
mv langchain/langchain/chat_loaders community/langchain_community
|
||||
mv langchain/langchain/chat_models community/langchain_community
|
||||
mv langchain/langchain/document_loaders community/langchain_community
|
||||
mv langchain/langchain/docstore community/langchain_community
|
||||
mv langchain/langchain/document_transformers community/langchain_community
|
||||
mv langchain/langchain/embeddings community/langchain_community
|
||||
mv langchain/langchain/graphs community/langchain_community
|
||||
mv langchain/langchain/llms community/langchain_community
|
||||
mv langchain/langchain/memory/chat_message_histories community/langchain_community
|
||||
mv langchain/langchain/retrievers community/langchain_community
|
||||
mv langchain/langchain/storage community/langchain_community
|
||||
mv langchain/langchain/tools community/langchain_community
|
||||
mv langchain/langchain/utilities community/langchain_community
|
||||
mv langchain/langchain/vectorstores community/langchain_community
|
||||
|
||||
mv langchain/langchain/agents/agent_toolkits community/langchain_community
|
||||
mv langchain/langchain/cache.py community/langchain_community
|
||||
mv langchain/langchain/indexes/base.py community/langchain_community/indexes
|
||||
mv langchain/langchain/indexes/_sql_record_manager.py community/langchain_community/indexes
|
||||
mv langchain/langchain/utils/{math,openai,openai_functions}.py community/langchain_community/utils
|
||||
|
||||
# mv stuff to core
|
||||
mv langchain/langchain/utils/json_schema.py core/langchain_core/utils
|
||||
mv langchain/langchain/utils/html.py core/langchain_core/utils
|
||||
mv langchain/langchain/utils/strings.py core/langchain_core/utils
|
||||
cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py
|
||||
rm langchain/langchain/utils/env.py
|
||||
|
||||
# mv unit tests to community
|
||||
mv langchain/tests/unit_tests/chat_loaders community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/document_loaders community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/docstore community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/document_transformers community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/embeddings community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/graphs community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/llms community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/chat_models community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/memory/chat_message_histories community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/storage community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/tools community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/utilities community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/vectorstores community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/retrievers community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/callbacks community/tests/unit_tests
|
||||
mv langchain/tests/unit_tests/indexes/test_sql_record_manager.py community/tests/unit_tests/indexes
|
||||
mv langchain/tests/unit_tests/utils/test_math.py community/tests/unit_tests/utils
|
||||
|
||||
# cp some test helpers back to langchain
|
||||
mkdir -p langchain/tests/unit_tests/llms
|
||||
cp {community,langchain}/tests/unit_tests/llms/fake_llm.py
|
||||
cp {community,langchain}/tests/unit_tests/llms/fake_chat_model.py
|
||||
mkdir -p langchain/tests/unit_tests/callbacks
|
||||
cp {community,langchain}/tests/unit_tests/callbacks/fake_callback_handler.py
|
||||
|
||||
# mv unit tests to core
|
||||
mv langchain/tests/unit_tests/utils/test_json_schema.py core/tests/unit_tests/utils
|
||||
mv langchain/tests/unit_tests/utils/test_html.py core/tests/unit_tests/utils
|
||||
|
||||
# mv integration tests to community
|
||||
mv langchain/tests/integration_tests/document_loaders community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/embeddings community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/graphs community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/llms community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/chat_models community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/memory/chat_message_histories community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/storage community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/tools community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/utilities community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/vectorstores community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/retrievers community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/adapters community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/callbacks community/tests/integration_tests
|
||||
mv langchain/tests/integration_tests/{test_kuzu,test_nebulagraph}.py community/tests/integration_tests/graphs
|
||||
touch community/tests/integration_tests/{chat_message_histories,tools}/__init__.py
|
||||
|
||||
# import new core stuff from core (everywhere)
|
||||
git grep -l 'from langchain.utils.json_schema' | xargs sed -i '' 's/from langchain.utils.json_schema/from langchain_core.utils.json_schema/g'
|
||||
git grep -l 'from langchain.utils.html' | xargs sed -i '' 's/from langchain.utils.html/from langchain_core.utils.html/g'
|
||||
git grep -l 'from langchain.utils.strings' | xargs sed -i '' 's/from langchain.utils.strings/from langchain_core.utils.strings/g'
|
||||
git grep -l 'from langchain.utils.env' | xargs sed -i '' 's/from langchain.utils.env/from langchain_core.utils.env/g'
|
||||
|
||||
git add community
|
||||
cd community
|
||||
|
||||
# import core stuff from core
|
||||
git grep -l 'from langchain.pydantic_v1' | xargs sed -i '' 's/from langchain.pydantic_v1/from langchain_core.pydantic_v1/g'
|
||||
git grep -l 'from langchain.callbacks.base' | xargs sed -i '' 's/from langchain.callbacks.base/from langchain_core.callbacks/g'
|
||||
git grep -l 'from langchain.callbacks.stdout' | xargs sed -i '' 's/from langchain.callbacks.stdout/from langchain_core.callbacks/g'
|
||||
git grep -l 'from langchain.callbacks.streaming_stdout' | xargs sed -i '' 's/from langchain.callbacks.streaming_stdout/from langchain_core.callbacks/g'
|
||||
git grep -l 'from langchain.callbacks.manager' | xargs sed -i '' 's/from langchain.callbacks.manager/from langchain_core.callbacks/g'
|
||||
git grep -l 'from langchain.callbacks.tracers.base' | xargs sed -i '' 's/from langchain.callbacks.tracers.base/from langchain_core.tracers/g'
|
||||
git grep -l 'from langchain.tools.base' | xargs sed -i '' 's/from langchain.tools.base/from langchain_core.tools/g'
|
||||
git grep -l 'from langchain.agents.tools' | xargs sed -i '' 's/from langchain.agents.tools/from langchain_core.tools/g'
|
||||
git grep -l 'from langchain.schema.output' | xargs sed -i '' 's/from langchain.schema.output/from langchain_core.outputs/g'
|
||||
git grep -l 'from langchain.schema.messages' | xargs sed -i '' 's/from langchain.schema.messages/from langchain_core.messages/g'
|
||||
git grep -l 'from langchain.schema import BaseRetriever' | xargs sed -i '' 's/from langchain.schema\ import\ BaseRetriever/from langchain_core.retrievers import BaseRetriever/g'
|
||||
git grep -l 'from langchain.schema import Document' | xargs sed -i '' 's/from langchain.schema\ import\ Document/from langchain_core.documents import Document/g'
|
||||
|
||||
# import openai stuff from openai
|
||||
git grep -l 'from langchain.utils.math' | xargs sed -i '' 's/from langchain.utils.math/from langchain_community.utils.math/g'
|
||||
git grep -l 'from langchain.utils.openai_functions' | xargs sed -i '' 's/from langchain.utils.openai_functions/from langchain_community.utils.openai_functions/g'
|
||||
git grep -l 'from langchain.utils.openai' | xargs sed -i '' 's/from langchain.utils.openai/from langchain_community.utils.openai/g'
|
||||
git grep -l 'from langchain.utils' | xargs sed -i '' 's/from langchain.utils/from langchain_core.utils/g'
|
||||
git grep -l 'from langchain\.' | xargs sed -i '' 's/from langchain\./from langchain_community./g'
|
||||
git grep -l 'from langchain_community.memory.chat_message_histories' | xargs sed -i '' 's/from langchain_community.memory.chat_message_histories/from langchain_community.chat_message_histories/g'
|
||||
git grep -l 'from langchain_community.agents.agent_toolkits' | xargs sed -i '' 's/from langchain_community.agents.agent_toolkits/from langchain_community.agent_toolkits/g'
|
||||
|
||||
sed -i '' 's/from\ langchain.chat_models\ import\ ChatOpenAI/from langchain_openai.chat_models import ChatOpenAI/g' langchain_community/chat_models/promptlayer_openai.py
|
||||
|
||||
git grep -l 'from langchain_community\.text_splitter' | xargs sed -i '' 's/from langchain_community\.text_splitter/from langchain.text_splitter/g'
|
||||
git grep -l 'from langchain_community\.chains' | xargs sed -i '' 's/from langchain_community\.chains/from langchain.chains/g'
|
||||
git grep -l 'from langchain_community\.agents' | xargs sed -i '' 's/from langchain_community\.agents/from langchain.agents/g'
|
||||
git grep -l 'from langchain_community\.memory' | xargs sed -i '' 's/from langchain_community\.memory/from langchain.memory/g'
|
||||
git grep -l 'langchain\.__version__' | xargs sed -i '' 's/langchain\.__version__/langchain_community.__version__/g'
|
||||
git grep -l 'langchain\.document_loaders' | xargs sed -i '' 's/langchain\.document_loaders/langchain_community.document_loaders/g'
|
||||
git grep -l 'langchain\.callbacks' | xargs sed -i '' 's/langchain\.callbacks/langchain_community.callbacks/g'
|
||||
git grep -l 'langchain\.tools' | xargs sed -i '' 's/langchain\.tools/langchain_community.tools/g'
|
||||
git grep -l 'langchain\.llms' | xargs sed -i '' 's/langchain\.llms/langchain_community.llms/g'
|
||||
git grep -l 'import langchain$' | xargs sed -i '' 's/import\ langchain$/import\ langchain_community/g'
|
||||
git grep -l 'from\ langchain\ ' | xargs sed -i '' 's/from\ langchain\ /from\ langchain_community\ /g'
|
||||
git grep -l 'langchain_core.language_models.llmsten' | xargs sed -i '' 's/langchain_core.language_models.llmsten/langchain_community.llms.baseten/g'
|
||||
|
||||
# update all moved langchain files to re-export classes and functions
|
||||
cd ../langchain
|
||||
git checkout master -- langchain
|
||||
|
||||
python ../../.scripts/community_split/update_imports.py langchain/chat_loaders langchain_community.chat_loaders
|
||||
python ../../.scripts/community_split/update_imports.py langchain/callbacks langchain_community.callbacks
|
||||
python ../../.scripts/community_split/update_imports.py langchain/document_loaders langchain_community.document_loaders
|
||||
python ../../.scripts/community_split/update_imports.py langchain/docstore langchain_community.docstore
|
||||
python ../../.scripts/community_split/update_imports.py langchain/document_transformers langchain_community.document_transformers
|
||||
python ../../.scripts/community_split/update_imports.py langchain/embeddings langchain_community.embeddings
|
||||
python ../../.scripts/community_split/update_imports.py langchain/graphs langchain_community.graphs
|
||||
python ../../.scripts/community_split/update_imports.py langchain/llms langchain_community.llms
|
||||
python ../../.scripts/community_split/update_imports.py langchain/chat_models langchain_community.chat_models
|
||||
python ../../.scripts/community_split/update_imports.py langchain/memory/chat_message_histories langchain_community.chat_message_histories
|
||||
python ../../.scripts/community_split/update_imports.py langchain/storage langchain_community.storage
|
||||
python ../../.scripts/community_split/update_imports.py langchain/tools langchain_community.tools
|
||||
python ../../.scripts/community_split/update_imports.py langchain/utilities langchain_community.utilities
|
||||
python ../../.scripts/community_split/update_imports.py langchain/vectorstores langchain_community.vectorstores
|
||||
python ../../.scripts/community_split/update_imports.py langchain/retrievers langchain_community.retrievers
|
||||
python ../../.scripts/community_split/update_imports.py langchain/adapters langchain_community.adapters
|
||||
python ../../.scripts/community_split/update_imports.py langchain/agents/agent_toolkits langchain_community.agent_toolkits
|
||||
python ../../.scripts/community_split/update_imports.py langchain/cache.py langchain_community.cache
|
||||
python ../../.scripts/community_split/update_imports.py langchain/utils/math.py langchain_community.utils.math
|
||||
python ../../.scripts/community_split/update_imports.py langchain/utils/json_schema.py langchain_core.utils.json_schema
|
||||
python ../../.scripts/community_split/update_imports.py langchain/utils/html.py langchain_core.utils.html
|
||||
python ../../.scripts/community_split/update_imports.py langchain/utils/env.py langchain_core.utils.env
|
||||
python ../../.scripts/community_split/update_imports.py langchain/utils/strings.py langchain_core.utils.strings
|
||||
python ../../.scripts/community_split/update_imports.py langchain/utils/openai.py langchain_community.utils.openai
|
||||
python ../../.scripts/community_split/update_imports.py langchain/utils/openai_functions.py langchain_community.utils.openai_functions
|
||||
|
||||
# update core and openai imports
|
||||
git grep -l 'from langchain.llms.base ' | xargs sed -i '' 's/from langchain.llms.base /from langchain_core.language_models.llms /g'
|
||||
git grep -l 'from langchain.chat_models.base ' | xargs sed -i '' 's/from langchain.chat_models.base /from langchain_core.language_models.chat_models /g'
|
||||
git grep -l 'from langchain.tools.base' | xargs sed -i '' 's/from langchain.tools.base/from langchain_core.tools/g'
|
||||
|
||||
git grep -l 'langchain_core.language_models.llmsten' | xargs sed -i '' 's/langchain_core.language_models.llmsten/langchain_community.llms.baseten/g'
|
||||
|
||||
cd ..
|
||||
|
||||
mv community/langchain_community/utilities/loading.py langchain/langchain/utilities
|
||||
mv community/langchain_community/utilities/asyncio.py langchain/langchain/utilities
|
||||
|
||||
#git add partners
|
||||
git add core
|
||||
|
||||
# rm files from community that just export core classes
|
||||
rm community/langchain_community/{chat_models,llms,tools,embeddings,vectorstores,callbacks}/base.py
|
||||
rm community/tests/unit_tests/{chat_models,llms,tools,callbacks}/test_base.py
|
||||
rm community/tests/unit_tests/callbacks/test_manager.py
|
||||
rm community/langchain_community/callbacks/{stdout,streaming_stdout}.py
|
||||
rm community/langchain_community/callbacks/tracers/{base,evaluation,langchain,langchain_v1,log_stream,root_listeners,run_collector,schemas,stdout}.py
|
||||
|
||||
# keep export tests in langchain
|
||||
git checkout master -- langchain/tests/unit_tests/{chat_models,llms,tools,callbacks,document_loaders}/test_base.py
|
||||
git checkout master -- langchain/tests/unit_tests/{callbacks,docstore,document_loaders,document_transformers,embeddings,graphs,llms,chat_models,storage,tools,utilities,vectorstores}/test_imports.py
|
||||
git checkout master -- langchain/tests/unit_tests/callbacks/test_manager.py
|
||||
git checkout master -- langchain/tests/unit_tests/document_loaders/blob_loaders/test_public_api.py
|
||||
git checkout master -- langchain/tests/unit_tests/document_loaders/parsers/test_public_api.py
|
||||
git checkout master -- langchain/tests/unit_tests/vectorstores/test_public_api.py
|
||||
git checkout master -- langchain/tests/unit_tests/schema
|
||||
|
||||
# keep some non-integration stuff in langchain. rm from community and add back to langchain
|
||||
rm community/langchain_community/retrievers/{multi_query,multi_vector,contextual_compression,ensemble,merger_retriever,parent_document_retriever,re_phraser,web_research,time_weighted_retriever}.py
|
||||
rm -r community/langchain_community/retrievers/{self_query,document_compressors}
|
||||
rm community/tests/unit_tests/retrievers/test_{ensemble,multi_query,multi_vector,parent_document,time_weighted_retriever,web_research}.py
|
||||
rm community/tests/integration_tests/retrievers/test_{contextual_compression,merger_retriever}.py
|
||||
rm -r community/tests/unit_tests/retrievers/{self_query,document_compressors}
|
||||
rm -r community/tests/integration_tests/retrievers/document_compressors
|
||||
|
||||
rm community/langchain_community/agent_toolkits/{pandas,python,spark}/__init__.py
|
||||
rm community/langchain_community/tools/python/__init__.py
|
||||
|
||||
rm -r community/langchain_community/agent_toolkits/conversational_retrieval/
|
||||
rm -r community/langchain_community/agent_toolkits/vectorstore/
|
||||
rm community/langchain_community/callbacks/tracers/logging.py
|
||||
rm community/langchain_community/callbacks/{file,streaming_aiter_final_only,streaming_aiter,streaming_stdout_final_only}.py
|
||||
rm community/langchain_community/embeddings/cache.py
|
||||
rm community/langchain_community/storage/{encoder_backed,file_system,in_memory,_lc_store}.py
|
||||
rm community/langchain_community/tools/retriever.py
|
||||
rm community/tests/unit_tests/callbacks/tracers/test_logging.py
|
||||
rm community/tests/unit_tests/embeddings/test_caching.py
|
||||
rm community/tests/unit_tests/storage/test_{filesystem,in_memory,lc_store}.py
|
||||
|
||||
git checkout master -- langchain/langchain/retrievers/{multi_query,multi_vector,self_query/base,contextual_compression,ensemble,merger_retriever,parent_document_retriever,re_phraser,web_research,time_weighted_retriever}.py
|
||||
git checkout master -- langchain/langchain/retrievers/{self_query,document_compressors}
|
||||
git checkout master -- langchain/tests/unit_tests/retrievers/test_{ensemble,multi_query,multi_vector,parent_document,time_weighted_retriever,web_research}.py
|
||||
git checkout master -- langchain/tests/integration_tests/retrievers/test_{contextual_compression,merger_retriever}.py
|
||||
git checkout master -- langchain/tests/unit_tests/retrievers/{self_query,document_compressors}
|
||||
git checkout master -- langchain/tests/integration_tests/retrievers/document_compressors
|
||||
touch langchain/tests/unit_tests/{llms,chat_models,tools,callbacks,runnables,document_loaders,docstore,document_transformers,embeddings,graphs,storage,utilities,vectorstores,retrievers}/__init__.py
|
||||
touch langchain/tests/unit_tests/document_loaders/{blob_loaders,parsers}/__init__.py
|
||||
mv {community,langchain}/tests/unit_tests/retrievers/sequential_retriever.py
|
||||
|
||||
git checkout master -- langchain/langchain/agents/agent_toolkits/conversational_retrieval/
|
||||
git checkout master -- langchain/langchain/agents/agent_toolkits/vectorstore/
|
||||
git checkout master -- langchain/langchain/callbacks/tracers/logging.py
|
||||
git checkout master -- langchain/langchain/callbacks/{file,streaming_aiter_final_only,streaming_aiter,streaming_stdout_final_only}.py
|
||||
git checkout master -- langchain/langchain/embeddings/cache.py
|
||||
git checkout master -- langchain/langchain/storage/{encoder_backed,file_system,in_memory,_lc_store}.py
|
||||
git checkout master -- langchain/langchain/tools/retriever.py
|
||||
git checkout master -- langchain/tests/unit_tests/callbacks/tracers/{test_logging,__init__}.py
|
||||
git checkout master -- langchain/tests/unit_tests/embeddings/{__init__,test_caching}.py
|
||||
git checkout master -- langchain/tests/unit_tests/storage/test_{filesystem,in_memory,lc_store}.py
|
||||
git checkout master -- langchain/tests/unit_tests/storage/__init__.py
|
||||
|
||||
# cp lint scripts
|
||||
cp -r core/scripts community
|
||||
|
||||
# cp test helpers
|
||||
cp -r langchain/tests/integration_tests/examples community/tests
|
||||
cp -r langchain/tests/integration_tests/examples community/tests/integration_tests
|
||||
cp -r langchain/tests/unit_tests/examples community/tests/unit_tests
|
||||
cp langchain/tests/unit_tests/conftest.py community/tests/unit_tests
|
||||
cp community/tests/integration_tests/vectorstores/fake_embeddings.py langchain/tests/integration_tests/cache/
|
||||
cp langchain/tests/integration_tests/test_compile.py community/tests/integration_tests
|
||||
|
||||
# cp manually changed files
|
||||
cp -r ../.scripts/community_split/libs/* .
|
||||
|
||||
# mv some tests to integrations
|
||||
mv community/tests/{unit_tests,integration_tests}/document_loaders/test_telegram.py
|
||||
mv community/tests/{unit_tests,integration_tests}/document_loaders/parsers/test_docai.py
|
||||
mv community/tests/{unit_tests,integration_tests}/chat_message_histories/test_streamlit.py
|
||||
|
||||
# fix some final tests
|
||||
git grep -l 'integration_tests\.vectorstores\.fake_embeddings' langchain/tests | xargs sed -i '' 's/integration_tests\.vectorstores\.fake_embeddings/integration_tests.cache.fake_embeddings/g'
|
||||
touch community/langchain_community/agent_toolkits/amadeus/__init__.py
|
||||
|
||||
# format
|
||||
cd core
|
||||
make format
|
||||
cd ../langchain
|
||||
make format
|
||||
cd ../experimental
|
||||
make format
|
||||
cd ../community
|
||||
make format
|
||||
|
||||
cd ..
|
||||
sed -E -i '' '1 s/(.*)/\1\ \ \#\ noqa\:\ E501/g' langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py
|
||||
sed -E -i '' 's/import\ importlib$/import importlib.util/g' experimental/langchain_experimental/prompts/load.py
|
||||
git add .
|
||||
@@ -1,85 +0,0 @@
|
||||
import ast
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ImportTransformer(ast.NodeTransformer):
|
||||
def __init__(self, public_items, module_name):
|
||||
self.public_items = public_items
|
||||
self.module_name = module_name
|
||||
|
||||
def visit_Module(self, node):
|
||||
imports = [
|
||||
ast.ImportFrom(
|
||||
module=self.module_name,
|
||||
names=[ast.alias(name=item, asname=None)],
|
||||
level=0,
|
||||
)
|
||||
for item in self.public_items
|
||||
]
|
||||
all_assignment = ast.Assign(
|
||||
targets=[ast.Name(id="__all__", ctx=ast.Store())],
|
||||
value=ast.List(
|
||||
elts=[ast.Str(s=item) for item in self.public_items], ctx=ast.Load()
|
||||
),
|
||||
)
|
||||
node.body = imports + [all_assignment]
|
||||
return node
|
||||
|
||||
|
||||
def find_public_classes_and_methods(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
node = ast.parse(file.read(), filename=file_path)
|
||||
|
||||
public_items = []
|
||||
for item in node.body:
|
||||
if isinstance(item, ast.ClassDef) or isinstance(item, ast.FunctionDef):
|
||||
public_items.append(item.name)
|
||||
if (
|
||||
isinstance(item, ast.Assign)
|
||||
and hasattr(item.targets[0], "id")
|
||||
and item.targets[0].id not in ("__all__", "logger")
|
||||
):
|
||||
public_items.append(item.targets[0].id)
|
||||
|
||||
return public_items or None
|
||||
|
||||
|
||||
def process_file(file_path, module_name):
|
||||
public_items = find_public_classes_and_methods(file_path)
|
||||
if public_items is None:
|
||||
return
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
contents = file.read()
|
||||
tree = ast.parse(contents, filename=file_path)
|
||||
|
||||
tree = ImportTransformer(public_items, module_name).visit(tree)
|
||||
tree = ast.fix_missing_locations(tree)
|
||||
|
||||
with open(file_path, "w") as file:
|
||||
file.write(ast.unparse(tree))
|
||||
|
||||
|
||||
def process_directory(directory_path, base_module_name):
|
||||
if Path(directory_path).is_file():
|
||||
process_file(directory_path, base_module_name)
|
||||
else:
|
||||
for root, dirs, files in os.walk(directory_path):
|
||||
for filename in files:
|
||||
if filename.endswith(".py") and not filename.startswith("_"):
|
||||
file_path = os.path.join(root, filename)
|
||||
relative_path = os.path.relpath(file_path, directory_path)
|
||||
module_name = f"{base_module_name}.{os.path.splitext(relative_path)[0].replace(os.sep, '.')}"
|
||||
process_file(file_path, module_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage: python script_name.py <directory_path> <base_module_name>")
|
||||
sys.exit(1)
|
||||
|
||||
directory_path = sys.argv[1]
|
||||
base_module_name = sys.argv[2]
|
||||
process_directory(directory_path, base_module_name)
|
||||
@@ -3,8 +3,7 @@
|
||||
⚡ Building applications with LLMs through composability ⚡
|
||||
|
||||
[](https://github.com/langchain-ai/langchain/releases)
|
||||
[](https://github.com/langchain-ai/langchain/actions/workflows/langchain_ci.yml)
|
||||
[](https://github.com/langchain-ai/langchain/actions/workflows/langchain_experimental_ci.yml)
|
||||
[](https://github.com/langchain-ai/langchain/actions/workflows/check_diffs.yml)
|
||||
[](https://pepy.tech/project/langchain)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://twitter.com/langchainai)
|
||||
|
||||
@@ -63,7 +63,7 @@
|
||||
"1. Create an access token via the Developer Playground for your workspace. [Detailed instructions](https://help.docugami.com/home/docugami-api).\n",
|
||||
"1. Add your documents (PDF \\[scanned or digital\\], DOC or DOCX) to Docugami for processing. There are two ways to do this:\n",
|
||||
" 1. Use the simple Docugami web experience. [Detailed instructions](https://help.docugami.com/home/adding-documents).\n",
|
||||
" 1. Use the [Docugami API](https://api-docs.docugami.com), specifically the [documents](https://api-docs.docugami.com/#tag/documents/operation/upload-document) endpoint. Code samples are available for [python](../upload_file/) and [JavaScript](../../js/upload-file/) or you can use the [docugami](https://pypi.org/project/docugami/) python library.\n",
|
||||
" 1. Use the [Docugami API](https://api-docs.docugami.com), specifically the [documents](https://api-docs.docugami.com/#tag/documents/operation/upload-document) endpoint. You can also use the [docugami python library](https://pypi.org/project/docugami/) as a convenient wrapper.\n",
|
||||
"\n",
|
||||
"Once your documents are in Docugami, they are processed and organized into sets of similar documents, e.g. NDAs, Lease Agreements, and Service Agreements. Docugami is not limited to any particular types of documents, and the clusters created depend on your particular documents. You can [change the docset assignments](https://help.docugami.com/home/working-with-the-doc-sets-view) later if you wish. You can monitor file status in the simple Docugami webapp, or use a [webhook](https://api-docs.docugami.com/#tag/webhooks) to be informed when your documents are done processing.\n",
|
||||
"\n",
|
||||
@@ -916,6 +916,20 @@
|
||||
"source": [
|
||||
"llama2_chain.invoke(\"What was the learning rate for LLaMA2?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "94826165",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Docugami KG-RAG Template\n",
|
||||
"\n",
|
||||
"Docugami also provides a [langchain template](https://github.com/docugami/langchain-template-docugami-kg-rag) that you can integrate into your langchain projects.\n",
|
||||
"\n",
|
||||
"Here's a walkthrough of how you can do this.\n",
|
||||
"\n",
|
||||
"[](https://www.youtube.com/watch?v=xOHOmL1NFMg)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -42,7 +42,7 @@
|
||||
"* We will use Open Clip multi-modal embeddings.\n",
|
||||
"* We will use [Chroma](https://www.trychroma.com/) with support for multi-modal.\n",
|
||||
"\n",
|
||||
"A seperate cookbook highlights `Options 2 and 3` [here](https://github.com/langchain-ai/langchain/blob/master/cookbook/Multi_modal_RAG.ipynb).\n",
|
||||
"A separate cookbook highlights `Options 2 and 3` [here](https://github.com/langchain-ai/langchain/blob/master/cookbook/Multi_modal_RAG.ipynb).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
||||
@@ -321,7 +321,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Requires:\n",
|
||||
"# pip install langchain docarray\n",
|
||||
"# pip install langchain docarray tiktoken\n",
|
||||
"\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
|
||||
@@ -69,7 +69,7 @@
|
||||
"\n",
|
||||
"[`Ollama`](https://ollama.ai/) is one way to easily run inference on macOS.\n",
|
||||
" \n",
|
||||
"The instructions [here](docs/integrations/llms/ollama) provide details, which we summarize:\n",
|
||||
"The instructions [here](https://github.com/jmorganca/ollama?tab=readme-ov-file#ollama) provide details, which we summarize:\n",
|
||||
" \n",
|
||||
"* [Download and run](https://ollama.ai/download) the app\n",
|
||||
"* From command line, fetch a model from this [list of options](https://github.com/jmorganca/ollama): e.g., `ollama pull llama2`\n",
|
||||
@@ -197,10 +197,10 @@
|
||||
"\n",
|
||||
"### Ollama\n",
|
||||
"\n",
|
||||
"With [Ollama](docs/integrations/llms/ollama), fetch a model via `ollama pull <model family>:<tag>`:\n",
|
||||
"With [Ollama](https://github.com/jmorganca/ollama), fetch a model via `ollama pull <model family>:<tag>`:\n",
|
||||
"\n",
|
||||
"* E.g., for Llama-7b: `ollama pull llama2` will download the most basic version of the model (e.g., smallest # parameters and 4 bit quantization)\n",
|
||||
"* We can also specify a particular version from the [model list](https://github.com/jmorganca/ollama), e.g., `ollama pull llama2:13b`\n",
|
||||
"* We can also specify a particular version from the [model list](https://github.com/jmorganca/ollama?tab=readme-ov-file#model-library), e.g., `ollama pull llama2:13b`\n",
|
||||
"* See the full set of parameters on the [API reference page](https://api.python.langchain.com/en/latest/llms/langchain.llms.ollama.Ollama.html)"
|
||||
]
|
||||
},
|
||||
@@ -608,7 +608,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.1"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
323
docs/docs/integrations/chat/google_generative_ai.ipynb
Normal file
323
docs/docs/integrations/chat/google_generative_ai.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -533,7 +533,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -440,7 +440,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -11,28 +11,57 @@ pip install cohere
|
||||
|
||||
Get a [Cohere api key](https://dashboard.cohere.ai/) and set it as an environment variable (`COHERE_API_KEY`)
|
||||
|
||||
## Cohere langchain integrations
|
||||
|
||||
## LLM
|
||||
|API|description|Endpoint docs|Import|Example usage|
|
||||
|---|---|---|---|---|
|
||||
|Chat|Build chat bots|[chat](https://docs.cohere.com/reference/chat)|`from langchain.chat_models import ChatCohere`|[cohere.ipynb](/docs/docs/integrations/chat/cohere.ipynb)|
|
||||
|LLM|Generate text|[generate](https://docs.cohere.com/reference/generate)|`from langchain.llms import Cohere`|[cohere.ipynb](/docs/docs/integrations/llms/cohere.ipynb)|
|
||||
|RAG Retriever|Connect to external data sources|[chat + rag](https://docs.cohere.com/reference/chat)|`from langchain.retrievers import CohereRagRetriever`|[cohere.ipynb](/docs/docs/integrations/retrievers/cohere.ipynb)|
|
||||
|Text Embedding|Embed strings to vectors|[embed](https://docs.cohere.com/reference/embed)|`from langchain.embeddings import CohereEmbeddings`|[cohere.ipynb](/docs/docs/integrations/text_embedding/cohere.ipynb)|
|
||||
|Rerank Retriever|Rank strings based on relevance|[rerank](https://docs.cohere.com/reference/rerank)|`from langchain.retrievers.document_compressors import CohereRerank`|[cohere.ipynb](/docs/docs/integrations/retrievers/cohere-reranker.ipynb)|
|
||||
|
||||
## Quick copy examples
|
||||
|
||||
### Chat
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatCohere
|
||||
from langchain.schema import HumanMessage
|
||||
chat = ChatCohere()
|
||||
messages = [HumanMessage(content="knock knock")]
|
||||
print(chat(messages))
|
||||
```
|
||||
|
||||
### LLM
|
||||
|
||||
There exists an Cohere LLM wrapper, which you can access with
|
||||
See a [usage example](/docs/integrations/llms/cohere).
|
||||
|
||||
```python
|
||||
from langchain.llms import Cohere
|
||||
|
||||
llm = Cohere(model="command")
|
||||
print(llm.invoke("Come up with a pet name"))
|
||||
```
|
||||
|
||||
## Text Embedding Model
|
||||
|
||||
There exists an Cohere Embedding model, which you can access with
|
||||
```python
|
||||
from langchain.embeddings import CohereEmbeddings
|
||||
```
|
||||
For a more detailed walkthrough of this, see [this notebook](/docs/integrations/text_embedding/cohere)
|
||||
|
||||
## Retriever
|
||||
|
||||
See a [usage example](/docs/integrations/retrievers/cohere-reranker).
|
||||
### RAG Retriever
|
||||
|
||||
```python
|
||||
from langchain.retrievers.document_compressors import CohereRerank
|
||||
from langchain.chat_models import ChatCohere
|
||||
from langchain.retrievers import CohereRagRetriever
|
||||
from langchain.schema.document import Document
|
||||
|
||||
rag = CohereRagRetriever(llm=ChatCohere())
|
||||
print(rag.get_relevant_documents("What is cohere ai?"))
|
||||
```
|
||||
|
||||
### Text Embedding
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatCohere
|
||||
from langchain.retrievers import CohereRagRetriever
|
||||
from langchain.schema.document import Document
|
||||
|
||||
rag = CohereRagRetriever(llm=ChatCohere())
|
||||
print(rag.get_relevant_documents("What is cohere ai?"))
|
||||
```
|
||||
|
||||
@@ -79,3 +79,9 @@ Databricks as an LLM provider
|
||||
|
||||
The notebook [Wrap Databricks endpoints as LLMs](/docs/integrations/llms/databricks#wrapping-a-serving-endpoint-custom-model) demonstrates how to serve a custom model that has been registered by MLflow as a Databricks endpoint.
|
||||
It supports two types of endpoints: the serving endpoint, which is recommended for both production and development, and the cluster driver proxy app, which is recommended for interactive development.
|
||||
|
||||
|
||||
Databricks Vector Search
|
||||
------------------------
|
||||
|
||||
Databricks Vector Search is a serverless similarity search engine that allows you to store a vector representation of your data, including metadata, in a vector database. With Vector Search, you can create auto-updating vector search indexes from Delta tables managed by Unity Catalog and query them with a simple API to return the most similar vectors. See the notebook [Databricks Vector Search](/docs/integrations/vectorstores/databricks_vector_search) for instructions to use it with LangChain.
|
||||
|
||||
55
docs/docs/integrations/providers/ollama.mdx
Normal file
55
docs/docs/integrations/providers/ollama.mdx
Normal file
@@ -0,0 +1,55 @@
|
||||
# Ollama
|
||||
|
||||
>[Ollama](https://ollama.ai/) is a python library. It allows you to run open-source large language models,
|
||||
> such as LLaMA2, locally.
|
||||
>
|
||||
>`Ollama` bundles model weights, configuration, and data into a single package, defined by a Modelfile.
|
||||
>It optimizes setup and configuration details, including GPU usage.
|
||||
>For a complete list of supported models and model variants, see the [Ollama model library](https://ollama.ai/library).
|
||||
|
||||
See [this guide](https://python.langchain.com/docs/guides/local_llms#quickstart) for more details
|
||||
on how to use `Ollama` with LangChain.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Follow [these instructions](https://github.com/jmorganca/ollama?tab=readme-ov-file#ollama)
|
||||
to set up and run a local Ollama instance.
|
||||
To use, you should set up the environment variables `ANYSCALE_API_BASE` and
|
||||
`ANYSCALE_API_KEY`.
|
||||
|
||||
|
||||
## LLM
|
||||
|
||||
```python
|
||||
from langchain.llms import Ollama
|
||||
```
|
||||
|
||||
See the notebook example [here](/docs/integrations/llms/ollama).
|
||||
|
||||
## Chat Models
|
||||
|
||||
### Chat Ollama
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOllama
|
||||
```
|
||||
|
||||
See the notebook example [here](/docs/integrations/chat/ollama).
|
||||
|
||||
### Ollama functions
|
||||
|
||||
```python
|
||||
from langchain_experimental.llms.ollama_functions import OllamaFunctions
|
||||
```
|
||||
|
||||
See the notebook example [here](/docs/integrations/chat/ollama_functions).
|
||||
|
||||
## Embedding models
|
||||
|
||||
```python
|
||||
from langchain.embeddings import OllamaEmbeddings
|
||||
```
|
||||
|
||||
See the notebook example [here](/docs/integrations/text_embedding/ollama).
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
"ABS_PATH = os.path.dirname(os.path.abspath(__file__))\n",
|
||||
"DB_DIR = os.path.join(ABS_PATH, \"db\")\n",
|
||||
"\n",
|
||||
"# Instantiate 2 diff cromadb indexs, each one with a diff embedding.\n",
|
||||
"# Instantiate 2 diff chromadb indexes, each one with a diff embedding.\n",
|
||||
"client_settings = chromadb.config.Settings(\n",
|
||||
" is_persistent=True,\n",
|
||||
" persist_directory=DB_DIR,\n",
|
||||
@@ -68,7 +68,7 @@
|
||||
" search_type=\"mmr\", search_kwargs={\"k\": 5, \"include_metadata\": True}\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# The Lord of the Retrievers will hold the ouput of boths retrievers and can be used as any other\n",
|
||||
"# The Lord of the Retrievers will hold the output of both retrievers and can be used as any other\n",
|
||||
"# retriever on different types of chains.\n",
|
||||
"lotr = MergerRetriever(retrievers=[retriever_all, retriever_multi_qa])"
|
||||
]
|
||||
@@ -145,7 +145,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Re-order results to avoid performance degradation.\n",
|
||||
"No matter the architecture of your model, there is a sustancial performance degradation when you include 10+ retrieved documents.\n",
|
||||
"No matter the architecture of your model, there is a substantial performance degradation when you include 10+ retrieved documents.\n",
|
||||
"In brief: When models must access relevant information in the middle of long contexts, then tend to ignore the provided documents.\n",
|
||||
"See: https://arxiv.org/abs//2307.03172"
|
||||
]
|
||||
@@ -157,7 +157,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# You can use an additional document transformer to reorder documents after removing redundance.\n",
|
||||
"# You can use an additional document transformer to reorder documents after removing redundancy.\n",
|
||||
"from langchain.document_transformers import LongContextReorder\n",
|
||||
"\n",
|
||||
"filter = EmbeddingsRedundantFilter(embeddings=filter_embeddings)\n",
|
||||
|
||||
@@ -215,7 +215,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"id": "579f0677-aa06-4ad8-a816-3520c8d6923c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -50,7 +50,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "22b09777-5ba3-4fbe-81cf-a702a55df9c4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -62,45 +62,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "c26fca9f-cfdb-45e5-a0bd-f677ff8b9d92",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdin",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Enter your HF API Key:\n",
|
||||
"\n",
|
||||
" ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"huggingfacehub_api_token = getpass(\"Enter your HF API Key:\\n\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 3,
|
||||
"id": "f9a92970-16f4-458c-b186-2a83e9f7d840",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = HuggingFaceHubEmbeddings(\n",
|
||||
" model=\"http://localhost:8080\", huggingfacehub_api_token=huggingfacehub_api_token\n",
|
||||
")"
|
||||
"embeddings = HuggingFaceHubEmbeddings(model=\"http://localhost:8080\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 4,
|
||||
"id": "42105438-9fee-460a-9c52-b7c595722758",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -112,7 +86,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 5,
|
||||
"id": "20167762-0988-4205-bbd4-1f20fd9dd247",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -124,7 +98,7 @@
|
||||
"[0.018113142, 0.00302585, -0.049911194]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -136,7 +110,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 6,
|
||||
"id": "54b87cf6-86ad-46f5-b2cd-17eb43cb4d0b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -145,6 +119,29 @@
|
||||
"source": [
|
||||
"doc_result = embeddings.embed_documents([text])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "6fba8be9-fabf-4972-8334-aa56ed9893e1",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[0.018113142, 0.00302585, -0.049911194]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"doc_result[0][:3]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -0,0 +1,385 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "5a8c5767-adfe-4b9d-a665-a898756d7a6c",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Databricks Vector Search\n",
|
||||
"\n",
|
||||
"Databricks Vector Search is a serverless similarity search engine that allows you to store a vector representation of your data, including metadata, in a vector database. With Vector Search, you can create auto-updating vector search indexes from Delta tables managed by Unity Catalog and query them with a simple API to return the most similar vectors.\n",
|
||||
"\n",
|
||||
"This notebook shows how to use LangChain with Databricks Vector Search."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "746cfacd-fb30-48fd-96a5-bbcc0d15ae49",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Install `databricks-vectorsearch` and related Python packages used in this notebook."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "9258a3e7-e050-4390-9d3f-9adff1460dab",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install langchain-core databricks-vectorsearch openai tiktoken"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "f4f09d6d-002d-4cb0-a664-0a83bd2a13da",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Use `OpenAIEmbeddings` for the embeddings."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "f11b902d-a772-45e0-bbd9-526218b717cc",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "59b568f3-8db2-427e-9a4a-1df6fa7a1739",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Split documents and get embeddings."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "b28e1c7b-eae4-4be8-abbd-8433c7557dc2",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"\n",
|
||||
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"emb_dim = len(embeddings.embed_query(\"hello\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "e8fcdda1-208a-45c9-816e-ff0d2c8f59d6",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Setup Databricks Vector Search client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "9b87fff1-99e5-4d9f-aba3-d21a7ccc498e",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from databricks.vector_search.client import VectorSearchClient\n",
|
||||
"\n",
|
||||
"vsc = VectorSearchClient()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create a Vector Search Endpoint\n",
|
||||
"This endpoint is used to create and access vector search indexes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vsc.create_endpoint(name=\"vector_search_demo_endpoint\", endpoint_type=\"STANDARD\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "81090f87-3efd-4c1e-9f58-8d6adba7553d",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Create Direct Vector Access Index\n",
|
||||
"Direct Vector Access Index supports direct read and write of embedding vectors and metadata through a REST API or an SDK. For this index, you manage embedding vectors and index updates yourself."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "9389ec6b-5885-411f-a26e-1a4b03651f5c",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vector_search_endpoint_name = \"vector_search_demo_endpoint\"\n",
|
||||
"index_name = \"ml.llm.demo_index\"\n",
|
||||
"\n",
|
||||
"index = vsc.create_direct_access_index(\n",
|
||||
" endpoint_name=vector_search_endpoint_name,\n",
|
||||
" index_name=index_name,\n",
|
||||
" primary_key=\"id\",\n",
|
||||
" embedding_dimension=emb_dim,\n",
|
||||
" embedding_vector_column=\"text_vector\",\n",
|
||||
" schema={\n",
|
||||
" \"id\": \"string\",\n",
|
||||
" \"text\": \"string\",\n",
|
||||
" \"text_vector\": \"array<float>\",\n",
|
||||
" \"source\": \"string\",\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"index.describe()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "047a14c9-2f06-4f74-883d-815b2c69786c",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.vectorstores import DatabricksVectorSearch\n",
|
||||
"\n",
|
||||
"dvs = DatabricksVectorSearch(\n",
|
||||
" index, text_column=\"text\", embedding=embeddings, columns=[\"source\"]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "951bd581-2ced-497f-9c70-4fda902fd3a1",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Add docs to the index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "1e85f235-901f-4cf5-845f-5dbf4ce42078",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dvs.add_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "8bea6f0a-b305-455a-acba-99cc8c9350b5",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Similarity search"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "25c5a044-a61a-4929-9e65-a0f0462925df",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"dvs.similarity_search(query)\n",
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "46e3f41b-dac2-4bed-91cb-a3914c25d275",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Work with Delta Sync Index\n",
|
||||
"\n",
|
||||
"You can also use `DatabricksVectorSearch` to search in a Delta Sync Index. Delta Sync Index automatically syncs from a Delta table. You don't need to call `add_text`/`add_documents` manually. See [Databricks documentation page](https://docs.databricks.com/en/generative-ai/vector-search.html#delta-sync-index-with-managed-embeddings) for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "0c1f448e-77ca-41ce-887c-15948e866a0e",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dvs_delta_sync = DatabricksVectorSearch(\"catalog_name.schema_name.delta_sync_index\")\n",
|
||||
"dvs_delta_sync.similarity_search(query)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+notebook": {
|
||||
"dashboards": [],
|
||||
"language": "python",
|
||||
"notebookMetadata": {
|
||||
"pythonIndentUnit": 2
|
||||
},
|
||||
"notebookName": "databricks_vector_search",
|
||||
"widgets": {}
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
441
docs/docs/integrations/vectorstores/yellowbrick.ipynb
Normal file
441
docs/docs/integrations/vectorstores/yellowbrick.ipynb
Normal file
@@ -0,0 +1,441 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7e80d338-091b-421c-ac66-5950b14944b2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Yellowbrick\n",
|
||||
"\n",
|
||||
"[Yellowbrick](https://yellowbrick.com/yellowbrick-data-warehouse/) is an elastic, massively parallel processing (MPP) SQL database that runs in the cloud and on-premises, using kubernetes for scale, resilience and cloud portability. Yellowbrick is designed to address the largest and most complex business-critical data warehousing use cases. The efficiency at scale that Yellowbrick provides also enables it to be used as a high performance and scalable vector database to store and search vectors with SQL. \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9291d9e5-d404-405f-8307-87d80d0233f2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using Yellowbrick as the vector store for ChatGpt\n",
|
||||
"\n",
|
||||
"This tutorial demonstrates how to create a simple chatbot backed by ChatGpt that uses Yellowbrick as a vector store to support Retrieval Augmented Generation (RAG). What you'll need:\n",
|
||||
"\n",
|
||||
"1. An account on the [Yellowbrick sandbox](https://cloudlabs.yellowbrick.com/)\n",
|
||||
"2. An api key from [OpenAI](https://platform.openai.com/)\n",
|
||||
"\n",
|
||||
"The tutorial is divided into five parts. First we'll use langchain to create a baseline chatbot to interact with ChatGpt without a vector store. Second, we'll create an embeddings table in Yellowbrick that will represent the vector store. Third, we'll load a series of documents (the Administration chapter of the Yellowbrick Manual). Fourth, we'll create the vector representation of those documents and store in a Yellowbrick table. Lastly, we'll send the same queries to the improved chatbox to see the results.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "924d1c25",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install all needed libraries\n",
|
||||
"%pip install langchain\n",
|
||||
"%pip install openai\n",
|
||||
"%pip install psycopg2-binary\n",
|
||||
"%pip install tiktoken"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5928e9c7-7666-4282-9cb4-00d919228ce0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup: Enter the information used to connect to Yellowbrick and OpenAI API\n",
|
||||
"\n",
|
||||
"Our chatbot integrates with ChatGpt via the langchain library, so you'll need an API key from OpenAI first:\n",
|
||||
"\n",
|
||||
"To get an api key for OpenAI:\n",
|
||||
"1. Register at https://platform.openai.com/\n",
|
||||
"2. Add a payment method - You're unlikely to go over free quota\n",
|
||||
"3. Create an API key\n",
|
||||
"\n",
|
||||
"You'll also need your Username, Password, and Database name from the welcome email when you sign up for the Yellowbrick Sandbox Account.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aaf215bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The following should be modified to include the information for your Yellowbrick database and OpenAPI Key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a4393d8d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Modify these values to match your Yellowbrick Sandbox and OpenAI API Key\n",
|
||||
"YBUSER = \"[SANDBOX USER]\"\n",
|
||||
"YBPASSWORD = \"[SANDBOX PASSWORD]\"\n",
|
||||
"YBDATABASE = \"[SANDBOX_DATABASE]\"\n",
|
||||
"YBHOST = \"trialsandbox.sandbox.aws.yellowbrickcloud.com\"\n",
|
||||
"\n",
|
||||
"OPENAI_API_KEY = \"[OPENAI API KEY]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c186f99b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import libraries and setup keys / login info\n",
|
||||
"import os\n",
|
||||
"import pathlib\n",
|
||||
"import re\n",
|
||||
"import sys\n",
|
||||
"import urllib.parse as urlparse\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"import psycopg2\n",
|
||||
"from IPython.display import Markdown, display\n",
|
||||
"from langchain.chains import LLMChain, RetrievalQAWithSourcesChain\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.docstore.document import Document\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import Yellowbrick\n",
|
||||
"\n",
|
||||
"# Establish connection parameters to Yellowbrick. If you've signed up for Sandbox, fill in the information from your welcome mail here:\n",
|
||||
"yellowbrick_connection_string = (\n",
|
||||
" f\"postgres://{urlparse.quote(YBUSER)}:{YBPASSWORD}@{YBHOST}:5432/{YBDATABASE}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"YB_DOC_DATABASE = \"sample_data\"\n",
|
||||
"YB_DOC_TABLE = \"yellowbrick_documentation\"\n",
|
||||
"embedding_table = \"my_embeddings\"\n",
|
||||
"\n",
|
||||
"# API Key for OpenAI. Signup at https://platform.openai.com\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n",
|
||||
"\n",
|
||||
"from langchain.prompts.chat import (\n",
|
||||
" ChatPromptTemplate,\n",
|
||||
" HumanMessagePromptTemplate,\n",
|
||||
" SystemMessagePromptTemplate,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e955b19b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part 1: Creating a baseline chatbot backed by ChatGpt without a Vector Store\n",
|
||||
"\n",
|
||||
"We will use langchain to query ChatGPT. As there is no Vector Store, ChatGPT will have no context in which to answer the question.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "538f8b96-1b54-4f2f-9239-dfb5cc7fd259",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set up the chat model and specific prompt\n",
|
||||
"system_template = \"\"\"If you don't know the answer, Make up your best guess.\"\"\"\n",
|
||||
"messages = [\n",
|
||||
" SystemMessagePromptTemplate.from_template(system_template),\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{question}\"),\n",
|
||||
"]\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(messages)\n",
|
||||
"\n",
|
||||
"chain_type_kwargs = {\"prompt\": prompt}\n",
|
||||
"llm = ChatOpenAI(\n",
|
||||
" model_name=\"gpt-3.5-turbo\", # Modify model_name if you have access to GPT-4\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=256,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = LLMChain(\n",
|
||||
" llm=llm,\n",
|
||||
" prompt=prompt,\n",
|
||||
" verbose=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def print_result_simple(query):\n",
|
||||
" result = chain(query)\n",
|
||||
" output_text = f\"\"\"### Question:\n",
|
||||
" {query}\n",
|
||||
" ### Answer: \n",
|
||||
" {result['text']}\n",
|
||||
" \"\"\"\n",
|
||||
" display(Markdown(output_text))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Use the chain to query\n",
|
||||
"print_result_simple(\"How many databases can be in a Yellowbrick Instance?\")\n",
|
||||
"\n",
|
||||
"print_result_simple(\"What's an easy way to add users in bulk to Yellowbrick?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "798c7aa6-5904-4860-b4a9-896fe7681a45",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part 2: Connect to Yellowbrick and create the embedding tables\n",
|
||||
"\n",
|
||||
"To load your document embeddings into Yellowbrick, you should create your own table for storing them in. Note that the \n",
|
||||
"Yellowbrick database that the table is in has to be UTF-8 encoded. \n",
|
||||
"\n",
|
||||
"Create a table in a UTF-8 database with the following schema, providing a table name of your choice:\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e72daf30-6160-4ff3-921f-c4c9da329991",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Establish a connection to the Yellowbrick database\n",
|
||||
"try:\n",
|
||||
" conn = psycopg2.connect(yellowbrick_connection_string)\n",
|
||||
"except psycopg2.Error as e:\n",
|
||||
" print(f\"Error connecting to the database: {e}\")\n",
|
||||
" exit(1)\n",
|
||||
"\n",
|
||||
"# Create a cursor object using the connection\n",
|
||||
"cursor = conn.cursor()\n",
|
||||
"\n",
|
||||
"# Define the SQL statement to create a table\n",
|
||||
"create_table_query = f\"\"\"\n",
|
||||
"CREATE TABLE if not exists {embedding_table} (\n",
|
||||
" id uuid,\n",
|
||||
" embedding_id integer,\n",
|
||||
" text character varying(60000),\n",
|
||||
" metadata character varying(1024),\n",
|
||||
" embedding double precision\n",
|
||||
")\n",
|
||||
"DISTRIBUTE ON (id);\n",
|
||||
"truncate table {embedding_table};\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"# Execute the SQL query to create a table\n",
|
||||
"try:\n",
|
||||
" cursor.execute(create_table_query)\n",
|
||||
" print(f\"Table '{embedding_table}' created successfully!\")\n",
|
||||
"except psycopg2.Error as e:\n",
|
||||
" print(f\"Error creating table: {e}\")\n",
|
||||
" conn.rollback()\n",
|
||||
"\n",
|
||||
"# Commit changes and close the cursor and connection\n",
|
||||
"conn.commit()\n",
|
||||
"cursor.close()\n",
|
||||
"conn.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8690ac3d-a775-4b0c-9499-9825885f3c82",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part 3: Extract the documents to index from an existing table in Yellowbrick\n",
|
||||
"Extract document paths and contents from an existing Yellowbrick table. We'll use these documents to create embeddings from in the next step.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "60ab85bb-7901-44cf-b149-10fcde2ab91d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"yellowbrick_doc_connection_string = (\n",
|
||||
" f\"postgres://{urlparse.quote(YBUSER)}:{YBPASSWORD}@{YBHOST}:5432/{YB_DOC_DATABASE}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Establish a connection to the Yellowbrick database\n",
|
||||
"conn = psycopg2.connect(yellowbrick_doc_connection_string)\n",
|
||||
"\n",
|
||||
"# Create a cursor object\n",
|
||||
"cursor = conn.cursor()\n",
|
||||
"\n",
|
||||
"# Query to select all documents from the table\n",
|
||||
"query = f\"SELECT path, document FROM {YB_DOC_TABLE}\"\n",
|
||||
"\n",
|
||||
"# Execute the query\n",
|
||||
"cursor.execute(query)\n",
|
||||
"\n",
|
||||
"# Fetch all documents\n",
|
||||
"yellowbrick_documents = cursor.fetchall()\n",
|
||||
"\n",
|
||||
"print(f\"Extracted {len(yellowbrick_documents)} documents successfully!\")\n",
|
||||
"\n",
|
||||
"# Close the cursor and connection\n",
|
||||
"cursor.close()\n",
|
||||
"conn.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b62b4150-2aa3-453e-a4db-81a2f8a11e70",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part 4: Load the Yellowbrick Vector Store with Documents\n",
|
||||
"Go through documents, split them into digestable chunks, create the embedding and insert into the Yellowbrick table. This takes around 5 minutes.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de914b10-850e-4c5b-a09b-c6a14006637c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Split documents into chunks for conversion to embeddings\n",
|
||||
"DOCUMENT_BASE_URL = \"https://docs.yellowbrick.com/6.7.1/\" # Actual URL\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"separator = \"\\n## \" # This separator assumes Markdown docs from the repo uses ### as logical main header most of the time\n",
|
||||
"chunk_size_limit = 2000\n",
|
||||
"max_chunk_overlap = 200\n",
|
||||
"\n",
|
||||
"documents = [\n",
|
||||
" Document(\n",
|
||||
" page_content=document[1],\n",
|
||||
" metadata={\"source\": DOCUMENT_BASE_URL + document[0].replace(\".md\", \".html\")},\n",
|
||||
" )\n",
|
||||
" for document in yellowbrick_documents\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(\n",
|
||||
" chunk_size=chunk_size_limit,\n",
|
||||
" chunk_overlap=max_chunk_overlap,\n",
|
||||
" separators=[separator, \"\\nn\", \"\\n\", \",\", \" \", \"\"],\n",
|
||||
")\n",
|
||||
"split_docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"docs_text = [doc.page_content for doc in split_docs]\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"vector_store = Yellowbrick.from_documents(\n",
|
||||
" documents=split_docs,\n",
|
||||
" embedding=embeddings,\n",
|
||||
" connection_string=yellowbrick_connection_string,\n",
|
||||
" table=embedding_table,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"Created vector store with {len(documents)} documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "beee89f5-0f1e-4c6e-91a9-44c10762d466",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part 5: Creating a chatbot that uses Yellowbrick as the vector store\n",
|
||||
"\n",
|
||||
"Next, we add Yellowbrick as a vector store. The vector store has been populated with embeddings representing the administrative chapter of the Yellowbrick product documentation.\n",
|
||||
"\n",
|
||||
"We'll send the same queries as above to see the impoved responses.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7daa9d4f-7804-4cfa-9873-415998d5e0f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"system_template = \"\"\"Use the following pieces of context to answer the users question.\n",
|
||||
"Take note of the sources and include them in the answer in the format: \"SOURCES: source1 source2\", use \"SOURCES\" in capital letters regardless of the number of sources.\n",
|
||||
"If you don't know the answer, just say that \"I don't know\", don't try to make up an answer.\n",
|
||||
"----------------\n",
|
||||
"{summaries}\"\"\"\n",
|
||||
"messages = [\n",
|
||||
" SystemMessagePromptTemplate.from_template(system_template),\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{question}\"),\n",
|
||||
"]\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(messages)\n",
|
||||
"\n",
|
||||
"vector_store = Yellowbrick(\n",
|
||||
" OpenAIEmbeddings(),\n",
|
||||
" yellowbrick_connection_string,\n",
|
||||
" embedding_table, # Change the table name to reflect your embeddings\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain_type_kwargs = {\"prompt\": prompt}\n",
|
||||
"llm = ChatOpenAI(\n",
|
||||
" model_name=\"gpt-3.5-turbo\", # Modify model_name if you have access to GPT-4\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=256,\n",
|
||||
")\n",
|
||||
"chain = RetrievalQAWithSourcesChain.from_chain_type(\n",
|
||||
" llm=llm,\n",
|
||||
" chain_type=\"stuff\",\n",
|
||||
" retriever=vector_store.as_retriever(search_kwargs={\"k\": 5}),\n",
|
||||
" return_source_documents=True,\n",
|
||||
" chain_type_kwargs=chain_type_kwargs,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def print_result_sources(query):\n",
|
||||
" result = chain(query)\n",
|
||||
" output_text = f\"\"\"### Question: \n",
|
||||
" {query}\n",
|
||||
" ### Answer: \n",
|
||||
" {result['answer']}\n",
|
||||
" ### Sources: \n",
|
||||
" {result['sources']}\n",
|
||||
" ### All relevant sources:\n",
|
||||
" {', '.join(list(set([doc.metadata['source'] for doc in result['source_documents']])))}\n",
|
||||
" \"\"\"\n",
|
||||
" display(Markdown(output_text))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Use the chain to query\n",
|
||||
"\n",
|
||||
"print_result_sources(\"How many databases can be in a Yellowbrick Instance?\")\n",
|
||||
"\n",
|
||||
"print_result_sources(\"Whats an easy way to add users in bulk to Yellowbrick?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "697c8a38",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Next Steps:\n",
|
||||
"\n",
|
||||
"This code can be modified to ask different questions. You can also load your own documents into the vector store. The langchain module is very flexible and can parse a large variety of files (including HTML, PDF, etc).\n",
|
||||
"\n",
|
||||
"You can also modify this to use Huggingface embeddings models and Meta's Llama 2 LLM for a completely private chatbox experience."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "langchain_venv",
|
||||
"language": "python",
|
||||
"name": "langchain_venv"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -85,8 +85,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# If we pass in a model explicitly, we need to make sure it supports the OpenAI function-calling API.\n",
|
||||
"llm = ChatOpenAI(model=\"gpt-4\", temperature=0)\n",
|
||||
"# For better results in OpenAI function-calling API, it is recommended to explicitly pass the latest model.\n",
|
||||
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
|
||||
BIN
docs/static/img/langchain_stack.png
vendored
BIN
docs/static/img/langchain_stack.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 174 KiB After Width: | Height: | Size: 820 KiB |
891
docs/static/svg/langchain_stack.svg
vendored
891
docs/static/svg/langchain_stack.svg
vendored
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 449 KiB After Width: | Height: | Size: 531 KiB |
@@ -4,16 +4,22 @@ import typer
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_cli.namespaces import app as app_namespace
|
||||
from langchain_cli.namespaces import integration as integration_namespace
|
||||
from langchain_cli.namespaces import template as template_namespace
|
||||
from langchain_cli.utils.packages import get_langserve_export, get_package_root
|
||||
|
||||
__version__ = "0.0.19"
|
||||
__version__ = "0.0.20rc0"
|
||||
|
||||
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||
app.add_typer(
|
||||
template_namespace.package_cli, name="template", help=template_namespace.__doc__
|
||||
)
|
||||
app.add_typer(app_namespace.app_cli, name="app", help=app_namespace.__doc__)
|
||||
app.add_typer(
|
||||
integration_namespace.integration_cli,
|
||||
name="integration",
|
||||
help=integration_namespace.__doc__,
|
||||
)
|
||||
|
||||
|
||||
def version_callback(show_version: bool) -> None:
|
||||
|
||||
1
libs/cli/langchain_cli/integration_template/.gitignore
vendored
Normal file
1
libs/cli/langchain_cli/integration_template/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
__pycache__
|
||||
21
libs/cli/langchain_cli/integration_template/LICENSE
Normal file
21
libs/cli/langchain_cli/integration_template/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user