Wfh/ref links (#8454)

This commit is contained in:
William FH 2023-07-29 08:44:32 -07:00 committed by GitHub
parent 13b4f465e2
commit b7c0eb9ecb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 189 additions and 379 deletions

File diff suppressed because one or more lines are too long

View File

@ -5,16 +5,22 @@ import logging
import os
import re
from pathlib import Path
import argparse
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Base URL for all class documentation
_BASE_URL = "https://api.python.langchain.com/en/latest/"
_BASE_URL = "https://api.python.langchain.com/en/latest"
# Regular expression to match Python code blocks
code_block_re = re.compile(r"^(```python\n)(.*?)(```\n)", re.DOTALL | re.MULTILINE)
# Regular expression to match langchain import lines
_IMPORT_RE = re.compile(r"(from\s+(langchain\.\w+(\.\w+)*?)\s+import\s+)(\w+)")
_IMPORT_RE = re.compile(
r"from\s+(langchain\.\w+(\.\w+)*?)\s+import\s+"
r"((?:\w+(?:,\s*)?)*" # Match zero or more words separated by a comma+optional ws
r"(?:\s*\(.*?\))?)", # Match optional parentheses block
re.DOTALL, # Match newlines as well
)
_CURRENT_PATH = Path(__file__).parent.absolute()
# Directory where generated markdown files are stored
@ -24,6 +30,10 @@ _JSON_PATH = _CURRENT_PATH.parent / "api_reference" / "guide_imports.json"
def find_files(path):
"""Find all MDX files in the given path"""
# Check if is file first
if os.path.isfile(path):
yield path
return
for root, _, files in os.walk(path):
for file in files:
if file.endswith(".mdx") or file.endswith(".md"):
@ -37,20 +47,33 @@ def get_full_module_name(module_path, class_name):
return inspect.getmodule(class_).__name__
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--docs_dir",
type=str,
default=_DOCS_DIR,
help="Directory where generated markdown files are stored",
)
return parser.parse_args()
def main():
"""Main function"""
args = get_args()
global_imports = {}
for file in find_files(_DOCS_DIR):
for file in find_files(args.docs_dir):
print(f"Adding links for imports in {file}")
# replace_imports now returns the import information rather than writing it to a file
file_imports = replace_imports(file)
if file_imports:
# Use relative file path as key
relative_path = os.path.relpath(file, _DOCS_DIR)
doc_url = f"https://python.langchain.com/docs/{relative_path.replace('.mdx', '').replace('.md', '')}"
relative_path = (
os.path.relpath(file, _DOCS_DIR).replace(".mdx", "").replace(".md", "")
)
doc_url = f"https://python.langchain.com/docs/{relative_path}"
for import_info in file_imports:
doc_title = import_info["title"]
class_name = import_info["imported"]
@ -77,7 +100,8 @@ def _get_doc_title(data: str, file_name: str) -> str:
def replace_imports(file):
"""Replace imports in each Python code block with links to their documentation and append the import info in a comment"""
"""Replace imports in each Python code block with links to their
documentation and append the import info in a comment"""
all_imports = []
with open(file, "r") as f:
data = f.read()
@ -97,37 +121,45 @@ def replace_imports(file):
# Process imports in the code block
imports = []
for import_match in _IMPORT_RE.finditer(code):
class_name = import_match.group(4)
try:
module_path = get_full_module_name(import_match.group(2), class_name)
except AttributeError as e:
logger.warning(f"Could not find module for {class_name}, {e}")
continue
except ImportError as e:
# Some CentOS OpenSSL issues can cause this to fail
logger.warning(f"Failed to load for class {class_name}, {e}")
continue
module = import_match.group(1)
imports_str = (
import_match.group(3).replace("(\n", "").replace("\n)", "")
) # Handle newlines within parentheses
# remove any newline and spaces, then split by comma
imported_classes = [
imp.strip()
for imp in re.split(r",\s*", imports_str.replace("\n", ""))
if imp.strip()
]
for class_name in imported_classes:
try:
module_path = get_full_module_name(module, class_name)
except AttributeError as e:
logger.warning(f"Could not find module for {class_name}, {e}")
continue
except ImportError as e:
logger.warning(f"Failed to load for class {class_name}, {e}")
continue
url = (
_BASE_URL
+ "/"
+ module_path.split(".")[1]
+ "/"
+ module_path
+ "."
+ class_name
+ ".html"
)
url = (
_BASE_URL
+ module_path.split(".")[1]
+ "/"
+ module_path
+ "."
+ class_name
+ ".html"
)
# Add the import information to our list
imports.append(
{
"imported": class_name,
"source": import_match.group(2),
"docs": url,
"title": _DOC_TITLE,
}
)
# Add the import information to our list
imports.append(
{
"imported": class_name,
"source": module,
"docs": url,
"title": _DOC_TITLE,
}
)
if imports:
all_imports.extend(imports)

View File

@ -216,7 +216,7 @@
},
"outputs": [],
"source": [
"from langchain.experimental.llms import JsonFormer\n",
"from langchain_experimental.llms import JsonFormer\n",
"\n",
"json_former = JsonFormer(json_schema=decoder_schema, pipeline=hf_model)"
]

View File

@ -162,7 +162,7 @@
}
],
"source": [
"from langchain.experimental.llms import RELLM\n",
"from langchain_experimental.llms import RELLM\n",
"\n",
"model = RELLM(pipeline=hf_model, regex=pattern, max_new_tokens=200)\n",
"\n",

View File

@ -13,7 +13,7 @@ This page provides instructions on how to use the DataForSEO search APIs within
The DataForSEO utility wraps the API. To import this utility, use:
```python
from langchain.utilities import DataForSeoAPIWrapper
from langchain.utilities.dataforseo_api_search import DataForSeoAPIWrapper
```
For a detailed walkthrough of this wrapper, see [this notebook](/docs/integrations/tools/dataforseo.ipynb).

View File

@ -177,8 +177,9 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import TransformChain, SQLDatabaseChain, SimpleSequentialChain\n",
"from langchain.sql_database import SQLDatabase"
"from langchain.chains import TransformChain, SimpleSequentialChain\n",
"from langchain.sql_database import SQLDatabase\n",
"from langchain_experimental.sql import SQLDatabaseChain"
]
},
{

View File

@ -15,7 +15,7 @@ pip install rockset
See a [usage example](/docs/integrations/vectorstores/rockset).
```python
from langchain.vectorstores import RocksetDB
from langchain.vectorstores import Rockset
```
## Document Loader

View File

@ -16,5 +16,5 @@ pip install spacy
See a [usage example](/docs/modules/data_connection/document_transformers/text_splitters/split_by_token.html#spacy).
```python
from langchain.llms import SpacyTextSplitter
from langchain.text_splitter import SpacyTextSplitter
```

View File

@ -77,7 +77,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import LocalAIEmbeddings"
"from langchain.embeddings import LocalAIEmbeddings"
]
},
{

View File

@ -15,7 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import DataForSeoAPIWrapper"
"from langchain.utilities.dataforseo_api_search import DataForSeoAPIWrapper"
]
},
{

View File

@ -124,7 +124,7 @@
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.document_loaders import TextLoader\n",
"from langchain.vectorstores.rocksetdb import RocksetDB\n",
"from langchain.vectorstores import Rockset\n",
"\n",
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
"documents = loader.load()\n",
@ -150,7 +150,7 @@
"# Make sure the environment variable OPENAI_API_KEY is set up\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"docsearch = RocksetDB(\n",
"docsearch = Rockset(\n",
" client=rockset_client,\n",
" embeddings=embeddings,\n",
" collection_name=COLLECTION_NAME,\n",
@ -185,7 +185,7 @@
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"output = docsearch.similarity_search_with_relevance_scores(\n",
" query, 4, RocksetDB.DistanceFunction.COSINE_SIM\n",
" query, 4, Rockset.DistanceFunction.COSINE_SIM\n",
")\n",
"print(\"output length:\", len(output))\n",
"for d, dist in output:\n",
@ -221,7 +221,7 @@
"output = docsearch.similarity_search_with_relevance_scores(\n",
" query,\n",
" 4,\n",
" RocksetDB.DistanceFunction.COSINE_SIM,\n",
" Rockset.DistanceFunction.COSINE_SIM,\n",
" where_str=\"{} NOT LIKE '%citizens%'\".format(TEXT_KEY),\n",
")\n",
"print(\"output length:\", len(output))\n",
@ -237,15 +237,16 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0765b822",
"metadata": {},
"source": [
"### 3. [Optional] Drop all inserted documents\n",
"\n",
"In order to delete texts from the Rockset collection, you need to know the unique ID associated with each document inside Rockset. These ids can either be supplied directly by the user while inserting the texts (in the `RocksetDB.add_texts()` function), else Rockset will generate a unique ID or each document. Either way, `Rockset.add_texts()` returns the ids for the inserted documents.\n",
"In order to delete texts from the Rockset collection, you need to know the unique ID associated with each document inside Rockset. These ids can either be supplied directly by the user while inserting the texts (in the `Rockset.add_texts()` function), else Rockset will generate a unique ID or each document. Either way, `Rockset.add_texts()` returns the ids for the inserted documents.\n",
"\n",
"To delete these docs, simply use the `RocksetDB.delete_texts()` function."
"To delete these docs, simply use the `Rockset.delete_texts()` function."
]
},
{

View File

@ -15,10 +15,7 @@
"execution_count": 11,
"id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
"metadata": {
"tags": [],
"vscode": {
"languageId": "python"
}
"tags": []
},
"outputs": [],
"source": [
@ -28,7 +25,7 @@
"\n",
"from pydantic import Extra\n",
"\n",
"from langchain.schema import BaseLanguageModel\n",
"from langchain.schema.language_model import BaseLanguageModel\n",
"from langchain.callbacks.manager import (\n",
" AsyncCallbackManagerForChainRun,\n",
" CallbackManagerForChainRun,\n",
@ -130,11 +127,7 @@
"cell_type": "code",
"execution_count": 12,
"id": "18361f89",
"metadata": {
"vscode": {
"languageId": "python"
}
},
"metadata": {},
"outputs": [
{
"name": "stdout",

View File

@ -35,19 +35,16 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 2,
"id": "91f1ca7f-a748-44c7-a1c6-a89a2d1414ba",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.schema import SystemMessage\n",
"from langchain.prompts import (\n",
" FewShotChatMessagePromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
" AIMessagePromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
" ChatPromptTemplate,\n",
")"
]
},
@ -61,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 3,
"id": "0fc5a02a-6249-4e92-95c3-30fff9671e8b",
"metadata": {
"tags": []
@ -84,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"id": "65e72ad1-9060-47d0-91a1-bc130c8b98ac",
"metadata": {
"tags": []
@ -103,9 +100,12 @@
],
"source": [
"# This is a prompt template used to format each individual example.\n",
"example_prompt = HumanMessagePromptTemplate.from_template(\n",
" \"{input}\"\n",
") + AIMessagePromptTemplate.from_template(\"{output}\")\n",
"example_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"human\", \"{input}\"),\n",
" (\"ai\", \"{output}\"),\n",
" ]\n",
")\n",
"few_shot_prompt = FewShotChatMessagePromptTemplate(\n",
" example_prompt=example_prompt,\n",
" examples=examples,\n",
@ -124,23 +124,25 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 5,
"id": "9f86d6d9-50de-41b6-b6c7-0f9980cc0187",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"final_prompt = (\n",
" SystemMessagePromptTemplate.from_template(\"You are wonderous wizard of math.\")\n",
" + few_shot_prompt\n",
" + HumanMessagePromptTemplate.from_template(\"{input}\")\n",
"final_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You are wonderous wizard of math.\"),\n",
" few_shot_prompt,\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 6,
"id": "97d443b1-6fae-4b36-bede-3ff7306288a3",
"metadata": {
"tags": []
@ -149,10 +151,10 @@
{
"data": {
"text/plain": [
"AIMessage(content=' Triangles do not have a \"square\". A square refers to a shape with 4 equal sides and 4 right angles. Triangles have 3 sides and 3 angles.\\n\\nThe area of a triangle can be calculated using the formula:\\n\\nA = 1/2 * b * h\\n\\nWhere:\\n\\nA is the area \\nb is the base (the length of one of the sides)\\nh is the height (the length from the base to the opposite vertex)\\n\\nSo the area depends on the specific dimensions of the triangle. There is no single \"square of a triangle\". The area can vary greatly between different triangles.', additional_kwargs={}, example=False)"
"AIMessage(content=' Triangles do not have a \"square\". A square refers to a shape with 4 equal sides and 4 right angles. Triangles have 3 sides and 3 angles.\\n\\nThe area of a triangle can be calculated using the formula:\\n\\nA = 1/2 * b * h\\n\\nWhere:\\n\\nA is the area \\nb is the base (the length of one of the sides)\\nh is the height (the length from the base to the opposite vertex)\\n\\nSo the area depends on the specific dimensions of the triangle. There is no single \"square of a triangle\". The area can vary greatly depending on the base and height measurements.', additional_kwargs={}, example=False)"
]
},
"execution_count": 14,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -182,7 +184,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 7,
"id": "6f7b5e86-4ca7-4edd-bf2b-9663030b2393",
"metadata": {
"tags": []
@ -204,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 8,
"id": "ad66f06a-66fd-4fcc-8166-5d0e3c801e57",
"metadata": {
"tags": []
@ -239,7 +241,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 9,
"id": "7790303a-f722-452e-8921-b14bdf20bdff",
"metadata": {
"tags": []
@ -252,7 +254,7 @@
" {'input': '2+4', 'output': '6'}]"
]
},
"execution_count": 17,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -279,17 +281,17 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 10,
"id": "253c255e-41d7-45f6-9d88-c7a0ced4b1bd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.schema import SystemMessage\n",
"from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate\n",
"from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate\n",
"\n",
"from langchain.prompts import (\n",
" FewShotChatMessagePromptTemplate,\n",
" ChatPromptTemplate,\n",
")\n",
"\n",
"# Define the few-shot prompt.\n",
"few_shot_prompt = FewShotChatMessagePromptTemplate(\n",
@ -299,9 +301,8 @@
" # Define how each example will be formatted.\n",
" # In this case, each example will become 2 messages:\n",
" # 1 human, and 1 AI\n",
" example_prompt=(\n",
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
" + AIMessagePromptTemplate.from_template(\"{output}\")\n",
" example_prompt=ChatPromptTemplate.from_messages(\n",
" [(\"human\", \"{input}\"), (\"ai\", \"{output}\")]\n",
" ),\n",
")"
]
@ -316,7 +317,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 11,
"id": "860bf682-c469-40e9-b657-27bfe7026099",
"metadata": {
"tags": []
@ -347,23 +348,25 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 12,
"id": "e731cb45-f0ea-422c-be37-42af2a6cb2c4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"final_prompt = (\n",
" SystemMessagePromptTemplate.from_template(\"You are wonderous wizard of math.\")\n",
" + few_shot_prompt\n",
" + HumanMessagePromptTemplate.from_template(\"{input}\")\n",
"final_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You are wonderous wizard of math.\"),\n",
" few_shot_prompt,\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 13,
"id": "e6cc4199-8947-42d7-91f0-375de1e15bd9",
"metadata": {
"tags": []
@ -396,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 14,
"id": "0568cbc6-5354-47f1-ab4d-dfcc616cf583",
"metadata": {
"tags": []
@ -408,7 +411,7 @@
"AIMessage(content=' 3 + 3 = 6', additional_kwargs={}, example=False)"
]
},
"execution_count": 26,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}

View File

@ -101,7 +101,7 @@
},
"outputs": [],
"source": [
"from langchain.experimental.generative_agents import (\n",
"from langchain_experimental.generative_agents import (\n",
" GenerativeAgent,\n",
" GenerativeAgentMemory,\n",
")"

View File

@ -100,7 +100,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.experimental import AutoGPT\n",
"from langchain_experimental.autonomous_agents import AutoGPT\n",
"from langchain.chat_models import ChatOpenAI"
]
},
@ -124,43 +124,47 @@
},
{
"cell_type": "markdown",
"id": "f0f208d9",
"metadata": {
"collapsed": false
},
"source": [
"## Run an example\n",
"\n",
"Here we will make it write a weather report for SF"
],
"metadata": {
"collapsed": false
},
"id": "f0f208d9"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"agent.run([\"write a weather report for SF today\"])"
],
"id": "d119d788",
"metadata": {
"collapsed": false
},
"id": "d119d788"
"outputs": [],
"source": [
"agent.run([\"write a weather report for SF today\"])"
]
},
{
"cell_type": "markdown",
"id": "f13f8322",
"metadata": {
"collapsed": false
},
"source": [
"## Chat History Memory\n",
"\n",
"In addition to the memory that holds the agent immediate steps, we also have a chat history memory. By default, the agent will use 'ChatMessageHistory' and it can be changed. This is useful when you want to use a different type of memory for example 'FileChatHistoryMemory'"
],
"metadata": {
"collapsed": false
},
"id": "f13f8322"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2a81f5ad",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain.memory.chat_message_histories import FileChatMessageHistory\n",
@ -173,19 +177,15 @@
" memory=vectorstore.as_retriever(),\n",
" chat_history_memory=FileChatMessageHistory(\"chat_history.txt\"),\n",
")"
],
"metadata": {
"collapsed": false
},
"id": "2a81f5ad"
]
},
{
"cell_type": "markdown",
"source": [],
"id": "b1403008",
"metadata": {
"collapsed": false
},
"id": "b1403008"
"source": []
}
],
"metadata": {
@ -209,4 +209,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@ -39,7 +39,7 @@
"from langchain.vectorstores.base import VectorStore\n",
"from pydantic import BaseModel, Field\n",
"from langchain.chains.base import Chain\n",
"from langchain.experimental import BabyAGI"
"from langchain_experimental.autonomous_agents import BabyAGI"
]
},
{

View File

@ -35,7 +35,7 @@
"from langchain.vectorstores.base import VectorStore\n",
"from pydantic import BaseModel, Field\n",
"from langchain.chains.base import Chain\n",
"from langchain.experimental import BabyAGI"
"from langchain_experimental.autonomous_agents import BabyAGI"
]
},
{

View File

@ -36,7 +36,7 @@
"# General\n",
"import os\n",
"import pandas as pd\n",
"from langchain.experimental.autonomous_agents.autogpt.agent import AutoGPT\n",
"from langchain_experimental.autonomous_agents import AutoGPT\n",
"from langchain.chat_models import ChatOpenAI\n",
"\n",
"from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent\n",

File diff suppressed because one or more lines are too long

View File

@ -3,7 +3,7 @@
```python
from langchain.chat_models import ChatOpenAI
from langchain.experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner
from langchain_experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner
from langchain.llms import OpenAI
from langchain import SerpAPIWrapper
from langchain.agents.tools import Tool

View File

@ -11,6 +11,7 @@ from langchain.agents.agent_toolkits.gmail.toolkit import GmailToolkit
from langchain.agents.agent_toolkits.jira.toolkit import JiraToolkit
from langchain.agents.agent_toolkits.json.base import create_json_agent
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.multion.base import create_multion_agent
from langchain.agents.agent_toolkits.nla.toolkit import NLAToolkit
from langchain.agents.agent_toolkits.office365.toolkit import O365Toolkit
from langchain.agents.agent_toolkits.openapi.base import create_openapi_agent
@ -63,6 +64,7 @@ __all__ = [
"create_pbi_agent",
"create_pbi_chat_agent",
"create_python_agent",
"create_multion_agent",
"create_spark_dataframe_agent",
"create_spark_sql_agent",
"create_sql_agent",

View File

@ -457,6 +457,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
messages: Sequence[
Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
@ -515,7 +516,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
# Automatically infer input variables from messages
input_vars = set()
for _message in _messages:
if isinstance(_message, BaseMessagePromptTemplate):
if isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages)
@ -643,12 +646,13 @@ def _create_template_from_message_type(
def _convert_to_message(
message: Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
str,
]
) -> Union[BaseMessage, BaseMessagePromptTemplate]:
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
@ -665,8 +669,10 @@ def _convert_to_message(
Returns:
an instance of a message or a message template
"""
if isinstance(message, BaseMessagePromptTemplate):
_message: Union[BaseMessage, BaseMessagePromptTemplate] = message
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
_message: Union[
BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate
] = message
elif isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):

View File

@ -7,17 +7,15 @@ from unittest import mock
import pydantic
import pytest
from langchain import OpenAI
from langchain.experimental.cpal.base import (
from langchain_experimental.cpal.base import (
CausalChain,
CPALChain,
InterventionChain,
NarrativeChain,
QueryChain,
)
from langchain.experimental.cpal.constants import Constant
from langchain.experimental.cpal.models import (
from langchain_experimental.cpal.constants import Constant
from langchain_experimental.cpal.models import (
CausalModel,
EntityModel,
EntitySettingModel,
@ -25,18 +23,20 @@ from langchain.experimental.cpal.models import (
NarrativeModel,
QueryModel,
)
from langchain.experimental.cpal.templates.univariate.causal import (
from langchain_experimental.cpal.templates.univariate.causal import (
template as causal_template,
)
from langchain.experimental.cpal.templates.univariate.intervention import (
from langchain_experimental.cpal.templates.univariate.intervention import (
template as intervention_template,
)
from langchain.experimental.cpal.templates.univariate.narrative import (
from langchain_experimental.cpal.templates.univariate.narrative import (
template as narrative_template,
)
from langchain.experimental.cpal.templates.univariate.query import (
from langchain_experimental.cpal.templates.univariate.query import (
template as query_template,
)
from langchain import OpenAI
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -274,7 +274,7 @@ class TestUnitCPALChain_MathWordProblems(unittest.TestCase):
patch required since `networkx` package is not part of unit test environment
"""
with mock.patch(
"langchain.experimental.cpal.models.NetworkxEntityGraph"
"langchain_experimental.cpal.models.NetworkxEntityGraph"
) as mock_networkx:
graph_instance = mock_networkx.return_value
graph_instance.get_topological_sort.return_value = [