mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 08:40:36 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64b4165c8d | ||
|
|
9d658aaa5a | ||
|
|
a61b7f7e7c | ||
|
|
c4b502a470 | ||
|
|
ee57054d05 | ||
|
|
26ff18575c | ||
|
|
760632b292 | ||
|
|
8259f9b7fa | ||
|
|
0b3e0dd1d2 | ||
|
|
72f99ff953 | ||
|
|
cf5803e44c |
@@ -839,6 +839,127 @@
|
||||
"source": [
|
||||
"agent.run(\"whats 2**.12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "f1da459d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Handling Tool Errors \n",
|
||||
"When a tool encounters an error and the exception is not caught, the agent will stop executing. If you want the agent to continue execution, you can raise a `ToolException` and set `handle_tool_error` accordingly. \n",
|
||||
"\n",
|
||||
"When `ToolException` is thrown, the agent will not stop working, but will handle the exception according to the `handle_tool_error` variable of the tool, and the processing result will be returned to the agent as observation, and printed in red.\n",
|
||||
"\n",
|
||||
"You can set `handle_tool_error` to `True`, set it a unified string value, or set it as a function. If it's set as a function, the function should take a `ToolException` as a parameter and return a `str` value.\n",
|
||||
"\n",
|
||||
"Please note that only raising a `ToolException` won't be effective. You need to first set the `handle_tool_error` of the tool because its default value is `False`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "ad16fbcf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import ToolException\n",
|
||||
"\n",
|
||||
"from langchain import SerpAPIWrapper\n",
|
||||
"from langchain.agents import AgentType, initialize_agent\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.tools import Tool\n",
|
||||
"\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"\n",
|
||||
"def _handle_error(error:ToolException) -> str:\n",
|
||||
" return \"The following errors occurred during tool execution:\" + error.args[0]+ \"Please try another tool.\"\n",
|
||||
"def search_tool1(s: str):raise ToolException(\"The search tool1 is not available.\")\n",
|
||||
"def search_tool2(s: str):raise ToolException(\"The search tool2 is not available.\")\n",
|
||||
"search_tool3 = SerpAPIWrapper()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "c05aa75b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"description=\"useful for when you need to answer questions about current events.You should give priority to using it.\"\n",
|
||||
"tools = [\n",
|
||||
" Tool.from_function(\n",
|
||||
" func=search_tool1,\n",
|
||||
" name=\"Search_tool1\",\n",
|
||||
" description=description,\n",
|
||||
" handle_tool_error=True,\n",
|
||||
" ),\n",
|
||||
" Tool.from_function(\n",
|
||||
" func=search_tool2,\n",
|
||||
" name=\"Search_tool2\",\n",
|
||||
" description=description,\n",
|
||||
" handle_tool_error=_handle_error,\n",
|
||||
" ),\n",
|
||||
" Tool.from_function(\n",
|
||||
" func=search_tool3.run,\n",
|
||||
" name=\"Search_tool3\",\n",
|
||||
" description=\"useful for when you need to answer questions about current events\",\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" ChatOpenAI(temperature=0),\n",
|
||||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" verbose=True,\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "cff8b4b5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI should use Search_tool1 to find recent news articles about Leo DiCaprio's personal life.\n",
|
||||
"Action: Search_tool1\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
|
||||
"Observation: \u001b[31;1m\u001b[1;3mThe search tool1 is not available.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mI should try using Search_tool2 instead.\n",
|
||||
"Action: Search_tool2\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
|
||||
"Observation: \u001b[31;1m\u001b[1;3mThe following errors occurred during tool execution:The search tool2 is not available.Please try another tool.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mI should try using Search_tool3 as a last resort.\n",
|
||||
"Action: Search_tool3\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
|
||||
"Observation: \u001b[38;5;200m\u001b[1;3mLeonardo DiCaprio and Gigi Hadid were recently spotted at a pre-Oscars party, sparking interest once again in their rumored romance. The Revenant actor and the model first made headlines when they were spotted together at a New York Fashion Week afterparty in September 2022.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mBased on the information from Search_tool3, it seems that Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\n",
|
||||
"Final Answer: Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"Who is Leo DiCaprio's girlfriend?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -857,7 +978,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
"version": "3.11.3"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
||||
@@ -113,7 +113,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 5,
|
||||
"id": "af803fee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -316,6 +316,64 @@
|
||||
"result['answer']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "11a76453",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using a different model for condensing the question\n",
|
||||
"\n",
|
||||
"This chain has two steps. First, it condenses the current question and the chat history into a standalone question. This is neccessary to create a standanlone vector to use for retrieval. After that, it does retrieval and then answers the question using retrieval augmented generation with a separate model. Part of the power of the declarative nature of LangChain is that you can easily use a separate language model for each call. This can be useful to use a cheaper and faster model for the simpler task of condensing the question, and then a more expensive model for answering the question. Here is an example of doing so."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "8d4ede9e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "04a23e23",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qa = ConversationalRetrievalChain.from_llm(\n",
|
||||
" ChatOpenAI(temperature=0, model=\"gpt-4\"),\n",
|
||||
" vectorstore.as_retriever(),\n",
|
||||
" condense_question_llm = ChatOpenAI(temperature=0, model='gpt-3.5-turbo'),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "b1223752",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat_history = []\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"result = qa({\"question\": query, \"chat_history\": chat_history})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cdce4e28",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat_history = [(query, result[\"answer\"])]\n",
|
||||
"query = \"Did he mention who she suceeded\"\n",
|
||||
"result = qa({\"question\": query, \"chat_history\": chat_history})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0eaadf0f",
|
||||
|
||||
@@ -130,6 +130,7 @@ We need access tokens and sometime other parameters to get access to these datas
|
||||
./document_loaders/examples/notion.ipynb
|
||||
./document_loaders/examples/obsidian.ipynb
|
||||
./document_loaders/examples/psychic.ipynb
|
||||
./document_loaders/examples/pyspark_dataframe.ipynb
|
||||
./document_loaders/examples/readthedocs_documentation.ipynb
|
||||
./document_loaders/examples/reddit.ipynb
|
||||
./document_loaders/examples/roam.ipynb
|
||||
|
||||
261
docs/modules/indexes/document_loaders/examples/github.ipynb
Normal file
261
docs/modules/indexes/document_loaders/examples/github.ipynb
Normal file
@@ -0,0 +1,261 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# GitHub\n",
|
||||
"\n",
|
||||
"This notebooks shows how you can load issues and pull requests (PRs) for a given repository on [GitHub](https://github.com/). We will use the LangChain Python repository as an example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup access token"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To access the GitHub API, you need a personal access token - you can set up yours here: https://github.com/settings/tokens?type=beta. You can either set this token as the environment variable ``GITHUB_PERSONAL_ACCESS_TOKEN`` and it will be automatically pulled in, or you can pass it in directly at initializaiton as the ``access_token`` named parameter."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If you haven't set your access token as an environment variable, pass it in here.\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"ACCESS_TOKEN = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load Issues and PRs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import GitHubIssuesLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = GitHubIssuesLoader(\n",
|
||||
" repo=\"hwchase17/langchain\",\n",
|
||||
" access_token=ACCESS_TOKEN, # delete/comment out this argument if you've set the access token as an env var.\n",
|
||||
" creator=\"UmerHA\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's load all issues and PRs created by \"UmerHA\".\n",
|
||||
"\n",
|
||||
"Here's a list of all filters you can use:\n",
|
||||
"- include_prs\n",
|
||||
"- milestone\n",
|
||||
"- state\n",
|
||||
"- assignee\n",
|
||||
"- creator\n",
|
||||
"- mentioned\n",
|
||||
"- labels\n",
|
||||
"- sort\n",
|
||||
"- direction\n",
|
||||
"- since\n",
|
||||
"\n",
|
||||
"For more info, see https://docs.github.com/en/rest/issues/issues?apiVersion=2022-11-28#list-repository-issues."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"# Creates GitHubLoader (#5257)\r\n",
|
||||
"\r\n",
|
||||
"GitHubLoader is a DocumentLoader that loads issues and PRs from GitHub.\r\n",
|
||||
"\r\n",
|
||||
"Fixes #5257\r\n",
|
||||
"\r\n",
|
||||
"Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested:\r\n",
|
||||
"DataLoaders\r\n",
|
||||
"- @eyurtsev\r\n",
|
||||
"\n",
|
||||
"{'url': 'https://github.com/hwchase17/langchain/pull/5408', 'title': 'DocumentLoader for GitHub', 'creator': 'UmerHA', 'created_at': '2023-05-29T14:50:53Z', 'comments': 0, 'state': 'open', 'labels': ['enhancement', 'lgtm', 'doc loader'], 'assignee': None, 'milestone': None, 'locked': False, 'number': 5408, 'is_pull_request': True}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(docs[0].page_content)\n",
|
||||
"print(docs[0].metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Only load issues"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"By default, the GitHub API returns considers pull requests to also be issues. To only get 'pure' issues (i.e., no pull requests), use `include_prs=False`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = GitHubIssuesLoader(\n",
|
||||
" repo=\"hwchase17/langchain\",\n",
|
||||
" access_token=ACCESS_TOKEN, # delete/comment out this argument if you've set the access token as an env var.\n",
|
||||
" creator=\"UmerHA\",\n",
|
||||
" include_prs=False,\n",
|
||||
")\n",
|
||||
"docs = loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"### System Info\n",
|
||||
"\n",
|
||||
"LangChain version = 0.0.167\r\n",
|
||||
"Python version = 3.11.0\r\n",
|
||||
"System = Windows 11 (using Jupyter)\n",
|
||||
"\n",
|
||||
"### Who can help?\n",
|
||||
"\n",
|
||||
"- @hwchase17\r\n",
|
||||
"- @agola11\r\n",
|
||||
"- @UmerHA (I have a fix ready, will submit a PR)\n",
|
||||
"\n",
|
||||
"### Information\n",
|
||||
"\n",
|
||||
"- [ ] The official example notebooks/scripts\n",
|
||||
"- [X] My own modified scripts\n",
|
||||
"\n",
|
||||
"### Related Components\n",
|
||||
"\n",
|
||||
"- [X] LLMs/Chat Models\n",
|
||||
"- [ ] Embedding Models\n",
|
||||
"- [X] Prompts / Prompt Templates / Prompt Selectors\n",
|
||||
"- [ ] Output Parsers\n",
|
||||
"- [ ] Document Loaders\n",
|
||||
"- [ ] Vector Stores / Retrievers\n",
|
||||
"- [ ] Memory\n",
|
||||
"- [ ] Agents / Agent Executors\n",
|
||||
"- [ ] Tools / Toolkits\n",
|
||||
"- [ ] Chains\n",
|
||||
"- [ ] Callbacks/Tracing\n",
|
||||
"- [ ] Async\n",
|
||||
"\n",
|
||||
"### Reproduction\n",
|
||||
"\n",
|
||||
"```\r\n",
|
||||
"import os\r\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"...\"\r\n",
|
||||
"\r\n",
|
||||
"from langchain.chains import LLMChain\r\n",
|
||||
"from langchain.chat_models import ChatOpenAI\r\n",
|
||||
"from langchain.prompts import PromptTemplate\r\n",
|
||||
"from langchain.prompts.chat import ChatPromptTemplate\r\n",
|
||||
"from langchain.schema import messages_from_dict\r\n",
|
||||
"\r\n",
|
||||
"role_strings = [\r\n",
|
||||
" (\"system\", \"you are a bird expert\"), \r\n",
|
||||
" (\"human\", \"which bird has a point beak?\")\r\n",
|
||||
"]\r\n",
|
||||
"prompt = ChatPromptTemplate.from_role_strings(role_strings)\r\n",
|
||||
"chain = LLMChain(llm=ChatOpenAI(), prompt=prompt)\r\n",
|
||||
"chain.run({})\r\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Expected behavior\n",
|
||||
"\n",
|
||||
"Chain should run\n",
|
||||
"{'url': 'https://github.com/hwchase17/langchain/issues/5027', 'title': \"ChatOpenAI models don't work with prompts created via ChatPromptTemplate.from_role_strings\", 'creator': 'UmerHA', 'created_at': '2023-05-20T10:39:18Z', 'comments': 1, 'state': 'open', 'labels': [], 'assignee': None, 'milestone': None, 'locked': False, 'number': 5027, 'is_pull_request': False}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(docs[0].page_content)\n",
|
||||
"print(docs[0].metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PySpack DataFrame Loader\n",
|
||||
"\n",
|
||||
"This shows how to load data from a PySpark DataFrame"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install pyspark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pyspark.sql import SparkSession"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"spark = SparkSession.builder.getOrCreate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = spark.read.csv('example_data/mlb_teams_2012.csv', header=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import PySparkDataFrameLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = PySparkDataFrameLoader(spark, df, page_content_column=\"Team\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader.load()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
184
docs/modules/indexes/document_loaders/examples/trello.ipynb
Normal file
184
docs/modules/indexes/document_loaders/examples/trello.ipynb
Normal file
@@ -0,0 +1,184 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Trello\n",
|
||||
"\n",
|
||||
">[Trello](https://www.atlassian.com/software/trello) is a web-based project management and collaboration tool that allows individuals and teams to organize and track their tasks and projects. It provides a visual interface known as a \"board\" where users can create lists and cards to represent their tasks and activities.\n",
|
||||
"\n",
|
||||
"The TrelloLoader allows you to load cards from a Trello board and is implemented on top of [py-trello](https://pypi.org/project/py-trello/)\n",
|
||||
"\n",
|
||||
"This currently supports `api_key/token` only.\n",
|
||||
"\n",
|
||||
"1. Credentials generation: https://trello.com/power-ups/admin/\n",
|
||||
"\n",
|
||||
"2. Click in the manual token generation link to get the token.\n",
|
||||
"\n",
|
||||
"To specify the API key and token you can either set the environment variables ``TRELLO_API_KEY`` and ``TRELLO_TOKEN`` or you can pass ``api_key`` and ``token`` directly into the `from_credentials` convenience constructor method.\n",
|
||||
"\n",
|
||||
"This loader allows you to provide the board name to pull in the corresponding cards into Document objects.\n",
|
||||
"\n",
|
||||
"Notice that the board \"name\" is also called \"title\" in oficial documentation:\n",
|
||||
"\n",
|
||||
"https://support.atlassian.com/trello/docs/changing-a-boards-title-and-description/\n",
|
||||
"\n",
|
||||
"You can also specify several load parameters to include / remove different fields both from the document page_content properties and metadata.\n",
|
||||
"\n",
|
||||
"## Features\n",
|
||||
"- Load cards from a Trello board.\n",
|
||||
"- Filter cards based on their status (open or closed).\n",
|
||||
"- Include card names, comments, and checklists in the loaded documents.\n",
|
||||
"- Customize the additional metadata fields to include in the document.\n",
|
||||
"\n",
|
||||
"By default all card fields are included for the full text page_content and metadata accordinly.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install py-trello beautifulsoup4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"········\n",
|
||||
"········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# If you have already set the API key and token using environment variables,\n",
|
||||
"# you can skip this cell and comment out the `api_key` and `token` named arguments\n",
|
||||
"# in the initialization steps below.\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"API_KEY = getpass()\n",
|
||||
"TOKEN = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Review Tech partner pages\n",
|
||||
"Comments:\n",
|
||||
"{'title': 'Review Tech partner pages', 'id': '6475357890dc8d17f73f2dcc', 'url': 'https://trello.com/c/b0OTZwkZ/1-review-tech-partner-pages', 'labels': ['Demand Marketing'], 'list': 'Done', 'closed': False, 'due_date': ''}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TrelloLoader\n",
|
||||
"\n",
|
||||
"# Get the open cards from \"Awesome Board\"\n",
|
||||
"loader = TrelloLoader.from_credentials(\n",
|
||||
" \"Awesome Board\",\n",
|
||||
" api_key=API_KEY,\n",
|
||||
" token=TOKEN,\n",
|
||||
" card_filter=\"open\",\n",
|
||||
" )\n",
|
||||
"documents = loader.load()\n",
|
||||
"\n",
|
||||
"print(documents[0].page_content)\n",
|
||||
"print(documents[0].metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Review Tech partner pages\n",
|
||||
"Comments:\n",
|
||||
"{'title': 'Review Tech partner pages', 'id': '6475357890dc8d17f73f2dcc', 'url': 'https://trello.com/c/b0OTZwkZ/1-review-tech-partner-pages', 'list': 'Done'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Get all the cards from \"Awesome Board\" but only include the\n",
|
||||
"# card list(column) as extra metadata.\n",
|
||||
"loader = TrelloLoader.from_credentials(\n",
|
||||
" \"Awesome Board\",\n",
|
||||
" api_key=API_KEY,\n",
|
||||
" token=TOKEN,\n",
|
||||
" extra_metadata=(\"list\"),\n",
|
||||
")\n",
|
||||
"documents = loader.load()\n",
|
||||
"\n",
|
||||
"print(documents[0].page_content)\n",
|
||||
"print(documents[0].metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Get the cards from \"Another Board\" and exclude the card name,\n",
|
||||
"# checklist and comments from the Document page_content text.\n",
|
||||
"loader = TrelloLoader.from_credentials(\n",
|
||||
" \"test\",\n",
|
||||
" api_key=API_KEY,\n",
|
||||
" token=TOKEN,\n",
|
||||
" include_card_name= False,\n",
|
||||
" include_checklist= False,\n",
|
||||
" include_comments= False,\n",
|
||||
")\n",
|
||||
"documents = loader.load()\n",
|
||||
"\n",
|
||||
"print(\"Document: \" + documents[0].page_content)\n",
|
||||
"print(documents[0].metadata)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "cc99336516f23363341912c6723b01ace86f02e26b4290be1efc0677e2e2ec24"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
158
docs/modules/indexes/text_splitters/examples/code_splitter.ipynb
Normal file
158
docs/modules/indexes/text_splitters/examples/code_splitter.ipynb
Normal file
@@ -0,0 +1,158 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# CodeTextSplitter\n",
|
||||
"\n",
|
||||
"CodeTextSplitter allows you to split your code with multiple language support. Import enum `Language` and specify the language. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.text_splitter import (\n",
|
||||
" CodeTextSplitter,\n",
|
||||
" Language,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Choose a language to use"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"python_splitter = CodeTextSplitter(\n",
|
||||
" language=Language.PYTHON, chunk_size=16, chunk_overlap=0\n",
|
||||
")\n",
|
||||
"js_splitter = CodeTextSplitter(\n",
|
||||
" language=Language.JS, chunk_size=16, chunk_overlap=0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Split the code"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='def', metadata={}),\n",
|
||||
" Document(page_content='hello_world():', metadata={}),\n",
|
||||
" Document(page_content='print(\"Hello,', metadata={}),\n",
|
||||
" Document(page_content='World!\")', metadata={}),\n",
|
||||
" Document(page_content='# Call the', metadata={}),\n",
|
||||
" Document(page_content='function', metadata={}),\n",
|
||||
" Document(page_content='hello_world()', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"PYTHON_CODE = \"\"\"\n",
|
||||
"def hello_world():\n",
|
||||
" print(\"Hello, World!\")\n",
|
||||
"\n",
|
||||
"# Call the function\n",
|
||||
"hello_world()\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"python_docs = python_splitter.create_documents([PYTHON_CODE])\n",
|
||||
"python_docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='function', metadata={}),\n",
|
||||
" Document(page_content='helloWorld() {', metadata={}),\n",
|
||||
" Document(page_content='console.log(\"He', metadata={}),\n",
|
||||
" Document(page_content='llo,', metadata={}),\n",
|
||||
" Document(page_content='World!\");', metadata={}),\n",
|
||||
" Document(page_content='}', metadata={}),\n",
|
||||
" Document(page_content='// Call the', metadata={}),\n",
|
||||
" Document(page_content='function', metadata={}),\n",
|
||||
" Document(page_content='helloWorld();', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"JS_CODE = \"\"\"\n",
|
||||
"function helloWorld() {\n",
|
||||
" console.log(\"Hello, World!\");\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"// Call the function\n",
|
||||
"helloWorld();\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"js_docs = js_splitter.create_documents([JS_CODE])\n",
|
||||
"js_docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "langchain",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -42,17 +42,17 @@
|
||||
" \n",
|
||||
"def foo():\n",
|
||||
"\n",
|
||||
"def testing_func():\n",
|
||||
"def testing_func_with_long_name():\n",
|
||||
"\n",
|
||||
"def bar():\n",
|
||||
"\"\"\"\n",
|
||||
"python_splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)"
|
||||
"python_splitter = PythonCodeTextSplitter(chunk_size=40, chunk_overlap=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6cdc55f3",
|
||||
"id": "8cc33770",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -62,15 +62,16 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "8cc33770",
|
||||
"id": "f5f70775",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Foo:\\n\\n def bar():', lookup_str='', metadata={}, lookup_index=0),\n",
|
||||
" Document(page_content='foo():\\n\\ndef testing_func():', lookup_str='', metadata={}, lookup_index=0),\n",
|
||||
" Document(page_content='bar():', lookup_str='', metadata={}, lookup_index=0)]"
|
||||
"[Document(page_content='class Foo:\\n\\n def bar():', metadata={}),\n",
|
||||
" Document(page_content='def foo():', metadata={}),\n",
|
||||
" Document(page_content='def testing_func_with_long_name():', metadata={}),\n",
|
||||
" Document(page_content='def bar():', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
@@ -82,33 +83,10 @@
|
||||
"docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "de625e08-c440-489d-beed-020b6c53bf69",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['Foo:\\n\\n def bar():', 'foo():\\n\\ndef testing_func():', 'bar():']"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"python_splitter.split_text(python_text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "55aadd84-75ca-48ae-9b84-b39c368488ed",
|
||||
"id": "6e096d42",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@@ -130,7 +108,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.9.1"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "683953b3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MongoDB Atlas Vector Search\n",
|
||||
"\n",
|
||||
">[MongoDB Atlas](https://www.mongodb.com/docs/atlas/) is a document database managed in the cloud. It also enables Lucene and its vector search feature.\n",
|
||||
"\n",
|
||||
"This notebook shows how to use the functionality related to the `MongoDB Atlas Vector Search` feature where you can store your embeddings in MongoDB documents and create a Lucene vector index to perform a KNN search.\n",
|
||||
"\n",
|
||||
"It uses the [knnBeta Operator](https://www.mongodb.com/docs/atlas/atlas-search/knn-beta) available in MongoDB Atlas Search. This feature is in early access and available only for evaluation purposes, to validate functionality, and to gather feedback from a small closed group of early access users. It is not recommended for production deployments as we may introduce breaking changes.\n",
|
||||
"\n",
|
||||
"To use MongoDB Atlas, you must have first deployed a cluster. Free clusters are available. \n",
|
||||
"Here is the MongoDB Atlas [quick start](https://www.mongodb.com/docs/atlas/getting-started/)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b4c41cad-08ef-4f72-a545-2151e4598efe",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install pymongo"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"MONGODB_ATLAS_URI = os.environ['MONGODB_ATLAS_URI']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "320af802-9271-46ee-948f-d2453933d44b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key. Make sure the environment variable `OPENAI_API_KEY` is set up before proceeding."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f3ecc42",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, let's create a Lucene vector index on your cluster. In the below example, `embedding` is the name of the field that contains the embedding vector. Please refer to the [documentation](https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings-for-vector-search) to get more details on how to define an Atlas Search index.\n",
|
||||
"You can name the index `langchain_demo` and create the index on the namespace `lanchain_db.langchain_col`. Finally, write the following definition in the JSON editor:\n",
|
||||
"\n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"mappings\": {\n",
|
||||
" \"dynamic\": true,\n",
|
||||
" \"fields\": {\n",
|
||||
" \"embedding\": {\n",
|
||||
" \"dimensions\": 1536,\n",
|
||||
" \"similarity\": \"cosine\",\n",
|
||||
" \"type\": \"knnVector\"\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "aac9563e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
|
||||
"from langchain.document_loaders import TextLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a3c3999a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"loader = TextLoader('../../../state_of_the_union.txt')\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6e104aee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pymongo import MongoClient\n",
|
||||
"\n",
|
||||
"# initialize MongoDB python client\n",
|
||||
"client = MongoClient(MONGODB_ATLAS_CONNECTION_STRING)\n",
|
||||
"\n",
|
||||
"db_name = \"lanchain_db\"\n",
|
||||
"collection_name = \"langchain_col\"\n",
|
||||
"namespace = f\"{db_name}.{collection_name}\"\n",
|
||||
"index_name = \"langchain_demo\"\n",
|
||||
"\n",
|
||||
"# insert the documents in MongoDB Atlas with their embedding\n",
|
||||
"docsearch = MongoDBAtlasVectorSearch.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" client=client,\n",
|
||||
" namespace=namespace,\n",
|
||||
" index_name=index_name\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# perform a similarity search between the embedding of the query and the embeddings of the documents\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = docsearch.similarity_search(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9c608226",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -941,7 +941,7 @@ class AgentExecutor(Chain):
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
# We construct a mapping from each tool to a color, used for logging.
|
||||
color_mapping = get_color_mapping(
|
||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
|
||||
)
|
||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||
# Let's start tracking the number of iterations and time elapsed
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.langchain import get_endpoint, get_headers
|
||||
from langchain.callbacks.tracers.langchain import get_headers
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
@@ -20,6 +21,10 @@ from langchain.schema import get_buffer_string
|
||||
from langchain.utils import raise_for_status_with_text
|
||||
|
||||
|
||||
def _get_endpoint() -> str:
|
||||
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
|
||||
|
||||
class LangChainTracerV1(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
@@ -27,7 +32,7 @@ class LangChainTracerV1(BaseTracer):
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self.session: Optional[TracerSessionV1] = None
|
||||
self._endpoint = get_endpoint()
|
||||
self._endpoint = _get_endpoint()
|
||||
self._headers = get_headers()
|
||||
|
||||
def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
|
||||
@@ -195,6 +195,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
||||
chain_type: str = "stuff",
|
||||
verbose: bool = False,
|
||||
condense_question_llm: Optional[BaseLanguageModel] = None,
|
||||
combine_docs_chain_kwargs: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseConversationalRetrievalChain:
|
||||
@@ -206,8 +207,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
verbose=verbose,
|
||||
**combine_docs_chain_kwargs,
|
||||
)
|
||||
|
||||
_llm = condense_question_llm or llm
|
||||
condense_question_chain = LLMChain(
|
||||
llm=llm, prompt=condense_question_prompt, verbose=verbose
|
||||
llm=_llm, prompt=condense_question_prompt, verbose=verbose
|
||||
)
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
|
||||
@@ -37,6 +37,7 @@ from langchain.document_loaders.gcs_directory import GCSDirectoryLoader
|
||||
from langchain.document_loaders.gcs_file import GCSFileLoader
|
||||
from langchain.document_loaders.git import GitLoader
|
||||
from langchain.document_loaders.gitbook import GitbookLoader
|
||||
from langchain.document_loaders.github import GitHubIssuesLoader
|
||||
from langchain.document_loaders.googledrive import GoogleDriveLoader
|
||||
from langchain.document_loaders.gutenberg import GutenbergLoader
|
||||
from langchain.document_loaders.hn import HNLoader
|
||||
@@ -73,6 +74,7 @@ from langchain.document_loaders.pdf import (
|
||||
)
|
||||
from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader
|
||||
from langchain.document_loaders.psychic import PsychicLoader
|
||||
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||
from langchain.document_loaders.python import PythonLoader
|
||||
from langchain.document_loaders.readthedocs import ReadTheDocsLoader
|
||||
from langchain.document_loaders.reddit import RedditPostsLoader
|
||||
@@ -92,6 +94,7 @@ from langchain.document_loaders.telegram import (
|
||||
from langchain.document_loaders.text import TextLoader
|
||||
from langchain.document_loaders.tomarkdown import ToMarkdownLoader
|
||||
from langchain.document_loaders.toml import TomlLoader
|
||||
from langchain.document_loaders.trello import TrelloLoader
|
||||
from langchain.document_loaders.twitter import TwitterTweetLoader
|
||||
from langchain.document_loaders.unstructured import (
|
||||
UnstructuredAPIFileIOLoader,
|
||||
@@ -152,6 +155,7 @@ __all__ = [
|
||||
"GCSDirectoryLoader",
|
||||
"GCSFileLoader",
|
||||
"GitLoader",
|
||||
"GitHubIssuesLoader",
|
||||
"GitbookLoader",
|
||||
"GoogleApiClient",
|
||||
"GoogleApiYoutubeLoader",
|
||||
@@ -185,6 +189,7 @@ __all__ = [
|
||||
"PyPDFDirectoryLoader",
|
||||
"PyPDFLoader",
|
||||
"PyPDFium2Loader",
|
||||
"PySparkDataFrameLoader",
|
||||
"PythonLoader",
|
||||
"ReadTheDocsLoader",
|
||||
"RedditPostsLoader",
|
||||
@@ -201,6 +206,7 @@ __all__ = [
|
||||
"StripeLoader",
|
||||
"TextLoader",
|
||||
"TomlLoader",
|
||||
"TrelloLoader",
|
||||
"TwitterTweetLoader",
|
||||
"UnstructuredAPIFileIOLoader",
|
||||
"UnstructuredAPIFileLoader",
|
||||
|
||||
182
langchain/document_loaders/github.py
Normal file
182
langchain/document_loaders/github.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from abc import ABC
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterator, List, Literal, Optional, Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class BaseGitHubLoader(BaseLoader, BaseModel, ABC):
|
||||
"""Load issues of a GitHub repository."""
|
||||
|
||||
repo: str
|
||||
"""Name of repository"""
|
||||
access_token: str
|
||||
"""Personal access token - see https://github.com/settings/tokens?type=beta"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that access token exists in environment."""
|
||||
values["access_token"] = get_from_dict_or_env(
|
||||
values, "access_token", "GITHUB_PERSONAL_ACCESS_TOKEN"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
|
||||
|
||||
class GitHubIssuesLoader(BaseGitHubLoader):
|
||||
include_prs: bool = True
|
||||
"""If True include Pull Requests in results, otherwise ignore them."""
|
||||
milestone: Union[int, Literal["*", "none"], None] = None
|
||||
"""If integer is passed, it should be a milestone's number field.
|
||||
If the string '*' is passed, issues with any milestone are accepted.
|
||||
If the string 'none' is passed, issues without milestones are returned.
|
||||
"""
|
||||
state: Optional[Literal["open", "closed", "all"]] = None
|
||||
"""Filter on issue state. Can be one of: 'open', 'closed', 'all'."""
|
||||
assignee: Optional[str] = None
|
||||
"""Filter on assigned user. Pass 'none' for no user and '*' for any user."""
|
||||
creator: Optional[str] = None
|
||||
"""Filter on the user that created the issue."""
|
||||
mentioned: Optional[str] = None
|
||||
"""Filter on a user that's mentioned in the issue."""
|
||||
labels: Optional[List[str]] = None
|
||||
"""Label names to filter one. Example: bug,ui,@high."""
|
||||
sort: Optional[Literal["created", "updated", "comments"]] = None
|
||||
"""What to sort results by. Can be one of: 'created', 'updated', 'comments'.
|
||||
Default is 'created'."""
|
||||
direction: Optional[Literal["asc", "desc"]] = None
|
||||
"""The direction to sort the results by. Can be one of: 'asc', 'desc'."""
|
||||
since: Optional[str] = None
|
||||
"""Only show notifications updated after the given time.
|
||||
This is a timestamp in ISO 8601 format: YYYY-MM-DDTHH:MM:SSZ."""
|
||||
|
||||
@validator("since")
|
||||
def validate_since(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v:
|
||||
try:
|
||||
datetime.strptime(v, "%Y-%m-%dT%H:%M:%SZ")
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Invalid value for 'since'. Expected a date string in "
|
||||
f"YYYY-MM-DDTHH:MM:SSZ format. Received: {v}"
|
||||
)
|
||||
return v
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""
|
||||
Get issues of a GitHub repository.
|
||||
|
||||
Returns:
|
||||
A list of Documents with attributes:
|
||||
- page_content
|
||||
- metadata
|
||||
- url
|
||||
- title
|
||||
- creator
|
||||
- created_at
|
||||
- last_update_time
|
||||
- closed_time
|
||||
- number of comments
|
||||
- state
|
||||
- labels
|
||||
- assignee
|
||||
- assignees
|
||||
- milestone
|
||||
- locked
|
||||
- number
|
||||
- is_pull_request
|
||||
"""
|
||||
url: Optional[str] = self.url
|
||||
while url:
|
||||
response = requests.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
issues = response.json()
|
||||
for issue in issues:
|
||||
doc = self.parse_issue(issue)
|
||||
if not self.include_prs and doc.metadata["is_pull_request"]:
|
||||
continue
|
||||
yield doc
|
||||
if response.links and response.links.get("next"):
|
||||
url = response.links["next"]["url"]
|
||||
else:
|
||||
url = None
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Get issues of a GitHub repository.
|
||||
|
||||
Returns:
|
||||
A list of Documents with attributes:
|
||||
- page_content
|
||||
- metadata
|
||||
- url
|
||||
- title
|
||||
- creator
|
||||
- created_at
|
||||
- last_update_time
|
||||
- closed_time
|
||||
- number of comments
|
||||
- state
|
||||
- labels
|
||||
- assignee
|
||||
- assignees
|
||||
- milestone
|
||||
- locked
|
||||
- number
|
||||
- is_pull_request
|
||||
"""
|
||||
return list(self.lazy_load())
|
||||
|
||||
def parse_issue(self, issue: dict) -> Document:
|
||||
"""Create Document objects from a list of GitHub issues."""
|
||||
metadata = {
|
||||
"url": issue["html_url"],
|
||||
"title": issue["title"],
|
||||
"creator": issue["user"]["login"],
|
||||
"created_at": issue["created_at"],
|
||||
"comments": issue["comments"],
|
||||
"state": issue["state"],
|
||||
"labels": [label["name"] for label in issue["labels"]],
|
||||
"assignee": issue["assignee"]["login"] if issue["assignee"] else None,
|
||||
"milestone": issue["milestone"]["title"] if issue["milestone"] else None,
|
||||
"locked": issue["locked"],
|
||||
"number": issue["number"],
|
||||
"is_pull_request": "pull_request" in issue,
|
||||
}
|
||||
content = issue["body"] if issue["body"] is not None else ""
|
||||
return Document(page_content=content, metadata=metadata)
|
||||
|
||||
@property
|
||||
def query_params(self) -> str:
|
||||
labels = ",".join(self.labels) if self.labels else self.labels
|
||||
query_params_dict = {
|
||||
"milestone": self.milestone,
|
||||
"state": self.state,
|
||||
"assignee": self.assignee,
|
||||
"creator": self.creator,
|
||||
"mentioned": self.mentioned,
|
||||
"labels": labels,
|
||||
"sort": self.sort,
|
||||
"direction": self.direction,
|
||||
"since": self.since,
|
||||
}
|
||||
query_params_list = [
|
||||
f"{k}={v}" for k, v in query_params_dict.items() if v is not None
|
||||
]
|
||||
query_params = "&".join(query_params_list)
|
||||
return query_params
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"https://api.github.com/repos/{self.repo}/issues?{self.query_params}"
|
||||
80
langchain/document_loaders/pyspark_dataframe.py
Normal file
80
langchain/document_loaders/pyspark_dataframe.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Load from a Spark Dataframe object"""
|
||||
import itertools
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
|
||||
class PySparkDataFrameLoader(BaseLoader):
|
||||
"""Load PySpark DataFrames"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spark_session: Optional["SparkSession"] = None,
|
||||
df: Optional[Any] = None,
|
||||
page_content_column: str = "text",
|
||||
fraction_of_memory: float = 0.1,
|
||||
):
|
||||
"""Initialize with a Spark DataFrame object."""
|
||||
try:
|
||||
from pyspark.sql import DataFrame, SparkSession
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"pyspark is not installed. "
|
||||
"Please install it with `pip install pyspark`"
|
||||
)
|
||||
|
||||
self.spark = (
|
||||
spark_session if spark_session else SparkSession.builder.getOrCreate()
|
||||
)
|
||||
|
||||
if not isinstance(df, DataFrame):
|
||||
raise ValueError(
|
||||
f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
|
||||
)
|
||||
self.df = df
|
||||
self.page_content_column = page_content_column
|
||||
self.fraction_of_memory = fraction_of_memory
|
||||
self.num_rows, self.max_num_rows = self.get_num_rows()
|
||||
self.rdd_df = self.df.rdd.map(list)
|
||||
self.column_names = self.df.columns
|
||||
|
||||
def get_num_rows(self) -> Tuple[int, int]:
|
||||
"""Gets the amount of "feasible" rows for the DataFrame"""
|
||||
row = self.df.limit(1).collect()[0]
|
||||
estimated_row_size = sys.getsizeof(row)
|
||||
mem_info = psutil.virtual_memory()
|
||||
available_memory = mem_info.available
|
||||
max_num_rows = int(
|
||||
(available_memory / estimated_row_size) * self.fraction_of_memory
|
||||
)
|
||||
return min(max_num_rows, self.df.count()), max_num_rows
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""A lazy loader for document content."""
|
||||
for row in self.rdd_df.toLocalIterator():
|
||||
metadata = {self.column_names[i]: row[i] for i in range(len(row))}
|
||||
text = metadata[self.page_content_column]
|
||||
metadata.pop(self.page_content_column)
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from the dataframe."""
|
||||
if self.df.count() > self.max_num_rows:
|
||||
logger.warning(
|
||||
f"The number of DataFrame rows is {self.df.count()}, "
|
||||
f"but we will only include the amount "
|
||||
f"of rows that can reasonably fit in memory: {self.num_rows}."
|
||||
)
|
||||
lazy_load_iterator = self.lazy_load()
|
||||
return list(itertools.islice(lazy_load_iterator, self.num_rows))
|
||||
168
langchain/document_loaders/trello.py
Normal file
168
langchain/document_loaders/trello.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Loader that loads cards from Trello"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trello import Board, Card, TrelloClient
|
||||
|
||||
|
||||
class TrelloLoader(BaseLoader):
|
||||
"""Trello loader. Reads all cards from a Trello board."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: TrelloClient,
|
||||
board_name: str,
|
||||
*,
|
||||
include_card_name: bool = True,
|
||||
include_comments: bool = True,
|
||||
include_checklist: bool = True,
|
||||
card_filter: Literal["closed", "open", "all"] = "all",
|
||||
extra_metadata: Tuple[str, ...] = ("due_date", "labels", "list", "closed"),
|
||||
):
|
||||
"""Initialize Trello loader.
|
||||
|
||||
Args:
|
||||
client: Trello API client.
|
||||
board_name: The name of the Trello board.
|
||||
include_card_name: Whether to include the name of the card in the document.
|
||||
include_comments: Whether to include the comments on the card in the
|
||||
document.
|
||||
include_checklist: Whether to include the checklist on the card in the
|
||||
document.
|
||||
card_filter: Filter on card status. Valid values are "closed", "open",
|
||||
"all".
|
||||
extra_metadata: List of additional metadata fields to include as document
|
||||
metadata.Valid values are "due_date", "labels", "list", "closed".
|
||||
|
||||
"""
|
||||
self.client = client
|
||||
self.board_name = board_name
|
||||
self.include_card_name = include_card_name
|
||||
self.include_comments = include_comments
|
||||
self.include_checklist = include_checklist
|
||||
self.extra_metadata = extra_metadata
|
||||
self.card_filter = card_filter
|
||||
|
||||
@classmethod
|
||||
def from_credentials(
|
||||
cls,
|
||||
board_name: str,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> TrelloLoader:
|
||||
"""Convenience constructor that builds TrelloClient init param for you.
|
||||
|
||||
Args:
|
||||
board_name: The name of the Trello board.
|
||||
api_key: Trello API key. Can also be specified as environment variable
|
||||
TRELLO_API_KEY.
|
||||
token: Trello token. Can also be specified as environment variable
|
||||
TRELLO_TOKEN.
|
||||
include_card_name: Whether to include the name of the card in the document.
|
||||
include_comments: Whether to include the comments on the card in the
|
||||
document.
|
||||
include_checklist: Whether to include the checklist on the card in the
|
||||
document.
|
||||
card_filter: Filter on card status. Valid values are "closed", "open",
|
||||
"all".
|
||||
extra_metadata: List of additional metadata fields to include as document
|
||||
metadata.Valid values are "due_date", "labels", "list", "closed".
|
||||
"""
|
||||
|
||||
try:
|
||||
from trello import TrelloClient # type: ignore
|
||||
except ImportError as ex:
|
||||
raise ImportError(
|
||||
"Could not import trello python package. "
|
||||
"Please install it with `pip install py-trello`."
|
||||
) from ex
|
||||
api_key = api_key or get_from_env("api_key", "TRELLO_API_KEY")
|
||||
token = token or get_from_env("token", "TRELLO_TOKEN")
|
||||
client = TrelloClient(api_key=api_key, token=token)
|
||||
return cls(client, board_name, **kwargs)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Loads all cards from the specified Trello board.
|
||||
|
||||
You can filter the cards, metadata and text included by using the optional
|
||||
parameters.
|
||||
|
||||
Returns:
|
||||
A list of documents, one for each card in the board.
|
||||
"""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa: F401
|
||||
except ImportError as ex:
|
||||
raise ImportError(
|
||||
"`beautifulsoup4` package not found, please run"
|
||||
" `pip install beautifulsoup4`"
|
||||
) from ex
|
||||
|
||||
board = self._get_board()
|
||||
# Create a dictionary with the list IDs as keys and the list names as values
|
||||
list_dict = {list_item.id: list_item.name for list_item in board.list_lists()}
|
||||
# Get Cards on the board
|
||||
cards = board.get_cards(card_filter=self.card_filter)
|
||||
return [self._card_to_doc(card, list_dict) for card in cards]
|
||||
|
||||
def _get_board(self) -> Board:
|
||||
# Find the first board with a matching name
|
||||
board = next(
|
||||
(b for b in self.client.list_boards() if b.name == self.board_name), None
|
||||
)
|
||||
if not board:
|
||||
raise ValueError(f"Board `{self.board_name}` not found.")
|
||||
return board
|
||||
|
||||
def _card_to_doc(self, card: Card, list_dict: dict) -> Document:
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
|
||||
text_content = ""
|
||||
if self.include_card_name:
|
||||
text_content = card.name + "\n"
|
||||
if card.description.strip():
|
||||
text_content += BeautifulSoup(card.description, "lxml").get_text()
|
||||
if self.include_checklist:
|
||||
# Get all the checklist items on the card
|
||||
for checklist in card.checklists:
|
||||
if checklist.items:
|
||||
items = [
|
||||
f"{item['name']}:{item['state']}" for item in checklist.items
|
||||
]
|
||||
text_content += f"\n{checklist.name}\n" + "\n".join(items)
|
||||
|
||||
if self.include_comments:
|
||||
# Get all the comments on the card
|
||||
comments = [
|
||||
BeautifulSoup(comment["data"]["text"], "lxml").get_text()
|
||||
for comment in card.comments
|
||||
]
|
||||
text_content += "Comments:" + "\n".join(comments)
|
||||
|
||||
# Default metadata fields
|
||||
metadata = {
|
||||
"title": card.name,
|
||||
"id": card.id,
|
||||
"url": card.url,
|
||||
}
|
||||
|
||||
# Extra metadata fields. Card object is not subscriptable.
|
||||
if "labels" in self.extra_metadata:
|
||||
metadata["labels"] = [label.name for label in card.labels]
|
||||
if "list" in self.extra_metadata:
|
||||
if card.list_id in list_dict:
|
||||
metadata["list"] = list_dict[card.list_id]
|
||||
if "closed" in self.extra_metadata:
|
||||
metadata["closed"] = card.closed
|
||||
if "due_date" in self.extra_metadata:
|
||||
metadata["due_date"] = card.due_date
|
||||
|
||||
return Document(page_content=text_content, metadata=metadata)
|
||||
@@ -3,7 +3,9 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
@@ -27,6 +29,23 @@ logger = logging.getLogger(__name__)
|
||||
TS = TypeVar("TS", bound="TextSplitter")
|
||||
|
||||
|
||||
def _split_text(text: str, separator: str, keep_separator: bool) -> List[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
# The parentheses in the pattern keep the delimiters in the result.
|
||||
_splits = re.split(f"({separator})", text)
|
||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 == 0:
|
||||
splits += _splits[-1:]
|
||||
splits = [_splits[0]] + splits
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if s != ""]
|
||||
|
||||
|
||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
|
||||
@@ -35,8 +54,16 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
keep_separator: bool = False,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
"""Create a new TextSplitter.
|
||||
|
||||
Args:
|
||||
chunk_size: Maximum size of chunks to return
|
||||
chunk_overlap: Overlap in characters between chunks
|
||||
length_function: Function that measures the length of given chunks
|
||||
keep_separator: Whether or not to keep the separator in the chunks
|
||||
"""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
@@ -45,6 +72,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
self._keep_separator = keep_separator
|
||||
|
||||
@abstractmethod
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
@@ -211,11 +239,9 @@ class CharacterTextSplitter(TextSplitter):
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
if self._separator:
|
||||
splits = text.split(self._separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
return self._merge_splits(splits, self._separator)
|
||||
splits = _split_text(text, self._separator, self._keep_separator)
|
||||
_separator = "" if self._keep_separator else self._separator
|
||||
return self._merge_splits(splits, _separator)
|
||||
|
||||
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
@@ -274,45 +300,56 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
that works.
|
||||
"""
|
||||
|
||||
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
|
||||
def __init__(
|
||||
self,
|
||||
separators: Optional[List[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = self._separators[-1]
|
||||
for _s in self._separators:
|
||||
separator = separators[-1]
|
||||
new_separators = None
|
||||
for i, _s in enumerate(separators):
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if _s in text:
|
||||
separator = _s
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
|
||||
splits = _split_text(text, separator, self._keep_separator)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_separator = "" if self._keep_separator else separator
|
||||
for s in splits:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
other_info = self.split_text(s)
|
||||
final_chunks.extend(other_info)
|
||||
if new_separators is None:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
|
||||
class NLTKTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at sentences using NLTK."""
|
||||
@@ -439,3 +476,314 @@ class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"",
|
||||
]
|
||||
super().__init__(separators=separators, **kwargs)
|
||||
|
||||
|
||||
class Language(str, Enum):
|
||||
CPP = "cpp"
|
||||
GO = "go"
|
||||
JAVA = "java"
|
||||
JS = "js"
|
||||
PHP = "php"
|
||||
PROTO = "proto"
|
||||
PYTHON = "python"
|
||||
RST = "rst"
|
||||
RUBY = "ruby"
|
||||
RUST = "rust"
|
||||
SCALA = "scala"
|
||||
SWIFT = "swift"
|
||||
MARKDOWN = "markdown"
|
||||
LATEX = "latex"
|
||||
|
||||
|
||||
class CodeTextSplitter(RecursiveCharacterTextSplitter):
|
||||
def __init__(self, language: Language, **kwargs: Any):
|
||||
"""
|
||||
A generic code text splitter supporting many programming languages.
|
||||
Example:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.JAVA
|
||||
)
|
||||
Args:
|
||||
Language: The programming language to use
|
||||
"""
|
||||
separators = self._get_separators_for_language(language)
|
||||
super().__init__(separators=separators, **kwargs)
|
||||
|
||||
def _get_separators_for_language(self, language: Language) -> List[str]:
|
||||
if language == Language.CPP:
|
||||
return [
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
# Split along function definitions
|
||||
"\nvoid ",
|
||||
"\nint ",
|
||||
"\nfloat ",
|
||||
"\ndouble ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.GO:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunc ",
|
||||
"\nvar ",
|
||||
"\nconst ",
|
||||
"\ntype ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.JAVA:
|
||||
return [
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
# Split along method definitions
|
||||
"\npublic ",
|
||||
"\nprotected ",
|
||||
"\nprivate ",
|
||||
"\nstatic ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.JS:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunction ",
|
||||
"\nconst ",
|
||||
"\nlet ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
"\ndefault ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.PHP:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunction ",
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nforeach ",
|
||||
"\nwhile ",
|
||||
"\ndo ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.PROTO:
|
||||
return [
|
||||
# Split along message definitions
|
||||
"\nmessage ",
|
||||
# Split along service definitions
|
||||
"\nservice ",
|
||||
# Split along enum definitions
|
||||
"\nenum ",
|
||||
# Split along option definitions
|
||||
"\noption ",
|
||||
# Split along import statements
|
||||
"\nimport ",
|
||||
# Split along syntax declarations
|
||||
"\nsyntax ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.PYTHON:
|
||||
return [
|
||||
# First, try to split along class definitions
|
||||
"\nclass ",
|
||||
"\ndef ",
|
||||
"\n\tdef ",
|
||||
# Now split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.RST:
|
||||
return [
|
||||
# Split along section titles
|
||||
"\n===\n",
|
||||
"\n---\n",
|
||||
"\n***\n",
|
||||
# Split along directive markers
|
||||
"\n.. ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.RUBY:
|
||||
return [
|
||||
# Split along method definitions
|
||||
"\ndef ",
|
||||
"\nclass ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nunless ",
|
||||
"\nwhile ",
|
||||
"\nfor ",
|
||||
"\ndo ",
|
||||
"\nbegin ",
|
||||
"\nrescue ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.RUST:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfn ",
|
||||
"\nconst ",
|
||||
"\nlet ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nwhile ",
|
||||
"\nfor ",
|
||||
"\nloop ",
|
||||
"\nmatch ",
|
||||
"\nconst ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.SCALA:
|
||||
return [
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
"\nobject ",
|
||||
# Split along method definitions
|
||||
"\ndef ",
|
||||
"\nval ",
|
||||
"\nvar ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\nmatch ",
|
||||
"\ncase ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.SWIFT:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunc ",
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
"\nstruct ",
|
||||
"\nenum ",
|
||||
# Split along control flow statements
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\ndo ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
# Split by the normal type of lines
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.MARKDOWN:
|
||||
return [
|
||||
# First, try to split along Markdown headings (starting with level 2)
|
||||
"\n## ",
|
||||
"\n### ",
|
||||
"\n#### ",
|
||||
"\n##### ",
|
||||
"\n###### ",
|
||||
# Note the alternative syntax for headings (below) is not handled here
|
||||
# Heading level 2
|
||||
# ---------------
|
||||
# End of code block
|
||||
"```\n\n",
|
||||
# Horizontal lines
|
||||
"\n\n***\n\n",
|
||||
"\n\n---\n\n",
|
||||
"\n\n___\n\n",
|
||||
# Note that this splitter doesn't handle horizontal lines defined
|
||||
# by *three or more* of ***, ---, or ___, but this is not handled
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.LATEX:
|
||||
return [
|
||||
# First, try to split along Latex sections
|
||||
"\n\\chapter{",
|
||||
"\n\\section{",
|
||||
"\n\\subsection{",
|
||||
"\n\\subsubsection{",
|
||||
# Now split by environments
|
||||
"\n\\begin{enumerate}",
|
||||
"\n\\begin{itemize}",
|
||||
"\n\\begin{description}",
|
||||
"\n\\begin{list}",
|
||||
"\n\\begin{quote}",
|
||||
"\n\\begin{quotation}",
|
||||
"\n\\begin{verse}",
|
||||
"\n\\begin{verbatim}",
|
||||
## Now split by math environments
|
||||
"\n\\begin{align}",
|
||||
"$$",
|
||||
"$",
|
||||
# Now split by the normal type of lines
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Language {language} is not supported! "
|
||||
f"Please choose from {list(Language)}"
|
||||
)
|
||||
|
||||
@@ -112,6 +112,18 @@ def create_schema_from_function(
|
||||
)
|
||||
|
||||
|
||||
class ToolException(Exception):
|
||||
"""An optional exception that tool throws when execution error occurs.
|
||||
|
||||
When this exception is thrown, the agent will not stop working,
|
||||
but will handle the exception according to the handle_tool_error
|
||||
variable of the tool, and the processing result will be returned
|
||||
to the agent as observation, and printed in red on the console.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
@@ -137,6 +149,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Deprecated. Please use callbacks instead."""
|
||||
|
||||
handle_tool_error: Optional[
|
||||
Union[bool, str, Callable[[ToolException], str]]
|
||||
] = False
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@@ -250,11 +267,36 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
if new_arg_supported
|
||||
else self._run(*tool_args, **tool_kwargs)
|
||||
)
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
run_manager.on_tool_end(
|
||||
str(observation), color="red", name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
run_manager.on_tool_end(str(observation), color=color, name=self.name, **kwargs)
|
||||
return observation
|
||||
else:
|
||||
run_manager.on_tool_end(
|
||||
str(observation), color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
@@ -289,13 +331,36 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
if new_arg_supported
|
||||
else await self._arun(*tool_args, **tool_kwargs)
|
||||
)
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
await run_manager.on_tool_end(
|
||||
str(observation), color="red", name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
await run_manager.on_tool_end(
|
||||
str(observation), color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
else:
|
||||
await run_manager.on_tool_end(
|
||||
str(observation), color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
||||
"""Make tool callable."""
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.vectorstores.lancedb import LanceDB
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
|
||||
from langchain.vectorstores.myscale import MyScale, MyScaleSettings
|
||||
from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch
|
||||
from langchain.vectorstores.pinecone import Pinecone
|
||||
@@ -38,6 +39,7 @@ __all__ = [
|
||||
"AtlasDB",
|
||||
"DeepLake",
|
||||
"Annoy",
|
||||
"MongoDBAtlasVectorSearch",
|
||||
"MyScale",
|
||||
"MyScaleSettings",
|
||||
"SKLearnVectorStore",
|
||||
|
||||
270
langchain/vectorstores/mongodb_atlas.py
Normal file
270
langchain/vectorstores/mongodb_atlas.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo import MongoClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_INSERT_BATCH_SIZE = 100
|
||||
|
||||
|
||||
class MongoDBAtlasVectorSearch(VectorStore):
|
||||
"""Wrapper around MongoDB Atlas Vector Search.
|
||||
|
||||
To use, you should have both:
|
||||
- the ``pymongo`` python package installed
|
||||
- a connection string associated with a MongoDB Atlas Cluster having deployed an
|
||||
Atlas Search index
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import MongoDBAtlasVectorSearch
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from pymongo import MongoClient
|
||||
|
||||
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||
namespace = "<db_name>.<collection_name>"
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = MongoDBAtlasVectorSearch(mongo_client, namespace, embeddings)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: MongoClient,
|
||||
namespace: str,
|
||||
embedding: Embeddings,
|
||||
*,
|
||||
index_name: str = "default",
|
||||
text_key: str = "text",
|
||||
embedding_key: str = "embedding",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
client: MongoDB client.
|
||||
namespace: MongoDB namespace to add the texts to.
|
||||
embedding: Text embedding model to use.
|
||||
text_key: MongoDB field that will contain the text for each
|
||||
document.
|
||||
embedding_key: MongoDB field that will contain the embedding for
|
||||
each document.
|
||||
"""
|
||||
self._client = client
|
||||
db_name, collection_name = namespace.split(".")
|
||||
self._collection = client[db_name][collection_name]
|
||||
self._embedding = embedding
|
||||
self._index_name = index_name
|
||||
self._text_key = text_key
|
||||
self._embedding_key = embedding_key
|
||||
|
||||
@classmethod
|
||||
def from_connection_string(
|
||||
cls,
|
||||
connection_string: str,
|
||||
namespace: str,
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> MongoDBAtlasVectorSearch:
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import pymongo, please install it with "
|
||||
"`pip install pymongo`."
|
||||
)
|
||||
client: MongoClient = MongoClient(connection_string)
|
||||
return cls(client, namespace, embedding, **kwargs)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
|
||||
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
result_ids = []
|
||||
for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||
texts_batch.append(text)
|
||||
metadatas_batch.append(metadata)
|
||||
if (i + 1) % batch_size == 0:
|
||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
if texts_batch:
|
||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||
return result_ids
|
||||
|
||||
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
|
||||
if not texts:
|
||||
return []
|
||||
# Embed and create the documents
|
||||
embeddings = self._embedding.embed_documents(texts)
|
||||
to_insert = [
|
||||
{self._text_key: t, self._embedding_key: embedding, **m}
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in MongoDB Atlas
|
||||
insert_result = self._collection.insert_many(to_insert)
|
||||
return insert_result.inserted_ids
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
pre_filter: Optional[dict] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return MongoDB documents most similar to query, along with scores.
|
||||
|
||||
Use the knnBeta Operator available in MongoDB Atlas Search
|
||||
This feature is in early access and available only for evaluation purposes, to
|
||||
validate functionality, and to gather feedback from a small closed group of
|
||||
early access users. It is not recommended for production deployments as we
|
||||
may introduce breaking changes.
|
||||
For more: https://www.mongodb.com/docs/atlas/atlas-search/knn-beta
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Optional Number of Documents to return. Defaults to 4.
|
||||
pre_filter: Optional Dictionary of argument(s) to prefilter on document
|
||||
fields.
|
||||
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
|
||||
following the knnBeta search.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
knn_beta = {
|
||||
"vector": self._embedding.embed_query(query),
|
||||
"path": self._embedding_key,
|
||||
"k": k,
|
||||
}
|
||||
if pre_filter:
|
||||
knn_beta["filter"] = pre_filter
|
||||
pipeline = [
|
||||
{
|
||||
"$search": {
|
||||
"index": self._index_name,
|
||||
"knnBeta": knn_beta,
|
||||
}
|
||||
},
|
||||
{"$project": {"score": {"$meta": "searchScore"}, self._embedding_key: 0}},
|
||||
]
|
||||
if post_filter_pipeline is not None:
|
||||
pipeline.extend(post_filter_pipeline)
|
||||
cursor = self._collection.aggregate(pipeline)
|
||||
docs = []
|
||||
for res in cursor:
|
||||
text = res.pop(self._text_key)
|
||||
score = res.pop("score")
|
||||
docs.append((Document(page_content=text, metadata=res), score))
|
||||
return docs
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
pre_filter: Optional[dict] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return MongoDB documents most similar to query.
|
||||
|
||||
Use the knnBeta Operator available in MongoDB Atlas Search
|
||||
This feature is in early access and available only for evaluation purposes, to
|
||||
validate functionality, and to gather feedback from a small closed group of
|
||||
early access users. It is not recommended for production deployments as we may
|
||||
introduce breaking changes.
|
||||
For more: https://www.mongodb.com/docs/atlas/atlas-search/knn-beta
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Optional Number of Documents to return. Defaults to 4.
|
||||
pre_filter: Optional Dictionary of argument(s) to prefilter on document
|
||||
fields.
|
||||
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
|
||||
following the knnBeta search.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query,
|
||||
k=k,
|
||||
pre_filter=pre_filter,
|
||||
post_filter_pipeline=post_filter_pipeline,
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
client: Optional[MongoClient] = None,
|
||||
namespace: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> MongoDBAtlasVectorSearch:
|
||||
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents.
|
||||
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Adds the documents to a provided MongoDB Atlas Vector Search index
|
||||
(Lucene)
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from pymongo import MongoClient
|
||||
|
||||
from langchain.vectorstores import MongoDBAtlasVectorSearch
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||
namespace = "<db_name>.<collection_name>"
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embeddings,
|
||||
metadatas=metadatas,
|
||||
client=client,
|
||||
namespace=namespace
|
||||
)
|
||||
"""
|
||||
if not client or not namespace:
|
||||
raise ValueError("Must provide 'client' and 'namespace' named parameters.")
|
||||
vecstore = cls(client, namespace, embedding, **kwargs)
|
||||
vecstore.add_texts(texts, metadatas=metadatas)
|
||||
return vecstore
|
||||
58
poetry.lock
generated
58
poetry.lock
generated
@@ -6626,6 +6626,35 @@ files = [
|
||||
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "py-trello"
|
||||
version = "0.19.0"
|
||||
description = "Python wrapper around the Trello API"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "py-trello-0.19.0.tar.gz", hash = "sha256:f4a8c05db61fad0ef5fa35d62c29806c75d9d2b797358d9cf77275e2cbf23020"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
python-dateutil = "*"
|
||||
pytz = "*"
|
||||
requests = "*"
|
||||
requests-oauthlib = ">=0.4.1"
|
||||
|
||||
[[package]]
|
||||
name = "py4j"
|
||||
version = "0.10.9.7"
|
||||
description = "Enables Python programs to dynamically access arbitrary Java objects"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"},
|
||||
{file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyaes"
|
||||
version = "1.6.1"
|
||||
@@ -6936,7 +6965,7 @@ tests = ["duckdb", "polars[pandas,pyarrow]", "pytest"]
|
||||
name = "pymongo"
|
||||
version = "4.3.3"
|
||||
description = "Python driver for MongoDB <http://www.mongodb.org>"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -7212,6 +7241,27 @@ files = [
|
||||
{file = "PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyspark"
|
||||
version = "3.4.0"
|
||||
description = "Apache Spark Python API"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pyspark-3.4.0.tar.gz", hash = "sha256:167a23e11854adb37f8602de6fcc3a4f96fd5f1e323b9bb83325f38408c5aafd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
py4j = "0.10.9.7"
|
||||
|
||||
[package.extras]
|
||||
connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.48.1)", "grpcio-status (>=1.48.1)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||
ml = ["numpy (>=1.15)"]
|
||||
mllib = ["numpy (>=1.15)"]
|
||||
pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||
sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytesseract"
|
||||
version = "0.3.10"
|
||||
@@ -10898,12 +10948,12 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
||||
cffi = ["cffi (>=1.11)"]
|
||||
|
||||
[extras]
|
||||
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "openai"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
|
||||
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
qdrant = ["qdrant-client"]
|
||||
@@ -10912,4 +10962,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "5e83a1f4ca8c0d3107363e393485174fd72ce9db93db5dc7c21b2dd37b184e66"
|
||||
content-hash = "937d2f0165f6aa381ea1e26002272a92b189ab18607bd05895e36d23f56978f4"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain"
|
||||
version = "0.0.184"
|
||||
version = "0.0.185"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
@@ -36,6 +36,7 @@ jinja2 = {version = "^3", optional = true}
|
||||
tiktoken = {version = "^0.3.2", optional = true, python="^3.9"}
|
||||
pinecone-client = {version = "^2", optional = true}
|
||||
pinecone-text = {version = "^0.4.2", optional = true}
|
||||
pymongo = {version = "^4.3.3", optional = true}
|
||||
clickhouse-connect = {version="^0.5.14", optional=true}
|
||||
weaviate-client = {version = "^3", optional = true}
|
||||
google-api-python-client = {version = "2.70.0", optional = true}
|
||||
@@ -97,8 +98,10 @@ scikit-learn = {version = "^1.2.2", optional = true}
|
||||
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
|
||||
azure-ai-vision = {version = "^0.11.1b1", optional = true}
|
||||
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
|
||||
py-trello = {version = "^0.19.0", optional = true}
|
||||
momento = {version = "^1.5.0", optional = true}
|
||||
bibtexparser = {version = "^1.4.0", optional = true}
|
||||
pyspark = {version = "^3.4.0", optional = true}
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
autodoc_pydantic = "^1.8.0"
|
||||
@@ -157,6 +160,7 @@ elasticsearch = {extras = ["async"], version = "^8.6.2"}
|
||||
redis = "^4.5.4"
|
||||
pinecone-client = "^2.2.1"
|
||||
pinecone-text = "^0.4.2"
|
||||
pymongo = "^4.3.3"
|
||||
clickhouse-connect = "^0.5.14"
|
||||
pgvector = "^0.1.6"
|
||||
transformers = "^4.27.4"
|
||||
@@ -172,7 +176,6 @@ gptcache = "^0.1.9"
|
||||
promptlayer = "^0.1.80"
|
||||
tair = "^1.3.3"
|
||||
wikipedia = "^1"
|
||||
pymongo = "^4.3.3"
|
||||
cassandra-driver = "^3.27.0"
|
||||
arxiv = "^1.4"
|
||||
mastodon-py = "^1.8.1"
|
||||
@@ -232,6 +235,7 @@ all = [
|
||||
"jinja2",
|
||||
"pinecone-client",
|
||||
"pinecone-text",
|
||||
"pymongo",
|
||||
"weaviate-client",
|
||||
"redis",
|
||||
"google-api-python-client",
|
||||
@@ -298,7 +302,9 @@ extended_testing = [
|
||||
"gql",
|
||||
"requests_toolbelt",
|
||||
"html2text",
|
||||
"py-trello",
|
||||
"scikit-learn",
|
||||
"pyspark",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@@ -22,4 +22,8 @@ PINECONE_ENVIRONMENT=us-west4-gcp
|
||||
# details here https://learn.microsoft.com/en-us/dotnet/api/azure.identity.defaultazurecredential?view=azure-dotnet
|
||||
POWERBI_DATASET_ID=_powerbi_dataset_id_here
|
||||
POWERBI_TABLE_NAME=_test_table_name_here
|
||||
POWERBI_NUMROWS=_num_rows_in_your_test_table
|
||||
POWERBI_NUMROWS=_num_rows_in_your_test_table
|
||||
|
||||
|
||||
# MongoDB Atlas Vector Search
|
||||
MONGODB_ATLAS_URI=your_mongodb_atlas_connection_string
|
||||
12
tests/integration_tests/document_loaders/test_github.py
Normal file
12
tests/integration_tests/document_loaders/test_github.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain.document_loaders.github import GitHubIssuesLoader
|
||||
|
||||
|
||||
def test_issues_load() -> None:
|
||||
title = "DocumentLoader for GitHub"
|
||||
loader = GitHubIssuesLoader(
|
||||
repo="hwchase17/langchain", creator="UmerHA", state="all"
|
||||
)
|
||||
docs = loader.load()
|
||||
titles = [d.metadata["title"] for d in docs]
|
||||
assert title in titles
|
||||
assert all(doc.metadata["creator"] == "UmerHA" for doc in docs)
|
||||
@@ -0,0 +1,38 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||
|
||||
|
||||
def test_pyspark_loader_load_valid_data() -> None:
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
# Requires a session to be set up
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
data = [
|
||||
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
|
||||
]
|
||||
df = spark.createDataFrame(data, ["text", "label"])
|
||||
|
||||
expected_docs = [
|
||||
Document(
|
||||
page_content=data[0][0],
|
||||
metadata={"label": data[0][1]},
|
||||
),
|
||||
Document(
|
||||
page_content=data[1][0],
|
||||
metadata={"label": data[1][1]},
|
||||
),
|
||||
Document(
|
||||
page_content=data[2][0],
|
||||
metadata={"label": data[2][1]},
|
||||
),
|
||||
]
|
||||
|
||||
loader = PySparkDataFrameLoader(
|
||||
spark_session=spark, df=df, page_content_column="text"
|
||||
)
|
||||
result = loader.load()
|
||||
|
||||
assert result == expected_docs
|
||||
135
tests/integration_tests/vectorstores/test_mongodb_atlas.py
Normal file
135
tests/integration_tests/vectorstores/test_mongodb_atlas.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Test MongoDB Atlas Vector Search functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo import MongoClient
|
||||
|
||||
INDEX_NAME = "langchain-test-index"
|
||||
NAMESPACE = "langchain_test_db.langchain_test_collection"
|
||||
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
||||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||
|
||||
|
||||
def get_test_client() -> Optional[MongoClient]:
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
|
||||
client: MongoClient = MongoClient(CONNECTION_STRING)
|
||||
return client
|
||||
except: # noqa: E722
|
||||
return None
|
||||
|
||||
|
||||
# Instantiate as constant instead of pytest fixture to prevent needing to make multiple
|
||||
# connections.
|
||||
TEST_CLIENT = get_test_client()
|
||||
|
||||
|
||||
class TestMongoDBAtlasVectorSearch:
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
# insure the test collection is empty
|
||||
assert TEST_CLIENT[DB_NAME][COLLECTION_NAME].count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls) -> None:
|
||||
# delete all the documents in the collection
|
||||
TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index]
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self) -> None:
|
||||
# delete all the documents in the collection
|
||||
TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index]
|
||||
|
||||
def test_from_documents(self, embedding_openai: Embeddings) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
documents = [
|
||||
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
||||
Document(page_content="Cats have fluff.", metadata={"b": 1}),
|
||||
Document(page_content="What is a sandwich?", metadata={"c": 1}),
|
||||
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
|
||||
]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_documents(
|
||||
documents,
|
||||
embedding_openai,
|
||||
client=TEST_CLIENT,
|
||||
namespace=NAMESPACE,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||
assert output[0].page_content == "What is a sandwich?"
|
||||
assert output[0].metadata["c"] == 1
|
||||
|
||||
def test_from_texts(self, embedding_openai: Embeddings) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"That fence is purple.",
|
||||
]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
client=TEST_CLIENT,
|
||||
namespace=NAMESPACE,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||
assert output[0].page_content == "What is a sandwich?"
|
||||
|
||||
def test_from_texts_with_metadatas(self, embedding_openai: Embeddings) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"The fence is purple.",
|
||||
]
|
||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
client=TEST_CLIENT,
|
||||
namespace=NAMESPACE,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||
assert output[0].page_content == "What is a sandwich?"
|
||||
assert output[0].metadata["c"] == 1
|
||||
|
||||
def test_from_texts_with_metadatas_and_pre_filter(
|
||||
self, embedding_openai: Embeddings
|
||||
) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"The fence is purple.",
|
||||
]
|
||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
client=TEST_CLIENT,
|
||||
namespace=NAMESPACE,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
output = vectorstore.similarity_search(
|
||||
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
|
||||
)
|
||||
assert output == []
|
||||
114
tests/unit_tests/document_loaders/test_github.py
Normal file
114
tests/unit_tests/document_loaders/test_github.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.github import GitHubIssuesLoader
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
loader = GitHubIssuesLoader(repo="repo", access_token="access_token")
|
||||
assert loader.repo == "repo"
|
||||
assert loader.access_token == "access_token"
|
||||
assert loader.headers == {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"Authorization": "Bearer access_token",
|
||||
}
|
||||
|
||||
|
||||
def test_invalid_initialization() -> None:
|
||||
# Invalid parameter
|
||||
with pytest.raises(ValueError):
|
||||
GitHubIssuesLoader(invalid="parameter")
|
||||
|
||||
# Invalid value for valid parameter
|
||||
with pytest.raises(ValueError):
|
||||
GitHubIssuesLoader(state="invalid_state")
|
||||
|
||||
# Invalid type for labels
|
||||
with pytest.raises(ValueError):
|
||||
GitHubIssuesLoader(labels="not_a_list")
|
||||
|
||||
# Invalid date format for since
|
||||
with pytest.raises(ValueError):
|
||||
GitHubIssuesLoader(since="not_a_date")
|
||||
|
||||
|
||||
def test_load(mocker: MockerFixture) -> None:
|
||||
mocker.patch(
|
||||
"requests.get", return_value=mocker.MagicMock(json=lambda: [], links=None)
|
||||
)
|
||||
loader = GitHubIssuesLoader(repo="repo", access_token="access_token")
|
||||
documents = loader.load()
|
||||
assert documents == []
|
||||
|
||||
|
||||
def test_parse_issue() -> None:
|
||||
issue = {
|
||||
"html_url": "https://github.com/repo/issue/1",
|
||||
"title": "Example Issue 1",
|
||||
"user": {"login": "username1"},
|
||||
"created_at": "2023-01-01T00:00:00Z",
|
||||
"comments": 1,
|
||||
"state": "open",
|
||||
"labels": [{"name": "bug"}],
|
||||
"assignee": {"login": "username2"},
|
||||
"milestone": {"title": "v1.0"},
|
||||
"locked": "False",
|
||||
"number": "1",
|
||||
"body": "This is an example issue 1",
|
||||
}
|
||||
expected_document = Document(
|
||||
page_content=issue["body"], # type: ignore
|
||||
metadata={
|
||||
"url": issue["html_url"],
|
||||
"title": issue["title"],
|
||||
"creator": issue["user"]["login"], # type: ignore
|
||||
"created_at": issue["created_at"],
|
||||
"comments": issue["comments"],
|
||||
"state": issue["state"],
|
||||
"labels": [label["name"] for label in issue["labels"]], # type: ignore
|
||||
"assignee": issue["assignee"]["login"], # type: ignore
|
||||
"milestone": issue["milestone"]["title"], # type: ignore
|
||||
"locked": issue["locked"],
|
||||
"number": issue["number"],
|
||||
"is_pull_request": False,
|
||||
},
|
||||
)
|
||||
loader = GitHubIssuesLoader(repo="repo", access_token="access_token")
|
||||
document = loader.parse_issue(issue)
|
||||
assert document == expected_document
|
||||
|
||||
|
||||
def test_url() -> None:
|
||||
# No parameters
|
||||
loader = GitHubIssuesLoader(repo="repo", access_token="access_token")
|
||||
assert loader.url == "https://api.github.com/repos/repo/issues?"
|
||||
|
||||
# parameters: state, sort
|
||||
loader = GitHubIssuesLoader(
|
||||
repo="repo", access_token="access_token", state="open", sort="created"
|
||||
)
|
||||
assert (
|
||||
loader.url == "https://api.github.com/repos/repo/issues?state=open&sort=created"
|
||||
)
|
||||
|
||||
# parameters: milestone, state, assignee, creator, mentioned, labels, sort,
|
||||
# direction, since
|
||||
loader = GitHubIssuesLoader(
|
||||
repo="repo",
|
||||
access_token="access_token",
|
||||
milestone="*",
|
||||
state="closed",
|
||||
assignee="user1",
|
||||
creator="user2",
|
||||
mentioned="user3",
|
||||
labels=["bug", "ui", "@high"],
|
||||
sort="comments",
|
||||
direction="asc",
|
||||
since="2023-05-26T00:00:00Z",
|
||||
)
|
||||
assert loader.url == (
|
||||
"https://api.github.com/repos/repo/issues?milestone=*&state=closed"
|
||||
"&assignee=user1&creator=user2&mentioned=user3&labels=bug,ui,@high"
|
||||
"&sort=comments&direction=asc&since=2023-05-26T00:00:00Z"
|
||||
)
|
||||
341
tests/unit_tests/document_loaders/test_trello.py
Normal file
341
tests/unit_tests/document_loaders/test_trello.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.document_loaders.trello import TrelloLoader
|
||||
|
||||
|
||||
def list_to_objects(dict_list: list) -> list:
|
||||
"""Helper to convert dict objects."""
|
||||
return [
|
||||
namedtuple("Object", d.keys())(**d) for d in dict_list if isinstance(d, dict)
|
||||
]
|
||||
|
||||
|
||||
def card_list_to_objects(cards: list) -> list:
|
||||
"""Helper to convert dict cards into trello weird mix of objects and dictionaries"""
|
||||
for card in cards:
|
||||
card["checklists"] = list_to_objects(card.get("checklists"))
|
||||
card["labels"] = list_to_objects(card.get("labels"))
|
||||
return list_to_objects(cards)
|
||||
|
||||
|
||||
class MockBoard:
|
||||
"""
|
||||
Defining Trello mock board internal object to use in the patched method.
|
||||
"""
|
||||
|
||||
def __init__(self, id: str, name: str, cards: list, lists: list):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.cards = cards
|
||||
self.lists = lists
|
||||
|
||||
def get_cards(self, card_filter: Optional[str] = "") -> list:
|
||||
"""We do not need to test the card-filter since is on Trello Client side."""
|
||||
return self.cards
|
||||
|
||||
def list_lists(self) -> list:
|
||||
return self.lists
|
||||
|
||||
|
||||
TRELLO_LISTS = [
|
||||
{
|
||||
"id": "5555cacbc4daa90564b34cf2",
|
||||
"name": "Publishing Considerations",
|
||||
},
|
||||
{
|
||||
"id": "5555059b74c03b3a9e362cd0",
|
||||
"name": "Backlog",
|
||||
},
|
||||
{
|
||||
"id": "555505a3427fd688c1ca5ebd",
|
||||
"name": "Selected for Milestone",
|
||||
},
|
||||
{
|
||||
"id": "555505ba95ff925f9fb1b370",
|
||||
"name": "Blocked",
|
||||
},
|
||||
{
|
||||
"id": "555505a695ff925f9fb1b13d",
|
||||
"name": "In Progress",
|
||||
},
|
||||
{
|
||||
"id": "555505bdfe380c7edc8ca1a3",
|
||||
"name": "Done",
|
||||
},
|
||||
]
|
||||
# Create a mock list of cards.
|
||||
TRELLO_CARDS_QA = [
|
||||
{
|
||||
"id": "12350aca6952888df7975903",
|
||||
"name": "Closed Card Title",
|
||||
"description": "This is the <em>description</em> of Closed Card.",
|
||||
"closed": True,
|
||||
"labels": [],
|
||||
"due_date": "",
|
||||
"url": "https://trello.com/card/12350aca6952888df7975903",
|
||||
"list_id": "555505bdfe380c7edc8ca1a3",
|
||||
"checklists": [
|
||||
{
|
||||
"name": "Checklist 1",
|
||||
"items": [
|
||||
{
|
||||
"name": "Item 1",
|
||||
"state": "pending",
|
||||
},
|
||||
{
|
||||
"name": "Item 2",
|
||||
"state": "completed",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"comments": [
|
||||
{
|
||||
"data": {
|
||||
"text": "This is a comment on a <s>Closed</s> Card.",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "45650aca6952888df7975903",
|
||||
"name": "Card 2",
|
||||
"description": "This is the description of <strong>Card 2</strong>.",
|
||||
"closed": False,
|
||||
"labels": [{"name": "Medium"}, {"name": "Task"}],
|
||||
"due_date": "",
|
||||
"url": "https://trello.com/card/45650aca6952888df7975903",
|
||||
"list_id": "555505a695ff925f9fb1b13d",
|
||||
"checklists": [],
|
||||
"comments": [],
|
||||
},
|
||||
{
|
||||
"id": "55550aca6952888df7975903",
|
||||
"name": "Camera",
|
||||
"description": "<div></div>",
|
||||
"closed": False,
|
||||
"labels": [{"name": "Task"}],
|
||||
"due_date": "",
|
||||
"url": "https://trello.com/card/55550aca6952888df7975903",
|
||||
"list_id": "555505a3427fd688c1ca5ebd",
|
||||
"checklists": [
|
||||
{
|
||||
"name": "Tasks",
|
||||
"items": [
|
||||
{"name": "Zoom", "state": "complete"},
|
||||
{"name": "Follow players", "state": "complete"},
|
||||
{
|
||||
"name": "camera limit to stage size",
|
||||
"state": "complete",
|
||||
},
|
||||
{"name": "Post Processing effects", "state": "complete"},
|
||||
{
|
||||
"name": "Shitch to universal render pipeline",
|
||||
"state": "complete",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"comments": [
|
||||
{
|
||||
"data": {
|
||||
"text": (
|
||||
"to follow group of players use Group Camera feature of "
|
||||
"cinemachine."
|
||||
)
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text": "Use 'Impulse' <s>Cinemachine</s> feature for camera shake."
|
||||
}
|
||||
},
|
||||
{"data": {"text": "depth of field with custom shader."}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trello_client() -> Any:
|
||||
"""Fixture that creates a mock for trello.TrelloClient."""
|
||||
# Create a mock `trello.TrelloClient` object.
|
||||
with patch("trello.TrelloClient") as mock_trello_client:
|
||||
# Create a mock list of trello list (columns in the UI).
|
||||
|
||||
# The trello client returns a hierarchy mix of objects and dictionaries.
|
||||
list_objs = list_to_objects(TRELLO_LISTS)
|
||||
cards_qa_objs = card_list_to_objects(TRELLO_CARDS_QA)
|
||||
boards = [
|
||||
MockBoard("5555eaafea917522902a2a2c", "Research", [], list_objs),
|
||||
MockBoard("55559f6002dd973ad8cdbfb7", "QA", cards_qa_objs, list_objs),
|
||||
]
|
||||
|
||||
# Patch `get_boards()` method of the mock `TrelloClient` object to return the
|
||||
# mock list of boards.
|
||||
mock_trello_client.return_value.list_boards.return_value = boards
|
||||
yield mock_trello_client.return_value
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_trello_client")
|
||||
@pytest.mark.requires("trello", "bs4", "lxml")
|
||||
class TestTrelloLoader(unittest.TestCase):
|
||||
def test_empty_board(self) -> None:
|
||||
"""
|
||||
Test loading a board with no cards.
|
||||
"""
|
||||
trello_loader = TrelloLoader.from_credentials(
|
||||
"Research",
|
||||
api_key="API_KEY",
|
||||
token="API_TOKEN",
|
||||
)
|
||||
documents = trello_loader.load()
|
||||
self.assertEqual(len(documents), 0, "Empty board returns an empty list.")
|
||||
|
||||
def test_complete_text_and_metadata(self) -> None:
|
||||
"""
|
||||
Test loading a board cards with all metadata.
|
||||
"""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
trello_loader = TrelloLoader.from_credentials(
|
||||
"QA",
|
||||
api_key="API_KEY",
|
||||
token="API_TOKEN",
|
||||
)
|
||||
documents = trello_loader.load()
|
||||
self.assertEqual(len(documents), len(TRELLO_CARDS_QA), "Card count matches.")
|
||||
|
||||
soup = BeautifulSoup(documents[0].page_content, "html.parser")
|
||||
self.assertTrue(
|
||||
len(soup.find_all()) == 0,
|
||||
"There is not markup in Closed Card document content.",
|
||||
)
|
||||
|
||||
# Check samples of every field type is present in page content.
|
||||
texts = [
|
||||
"Closed Card Title",
|
||||
"This is the description of Closed Card.",
|
||||
"Checklist 1",
|
||||
"Item 1:pending",
|
||||
"This is a comment on a Closed Card.",
|
||||
]
|
||||
for text in texts:
|
||||
self.assertTrue(text in documents[0].page_content)
|
||||
|
||||
# Check all metadata is present in first Card
|
||||
self.assertEqual(
|
||||
documents[0].metadata,
|
||||
{
|
||||
"title": "Closed Card Title",
|
||||
"id": "12350aca6952888df7975903",
|
||||
"url": "https://trello.com/card/12350aca6952888df7975903",
|
||||
"labels": [],
|
||||
"list": "Done",
|
||||
"closed": True,
|
||||
"due_date": "",
|
||||
},
|
||||
"Metadata of Closed Card Matches.",
|
||||
)
|
||||
|
||||
soup = BeautifulSoup(documents[1].page_content, "html.parser")
|
||||
self.assertTrue(
|
||||
len(soup.find_all()) == 0,
|
||||
"There is not markup in Card 2 document content.",
|
||||
)
|
||||
|
||||
# Check samples of every field type is present in page content.
|
||||
texts = [
|
||||
"Card 2",
|
||||
"This is the description of Card 2.",
|
||||
]
|
||||
for text in texts:
|
||||
self.assertTrue(text in documents[1].page_content)
|
||||
|
||||
# Check all metadata is present in second Card
|
||||
self.assertEqual(
|
||||
documents[1].metadata,
|
||||
{
|
||||
"title": "Card 2",
|
||||
"id": "45650aca6952888df7975903",
|
||||
"url": "https://trello.com/card/45650aca6952888df7975903",
|
||||
"labels": ["Medium", "Task"],
|
||||
"list": "In Progress",
|
||||
"closed": False,
|
||||
"due_date": "",
|
||||
},
|
||||
"Metadata of Card 2 Matches.",
|
||||
)
|
||||
|
||||
soup = BeautifulSoup(documents[2].page_content, "html.parser")
|
||||
self.assertTrue(
|
||||
len(soup.find_all()) == 0,
|
||||
"There is not markup in Card 2 document content.",
|
||||
)
|
||||
|
||||
# Check samples of every field type is present in page content.
|
||||
texts = [
|
||||
"Camera",
|
||||
"camera limit to stage size:complete",
|
||||
"Use 'Impulse' Cinemachine feature for camera shake.",
|
||||
]
|
||||
|
||||
for text in texts:
|
||||
self.assertTrue(text in documents[2].page_content, text + " is present.")
|
||||
|
||||
# Check all metadata is present in second Card
|
||||
self.assertEqual(
|
||||
documents[2].metadata,
|
||||
{
|
||||
"title": "Camera",
|
||||
"id": "55550aca6952888df7975903",
|
||||
"url": "https://trello.com/card/55550aca6952888df7975903",
|
||||
"labels": ["Task"],
|
||||
"list": "Selected for Milestone",
|
||||
"closed": False,
|
||||
"due_date": "",
|
||||
},
|
||||
"Metadata of Camera Card matches.",
|
||||
)
|
||||
|
||||
def test_partial_text_and_metadata(self) -> None:
|
||||
"""
|
||||
Test loading a board cards removing some text and metadata.
|
||||
"""
|
||||
trello_loader = TrelloLoader.from_credentials(
|
||||
"QA",
|
||||
api_key="API_KEY",
|
||||
token="API_TOKEN",
|
||||
extra_metadata=("list"),
|
||||
include_card_name=False,
|
||||
include_checklist=False,
|
||||
include_comments=False,
|
||||
)
|
||||
documents = trello_loader.load()
|
||||
|
||||
# Check samples of every field type is present in page content.
|
||||
texts = [
|
||||
"Closed Card Title",
|
||||
"Checklist 1",
|
||||
"Item 1:pending",
|
||||
"This is a comment on a Closed Card.",
|
||||
]
|
||||
for text in texts:
|
||||
self.assertFalse(text in documents[0].page_content)
|
||||
|
||||
# Check all metadata is present in first Card
|
||||
self.assertEqual(
|
||||
documents[0].metadata,
|
||||
{
|
||||
"title": "Closed Card Title",
|
||||
"id": "12350aca6952888df7975903",
|
||||
"url": "https://trello.com/card/12350aca6952888df7975903",
|
||||
"list": "Done",
|
||||
},
|
||||
"Metadata of Closed Card Matches.",
|
||||
)
|
||||
@@ -4,9 +4,25 @@ import pytest
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import (
|
||||
CharacterTextSplitter,
|
||||
CodeTextSplitter,
|
||||
Language,
|
||||
PythonCodeTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
)
|
||||
|
||||
FAKE_PYTHON_TEXT = """
|
||||
class Foo:
|
||||
|
||||
def bar():
|
||||
|
||||
|
||||
def foo():
|
||||
|
||||
def testing_func():
|
||||
|
||||
def bar():
|
||||
"""
|
||||
|
||||
|
||||
def test_character_text_splitter() -> None:
|
||||
"""Test splitting by character count."""
|
||||
@@ -135,15 +151,16 @@ Bye!\n\n-H."""
|
||||
"Okay then",
|
||||
"f f f f.",
|
||||
"This is a",
|
||||
"a weird",
|
||||
"weird",
|
||||
"text to",
|
||||
"write, but",
|
||||
"gotta test",
|
||||
"the",
|
||||
"splittingg",
|
||||
"ggg",
|
||||
"write,",
|
||||
"but gotta",
|
||||
"test the",
|
||||
"splitting",
|
||||
"gggg",
|
||||
"some how.",
|
||||
"Bye!\n\n-H.",
|
||||
"Bye!",
|
||||
"-H.",
|
||||
]
|
||||
assert output == expected_output
|
||||
|
||||
@@ -168,3 +185,328 @@ def test_split_documents() -> None:
|
||||
Document(page_content="z", metadata={"source": "1"}),
|
||||
]
|
||||
assert splitter.split_documents(docs) == expected_output
|
||||
|
||||
|
||||
def test_python_text_splitter() -> None:
|
||||
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
|
||||
splits = splitter.split_text(FAKE_PYTHON_TEXT)
|
||||
split_0 = """class Foo:\n\n def bar():"""
|
||||
split_1 = """def foo():"""
|
||||
split_2 = """def testing_func():"""
|
||||
split_3 = """def bar():"""
|
||||
expected_splits = [split_0, split_1, split_2, split_3]
|
||||
assert splits == expected_splits
|
||||
|
||||
|
||||
CHUNK_SIZE = 16
|
||||
|
||||
|
||||
def test_python_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.PYTHON, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
def hello_world():
|
||||
print("Hello, World!")
|
||||
|
||||
# Call the function
|
||||
hello_world()
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"def",
|
||||
"hello_world():",
|
||||
'print("Hello,',
|
||||
'World!")',
|
||||
"# Call the",
|
||||
"function",
|
||||
"hello_world()",
|
||||
]
|
||||
|
||||
|
||||
def test_golang_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.GO, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func helloWorld() {
|
||||
fmt.Println("Hello, World!")
|
||||
}
|
||||
|
||||
func main() {
|
||||
helloWorld()
|
||||
}
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"package main",
|
||||
'import "fmt"',
|
||||
"func",
|
||||
"helloWorld() {",
|
||||
'fmt.Println("He',
|
||||
"llo,",
|
||||
'World!")',
|
||||
"}",
|
||||
"func main() {",
|
||||
"helloWorld()",
|
||||
"}",
|
||||
]
|
||||
|
||||
|
||||
def test_rst_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.RST, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
Sample Document
|
||||
===============
|
||||
|
||||
Section
|
||||
-------
|
||||
|
||||
This is the content of the section.
|
||||
|
||||
Lists
|
||||
-----
|
||||
|
||||
- Item 1
|
||||
- Item 2
|
||||
- Item 3
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"Sample Document",
|
||||
"===============",
|
||||
"Section",
|
||||
"-------",
|
||||
"This is the",
|
||||
"content of the",
|
||||
"section.",
|
||||
"Lists\n-----",
|
||||
"- Item 1",
|
||||
"- Item 2",
|
||||
"- Item 3",
|
||||
]
|
||||
|
||||
|
||||
def test_proto_file_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.PROTO, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
syntax = "proto3";
|
||||
|
||||
package example;
|
||||
|
||||
message Person {
|
||||
string name = 1;
|
||||
int32 age = 2;
|
||||
repeated string hobbies = 3;
|
||||
}
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"syntax =",
|
||||
'"proto3";',
|
||||
"package",
|
||||
"example;",
|
||||
"message Person",
|
||||
"{",
|
||||
"string name",
|
||||
"= 1;",
|
||||
"int32 age =",
|
||||
"2;",
|
||||
"repeated",
|
||||
"string hobbies",
|
||||
"= 3;",
|
||||
"}",
|
||||
]
|
||||
|
||||
|
||||
def test_javascript_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.JS, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
function helloWorld() {
|
||||
console.log("Hello, World!");
|
||||
}
|
||||
|
||||
// Call the function
|
||||
helloWorld();
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"function",
|
||||
"helloWorld() {",
|
||||
'console.log("He',
|
||||
"llo,",
|
||||
'World!");',
|
||||
"}",
|
||||
"// Call the",
|
||||
"function",
|
||||
"helloWorld();",
|
||||
]
|
||||
|
||||
|
||||
def test_java_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.JAVA, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
public class HelloWorld {
|
||||
public static void main(String[] args) {
|
||||
System.out.println("Hello, World!");
|
||||
}
|
||||
}
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"public class",
|
||||
"HelloWorld {",
|
||||
"public",
|
||||
"static void",
|
||||
"main(String[]",
|
||||
"args) {",
|
||||
"System.out.prin",
|
||||
'tln("Hello,',
|
||||
'World!");',
|
||||
"}\n}",
|
||||
]
|
||||
|
||||
|
||||
def test_cpp_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.CPP, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
#include <iostream>
|
||||
|
||||
int main() {
|
||||
std::cout << "Hello, World!" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"#include",
|
||||
"<iostream>",
|
||||
"int main() {",
|
||||
"std::cout",
|
||||
'<< "Hello,',
|
||||
'World!" <<',
|
||||
"std::endl;",
|
||||
"return 0;\n}",
|
||||
]
|
||||
|
||||
|
||||
def test_scala_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.SCALA, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
object HelloWorld {
|
||||
def main(args: Array[String]): Unit = {
|
||||
println("Hello, World!")
|
||||
}
|
||||
}
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"object",
|
||||
"HelloWorld {",
|
||||
"def",
|
||||
"main(args:",
|
||||
"Array[String]):",
|
||||
"Unit = {",
|
||||
'println("Hello,',
|
||||
'World!")',
|
||||
"}\n}",
|
||||
]
|
||||
|
||||
|
||||
def test_ruby_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.RUBY, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
def hello_world
|
||||
puts "Hello, World!"
|
||||
end
|
||||
|
||||
hello_world
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"def hello_world",
|
||||
'puts "Hello,',
|
||||
'World!"',
|
||||
"end",
|
||||
"hello_world",
|
||||
]
|
||||
|
||||
|
||||
def test_php_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.PHP, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
<?php
|
||||
function hello_world() {
|
||||
echo "Hello, World!";
|
||||
}
|
||||
|
||||
hello_world();
|
||||
?>
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"<?php",
|
||||
"function",
|
||||
"hello_world() {",
|
||||
"echo",
|
||||
'"Hello,',
|
||||
'World!";',
|
||||
"}",
|
||||
"hello_world();",
|
||||
"?>",
|
||||
]
|
||||
|
||||
|
||||
def test_swift_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.SWIFT, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
func helloWorld() {
|
||||
print("Hello, World!")
|
||||
}
|
||||
|
||||
helloWorld()
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"func",
|
||||
"helloWorld() {",
|
||||
'print("Hello,',
|
||||
'World!")',
|
||||
"}",
|
||||
"helloWorld()",
|
||||
]
|
||||
|
||||
|
||||
def test_rust_code_splitter() -> None:
|
||||
splitter = CodeTextSplitter(
|
||||
language=Language.RUST, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
fn main() {
|
||||
println!("Hello, World!");
|
||||
}
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == ["fn main() {", 'println!("Hello', ",", 'World!");', "}"]
|
||||
|
||||
@@ -13,7 +13,12 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
|
||||
from langchain.tools.base import (
|
||||
BaseTool,
|
||||
SchemaAnnotationError,
|
||||
StructuredTool,
|
||||
ToolException,
|
||||
)
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@@ -479,3 +484,75 @@ async def test_create_async_tool() -> None:
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.coroutine is not None
|
||||
assert await test_tool.arun("foo") == "foo"
|
||||
|
||||
|
||||
class _FakeExceptionTool(BaseTool):
|
||||
name = "exception"
|
||||
description = "an exception-throwing tool"
|
||||
exception: Exception = ToolException()
|
||||
|
||||
def _run(self) -> str:
|
||||
raise self.exception
|
||||
|
||||
async def _arun(self) -> str:
|
||||
raise self.exception
|
||||
|
||||
|
||||
def test_exception_handling_bool() -> None:
|
||||
_tool = _FakeExceptionTool(handle_tool_error=True)
|
||||
expected = "Tool execution error"
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_exception_handling_str() -> None:
|
||||
expected = "foo bar"
|
||||
_tool = _FakeExceptionTool(handle_tool_error=expected)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_exception_handling_callable() -> None:
|
||||
expected = "foo bar"
|
||||
handling = lambda _: expected # noqa: E731
|
||||
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_exception_handling_non_tool_exception() -> None:
|
||||
_tool = _FakeExceptionTool(exception=ValueError())
|
||||
with pytest.raises(ValueError):
|
||||
_tool.run({})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling_bool() -> None:
|
||||
_tool = _FakeExceptionTool(handle_tool_error=True)
|
||||
expected = "Tool execution error"
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling_str() -> None:
|
||||
expected = "foo bar"
|
||||
_tool = _FakeExceptionTool(handle_tool_error=expected)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling_callable() -> None:
|
||||
expected = "foo bar"
|
||||
handling = lambda _: expected # noqa: E731
|
||||
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling_non_tool_exception() -> None:
|
||||
_tool = _FakeExceptionTool(exception=ValueError())
|
||||
with pytest.raises(ValueError):
|
||||
await _tool.arun({})
|
||||
|
||||
Reference in New Issue
Block a user