BUGFIX: add prompt imports for backwards compat (#13702)

This commit is contained in:
Bagatur 2023-11-21 23:04:20 -08:00 committed by GitHub
parent 78da34153e
commit 16af282429
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 186 additions and 15 deletions

View File

@ -41,7 +41,7 @@ from langchain_core.prompts.few_shot import (
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain_core.prompts.loading import load_prompt from langchain_core.prompts.loading import load_prompt
from langchain_core.prompts.pipeline import PipelinePromptTemplate from langchain_core.prompts.pipeline import PipelinePromptTemplate
from langchain_core.prompts.prompt import Prompt, PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import ( from langchain_core.prompts.string import (
StringPromptTemplate, StringPromptTemplate,
check_valid_template, check_valid_template,
@ -62,7 +62,6 @@ __all__ = [
"HumanMessagePromptTemplate", "HumanMessagePromptTemplate",
"MessagesPlaceholder", "MessagesPlaceholder",
"PipelinePromptTemplate", "PipelinePromptTemplate",
"Prompt",
"PromptTemplate", "PromptTemplate",
"StringPromptTemplate", "StringPromptTemplate",
"SystemMessagePromptTemplate", "SystemMessagePromptTemplate",

View File

@ -244,7 +244,3 @@ class PromptTemplate(StringPromptTemplate):
partial_variables=_partial_variables, partial_variables=_partial_variables,
**kwargs, **kwargs,
) )
# For backwards compatibility.
Prompt = PromptTemplate

View File

@ -13,7 +13,6 @@ EXPECTED_ALL = [
"HumanMessagePromptTemplate", "HumanMessagePromptTemplate",
"MessagesPlaceholder", "MessagesPlaceholder",
"PipelinePromptTemplate", "PipelinePromptTemplate",
"Prompt",
"PromptTemplate", "PromptTemplate",
"StringPromptTemplate", "StringPromptTemplate",
"SystemMessagePromptTemplate", "SystemMessagePromptTemplate",

View File

@ -44,7 +44,6 @@ from langchain_core.prompts import (
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
MessagesPlaceholder, MessagesPlaceholder,
PipelinePromptTemplate, PipelinePromptTemplate,
Prompt,
PromptTemplate, PromptTemplate,
StringPromptTemplate, StringPromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
@ -52,6 +51,7 @@ from langchain_core.prompts import (
) )
from langchain.prompts.example_selector import NGramOverlapExampleSelector from langchain.prompts.example_selector import NGramOverlapExampleSelector
from langchain.prompts.prompt import Prompt
__all__ = [ __all__ = [
"AIMessagePromptTemplate", "AIMessagePromptTemplate",
@ -67,11 +67,11 @@ __all__ = [
"MessagesPlaceholder", "MessagesPlaceholder",
"NGramOverlapExampleSelector", "NGramOverlapExampleSelector",
"PipelinePromptTemplate", "PipelinePromptTemplate",
"Prompt",
"PromptTemplate", "PromptTemplate",
"SemanticSimilarityExampleSelector", "SemanticSimilarityExampleSelector",
"StringPromptTemplate", "StringPromptTemplate",
"SystemMessagePromptTemplate", "SystemMessagePromptTemplate",
"load_prompt", "load_prompt",
"FewShotChatMessagePromptTemplate", "FewShotChatMessagePromptTemplate",
"Prompt",
] ]

View File

@ -1,3 +1,4 @@
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts import ( from langchain_core.prompts import (
BasePromptTemplate, BasePromptTemplate,
StringPromptTemplate, StringPromptTemplate,
@ -6,6 +7,7 @@ from langchain_core.prompts import (
jinja2_formatter, jinja2_formatter,
validate_jinja2, validate_jinja2,
) )
from langchain_core.prompts.string import _get_jinja2_variables_from_template
__all__ = [ __all__ = [
"jinja2_formatter", "jinja2_formatter",
@ -14,4 +16,6 @@ __all__ = [
"get_template_variables", "get_template_variables",
"StringPromptTemplate", "StringPromptTemplate",
"BasePromptTemplate", "BasePromptTemplate",
"StringPromptValue",
"_get_jinja2_variables_from_template",
] ]

View File

@ -1,3 +1,4 @@
from langchain_core.prompt_values import ChatPromptValue, ChatPromptValueConcrete
from langchain_core.prompts.chat import ( from langchain_core.prompts.chat import (
AIMessagePromptTemplate, AIMessagePromptTemplate,
BaseChatPromptTemplate, BaseChatPromptTemplate,
@ -8,6 +9,8 @@ from langchain_core.prompts.chat import (
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
MessagesPlaceholder, MessagesPlaceholder,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
_convert_to_message,
_create_template_from_message_type,
) )
__all__ = [ __all__ = [
@ -20,4 +23,8 @@ __all__ = [
"SystemMessagePromptTemplate", "SystemMessagePromptTemplate",
"BaseChatPromptTemplate", "BaseChatPromptTemplate",
"ChatPromptTemplate", "ChatPromptTemplate",
"ChatPromptValue",
"ChatPromptValueConcrete",
"_convert_to_message",
"_create_template_from_message_type",
] ]

View File

@ -1,6 +1,11 @@
from langchain_core.prompts.few_shot import ( from langchain_core.prompts.few_shot import (
FewShotChatMessagePromptTemplate, FewShotChatMessagePromptTemplate,
FewShotPromptTemplate, FewShotPromptTemplate,
_FewShotPromptTemplateMixin,
) )
__all__ = ["FewShotPromptTemplate", "FewShotChatMessagePromptTemplate"] __all__ = [
"FewShotPromptTemplate",
"FewShotChatMessagePromptTemplate",
"_FewShotPromptTemplateMixin",
]

View File

@ -1,4 +1,23 @@
from langchain_core.prompts.loading import load_prompt, load_prompt_from_config from langchain_core.prompts.loading import (
_load_examples,
_load_few_shot_prompt,
_load_output_parser,
_load_prompt,
_load_prompt_from_file,
_load_template,
load_prompt,
load_prompt_from_config,
)
from langchain_core.utils.loading import try_load_from_hub from langchain_core.utils.loading import try_load_from_hub
__all__ = ["load_prompt_from_config", "load_prompt", "try_load_from_hub"] __all__ = [
"load_prompt_from_config",
"load_prompt",
"try_load_from_hub",
"_load_examples",
"_load_few_shot_prompt",
"_load_output_parser",
"_load_prompt",
"_load_prompt_from_file",
"_load_template",
]

View File

@ -1,3 +1,3 @@
from langchain_core.prompts.pipeline import PipelinePromptTemplate from langchain_core.prompts.pipeline import PipelinePromptTemplate, _get_inputs
__all__ = ["PipelinePromptTemplate"] __all__ = ["PipelinePromptTemplate", "_get_inputs"]

View File

@ -1,3 +1,6 @@
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
__all__ = ["PromptTemplate"] # For backwards compatibility.
Prompt = PromptTemplate
__all__ = ["PromptTemplate", "Prompt"]

View File

@ -4,6 +4,9 @@ from langchain_core.tools import (
StructuredTool, StructuredTool,
Tool, Tool,
ToolException, ToolException,
_create_subset_model,
_get_filtered_args,
_SchemaConfig,
create_schema_from_function, create_schema_from_function,
tool, tool,
) )
@ -16,4 +19,7 @@ __all__ = [
"Tool", "Tool",
"StructuredTool", "StructuredTool",
"tool", "tool",
"_SchemaConfig",
"_create_subset_model",
"_get_filtered_args",
] ]

View File

@ -0,0 +1 @@
"""Test prompt functionality."""

View File

@ -0,0 +1,16 @@
from langchain.prompts.base import __all__
EXPECTED_ALL = [
"BasePromptTemplate",
"StringPromptTemplate",
"StringPromptValue",
"_get_jinja2_variables_from_template",
"check_valid_template",
"get_template_variables",
"jinja2_formatter",
"validate_jinja2",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,21 @@
from langchain.prompts.chat import __all__
EXPECTED_ALL = [
"AIMessagePromptTemplate",
"BaseChatPromptTemplate",
"BaseMessagePromptTemplate",
"BaseStringMessagePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"ChatPromptValue",
"ChatPromptValueConcrete",
"HumanMessagePromptTemplate",
"MessagesPlaceholder",
"SystemMessagePromptTemplate",
"_convert_to_message",
"_create_template_from_message_type",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,11 @@
from langchain.prompts.few_shot import __all__
EXPECTED_ALL = [
"FewShotChatMessagePromptTemplate",
"FewShotPromptTemplate",
"_FewShotPromptTemplateMixin",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.prompts.few_shot_with_templates import __all__
EXPECTED_ALL = ["FewShotPromptWithTemplates"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,28 @@
from langchain.prompts import __all__
EXPECTED_ALL = [
"AIMessagePromptTemplate",
"BaseChatPromptTemplate",
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"HumanMessagePromptTemplate",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"MessagesPlaceholder",
"NGramOverlapExampleSelector",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,17 @@
from langchain.prompts.loading import __all__
EXPECTED_ALL = [
"_load_examples",
"_load_few_shot_prompt",
"_load_output_parser",
"_load_prompt",
"_load_prompt_from_file",
"_load_template",
"load_prompt",
"load_prompt_from_config",
"try_load_from_hub",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.prompts.pipeline import __all__
EXPECTED_ALL = ["PipelinePromptTemplate", "_get_inputs"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.prompts.prompt import __all__
EXPECTED_ALL = ["Prompt", "PromptTemplate"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,18 @@
from langchain.tools.base import __all__
EXPECTED_ALL = [
"BaseTool",
"SchemaAnnotationError",
"StructuredTool",
"Tool",
"ToolException",
"_SchemaConfig",
"_create_subset_model",
"_get_filtered_args",
"create_schema_from_function",
"tool",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)