Wfh/ref links (#8454)

This commit is contained in:
William FH
2023-07-29 08:44:32 -07:00
committed by GitHub
parent 13b4f465e2
commit b7c0eb9ecb
23 changed files with 189 additions and 379 deletions

View File

@@ -11,6 +11,7 @@ from langchain.agents.agent_toolkits.gmail.toolkit import GmailToolkit
from langchain.agents.agent_toolkits.jira.toolkit import JiraToolkit
from langchain.agents.agent_toolkits.json.base import create_json_agent
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.multion.base import create_multion_agent
from langchain.agents.agent_toolkits.nla.toolkit import NLAToolkit
from langchain.agents.agent_toolkits.office365.toolkit import O365Toolkit
from langchain.agents.agent_toolkits.openapi.base import create_openapi_agent
@@ -63,6 +64,7 @@ __all__ = [
"create_pbi_agent",
"create_pbi_chat_agent",
"create_python_agent",
"create_multion_agent",
"create_spark_dataframe_agent",
"create_spark_sql_agent",
"create_sql_agent",

View File

@@ -457,6 +457,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
messages: Sequence[
Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
@@ -515,7 +516,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
# Automatically infer input variables from messages
input_vars = set()
for _message in _messages:
if isinstance(_message, BaseMessagePromptTemplate):
if isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages)
@@ -643,12 +646,13 @@ def _create_template_from_message_type(
def _convert_to_message(
message: Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
str,
]
) -> Union[BaseMessage, BaseMessagePromptTemplate]:
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
@@ -665,8 +669,10 @@ def _convert_to_message(
Returns:
an instance of a message or a message template
"""
if isinstance(message, BaseMessagePromptTemplate):
_message: Union[BaseMessage, BaseMessagePromptTemplate] = message
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
_message: Union[
BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate
] = message
elif isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):

View File

@@ -7,17 +7,15 @@ from unittest import mock
import pydantic
import pytest
from langchain import OpenAI
from langchain.experimental.cpal.base import (
from langchain_experimental.cpal.base import (
CausalChain,
CPALChain,
InterventionChain,
NarrativeChain,
QueryChain,
)
from langchain.experimental.cpal.constants import Constant
from langchain.experimental.cpal.models import (
from langchain_experimental.cpal.constants import Constant
from langchain_experimental.cpal.models import (
CausalModel,
EntityModel,
EntitySettingModel,
@@ -25,18 +23,20 @@ from langchain.experimental.cpal.models import (
NarrativeModel,
QueryModel,
)
from langchain.experimental.cpal.templates.univariate.causal import (
from langchain_experimental.cpal.templates.univariate.causal import (
template as causal_template,
)
from langchain.experimental.cpal.templates.univariate.intervention import (
from langchain_experimental.cpal.templates.univariate.intervention import (
template as intervention_template,
)
from langchain.experimental.cpal.templates.univariate.narrative import (
from langchain_experimental.cpal.templates.univariate.narrative import (
template as narrative_template,
)
from langchain.experimental.cpal.templates.univariate.query import (
from langchain_experimental.cpal.templates.univariate.query import (
template as query_template,
)
from langchain import OpenAI
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -274,7 +274,7 @@ class TestUnitCPALChain_MathWordProblems(unittest.TestCase):
patch required since `networkx` package is not part of unit test environment
"""
with mock.patch(
"langchain.experimental.cpal.models.NetworkxEntityGraph"
"langchain_experimental.cpal.models.NetworkxEntityGraph"
) as mock_networkx:
graph_instance = mock_networkx.return_value
graph_instance.get_topological_sort.return_value = [