diff --git a/langchain/__init__.py b/langchain/__init__.py index 5be850781f9..b0cfdb48047 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -38,11 +38,11 @@ from langchain.llms import ( ) from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.prompts import ( - BasePromptTemplate, FewShotPromptTemplate, Prompt, PromptTemplate, ) +from langchain.schema.prompt_template import BasePromptTemplate from langchain.sql_database import SQLDatabase from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.google_search import GoogleSearchAPIWrapper diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 9c6c6b0c22d..63837514762 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -26,13 +26,13 @@ from langchain.callbacks.manager import ( from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( AgentAction, AgentFinish, BaseOutputParser, + BasePromptTemplate, OutputParserException, ) from langchain.schema.messages import BaseMessage diff --git a/langchain/agents/agent_toolkits/openapi/planner.py b/langchain/agents/agent_toolkits/openapi/planner.py index bcb79cca759..87cd01ab8d8 100644 --- a/langchain/agents/agent_toolkits/openapi/planner.py +++ b/langchain/agents/agent_toolkits/openapi/planner.py @@ -34,8 +34,8 @@ from langchain.chains.llm import LLMChain from langchain.llms.openai import OpenAI from langchain.memory import ReadOnlySharedMemory from langchain.prompts import PromptTemplate -from langchain.prompts.base import BasePromptTemplate from langchain.requests import RequestsWrapper +from langchain.schema import BasePromptTemplate from langchain.tools.base import BaseTool from langchain.tools.requests.tool import BaseRequestsTool diff --git a/langchain/agents/agent_toolkits/pandas/base.py b/langchain/agents/agent_toolkits/pandas/base.py index a695dc8eedd..687fc401bd0 100644 --- a/langchain/agents/agent_toolkits/pandas/base.py +++ b/langchain/agents/agent_toolkits/pandas/base.py @@ -19,7 +19,7 @@ from langchain.agents.types import AgentType from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.schema.messages import SystemMessage from langchain.tools.python.tool import PythonAstREPLTool diff --git a/langchain/agents/chat/base.py b/langchain/agents/chat/base.py index 4ea1f23e0b7..62415f9385c 100644 --- a/langchain/agents/chat/base.py +++ b/langchain/agents/chat/base.py @@ -14,13 +14,12 @@ from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.schema import AgentAction +from langchain.schema import AgentAction, BasePromptTemplate from langchain.tools.base import BaseTool diff --git a/langchain/agents/conversational_chat/base.py b/langchain/agents/conversational_chat/base.py index c399a10a65d..8eb310b30fb 100644 --- a/langchain/agents/conversational_chat/base.py +++ b/langchain/agents/conversational_chat/base.py @@ -16,17 +16,13 @@ from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, ) -from langchain.schema import ( - AgentAction, - BaseOutputParser, -) +from langchain.schema import AgentAction, BaseOutputParser, BasePromptTemplate from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from langchain.tools.base import BaseTool diff --git a/langchain/agents/openai_functions_agent/base.py b/langchain/agents/openai_functions_agent/base.py index 619820d4615..748222ce939 100644 --- a/langchain/agents/openai_functions_agent/base.py +++ b/langchain/agents/openai_functions_agent/base.py @@ -11,7 +11,6 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.chat_models.openai import ChatOpenAI -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( BaseMessagePromptTemplate, ChatPromptTemplate, @@ -21,6 +20,7 @@ from langchain.prompts.chat import ( from langchain.schema import ( AgentAction, AgentFinish, + BasePromptTemplate, OutputParserException, ) from langchain.schema.messages import ( diff --git a/langchain/agents/openai_functions_multi_agent/base.py b/langchain/agents/openai_functions_multi_agent/base.py index ba7849d02c2..71cf2845a1a 100644 --- a/langchain/agents/openai_functions_multi_agent/base.py +++ b/langchain/agents/openai_functions_multi_agent/base.py @@ -11,7 +11,6 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.chat_models.openai import ChatOpenAI -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( BaseMessagePromptTemplate, ChatPromptTemplate, @@ -21,6 +20,7 @@ from langchain.prompts.chat import ( from langchain.schema import ( AgentAction, AgentFinish, + BasePromptTemplate, OutputParserException, ) from langchain.schema.messages import ( diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index 38b01d28c58..e7cbfb6466b 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -13,7 +13,7 @@ from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.docstore.base import Docstore from langchain.docstore.document import Document -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.tools.base import BaseTool diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index e1662273de2..e63322fc0d6 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -10,7 +10,7 @@ from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.tools.base import BaseTool from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.serpapi import SerpAPIWrapper diff --git a/langchain/agents/structured_chat/base.py b/langchain/agents/structured_chat/base.py index 75860183423..57514e63753 100644 --- a/langchain/agents/structured_chat/base.py +++ b/langchain/agents/structured_chat/base.py @@ -11,13 +11,12 @@ from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.schema import AgentAction +from langchain.schema import AgentAction, BasePromptTemplate from langchain.tools import BaseTool HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" diff --git a/langchain/callbacks/tracers/evaluation.py b/langchain/callbacks/tracers/evaluation.py index 52f83b85f91..aefd5dd4f60 100644 --- a/langchain/callbacks/tracers/evaluation.py +++ b/langchain/callbacks/tracers/evaluation.py @@ -1,5 +1,4 @@ """A tracer that runs evaluators over completed runs.""" -import logging from concurrent.futures import Future, ThreadPoolExecutor, wait from typing import Any, Optional, Sequence, Set, Union from uuid import UUID @@ -9,8 +8,6 @@ from langchainplus_sdk import LangChainPlusClient, RunEvaluator from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run -logger = logging.getLogger(__name__) - class EvaluatorCallbackHandler(BaseTracer): """A tracer that runs a run evaluator whenever a run is persisted. @@ -50,7 +47,7 @@ class EvaluatorCallbackHandler(BaseTracer): max_workers: Optional[int] = None, client: Optional[LangChainPlusClient] = None, example_id: Optional[Union[UUID, str]] = None, - **kwargs: Any, + **kwargs: Any ) -> None: super().__init__(**kwargs) self.example_id = ( @@ -63,17 +60,6 @@ class EvaluatorCallbackHandler(BaseTracer): ) self.futures: Set[Future] = set() - def _evaluate_run(self, run: Run, evaluator: RunEvaluator) -> None: - try: - self.client.evaluate_run(run, evaluator) - except Exception as e: - logger.error( - f"Error evaluating run {run.id} with " - f"{evaluator.__class__.__name__}: {e}", - exc_info=True, - ) - raise e - def _persist_run(self, run: Run) -> None: """Run the evaluator on the run. @@ -86,7 +72,9 @@ class EvaluatorCallbackHandler(BaseTracer): run_ = run.copy() run_.reference_example_id = self.example_id for evaluator in self.evaluators: - self.futures.add(self.executor.submit(self._evaluate_run, run_, evaluator)) + self.futures.add( + self.executor.submit(self.client.evaluate_run, run_, evaluator) + ) def wait_for_futures(self) -> None: """Wait for all futures to complete.""" diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 7e199fe42fa..9eb4bc2fcdd 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -13,8 +13,8 @@ from langchain.callbacks.manager import ( from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.prompts import BasePromptTemplate from langchain.requests import TextRequestsWrapper +from langchain.schema import BasePromptTemplate class APIChain(Chain): diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index 338ea26a843..92e7838ff65 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -11,7 +11,7 @@ from langchain.callbacks.manager import ( ) from langchain.chains.base import Chain from langchain.docstore.document import Document -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index 4b480090589..fac234719c0 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -13,8 +13,8 @@ from langchain.chains.combine_documents.base import ( ) from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BasePromptTemplate def _get_default_document_prompt() -> PromptTemplate: diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index d39ce632c80..e4859ff6cad 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -11,8 +11,8 @@ from langchain.chains.combine_documents.base import ( ) from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BasePromptTemplate def _get_default_document_prompt() -> PromptTemplate: diff --git a/langchain/chains/constitutional_ai/base.py b/langchain/chains/constitutional_ai/base.py index bf342120b07..75eeda79cd5 100644 --- a/langchain/chains/constitutional_ai/base.py +++ b/langchain/chains/constitutional_ai/base.py @@ -8,7 +8,7 @@ from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.principles import PRINCIPLES from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT from langchain.chains.llm import LLMChain -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate class ConstitutionalChain(Chain): diff --git a/langchain/chains/conversation/base.py b/langchain/chains/conversation/base.py index a42705ea26f..43d72f91e7c 100644 --- a/langchain/chains/conversation/base.py +++ b/langchain/chains/conversation/base.py @@ -6,8 +6,7 @@ from pydantic import Extra, Field, root_validator from langchain.chains.conversation.prompt import PROMPT from langchain.chains.llm import LLMChain from langchain.memory.buffer import ConversationBufferMemory -from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseMemory +from langchain.schema import BaseMemory, BasePromptTemplate class ConversationChain(LLMChain): diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index d369fb31a4e..179bb6f202a 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -21,8 +21,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain -from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseRetriever, Document +from langchain.schema import BasePromptTemplate, BaseRetriever, Document from langchain.schema.messages import BaseMessage from langchain.vectorstores.base import VectorStore diff --git a/langchain/chains/flare/base.py b/langchain/chains/flare/base.py index 02f457e8684..c52e91c3ed2 100644 --- a/langchain/chains/flare/base.py +++ b/langchain/chains/flare/base.py @@ -19,8 +19,7 @@ from langchain.chains.flare.prompts import ( ) from langchain.chains.llm import LLMChain from langchain.llms import OpenAI -from langchain.prompts import BasePromptTemplate -from langchain.schema import BaseRetriever, Generation +from langchain.schema import BasePromptTemplate, BaseRetriever, Generation class _ResponseChain(LLMChain): diff --git a/langchain/chains/graph_qa/base.py b/langchain/chains/graph_qa/base.py index 36cff24d9f1..44dceca8ed4 100644 --- a/langchain/chains/graph_qa/base.py +++ b/langchain/chains/graph_qa/base.py @@ -11,7 +11,7 @@ from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate class GraphQAChain(Chain): diff --git a/langchain/chains/graph_qa/cypher.py b/langchain/chains/graph_qa/cypher.py index 95756a49ad3..f50195b1a0c 100644 --- a/langchain/chains/graph_qa/cypher.py +++ b/langchain/chains/graph_qa/cypher.py @@ -12,7 +12,7 @@ from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.neo4j_graph import Neo4jGraph -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate INTERMEDIATE_STEPS_KEY = "intermediate_steps" diff --git a/langchain/chains/graph_qa/kuzu.py b/langchain/chains/graph_qa/kuzu.py index c8de4a268be..d373d79f951 100644 --- a/langchain/chains/graph_qa/kuzu.py +++ b/langchain/chains/graph_qa/kuzu.py @@ -11,7 +11,7 @@ from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.kuzu_graph import KuzuGraph -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate class KuzuQAChain(Chain): diff --git a/langchain/chains/graph_qa/nebulagraph.py b/langchain/chains/graph_qa/nebulagraph.py index ab4048fde89..377559e7076 100644 --- a/langchain/chains/graph_qa/nebulagraph.py +++ b/langchain/chains/graph_qa/nebulagraph.py @@ -11,7 +11,7 @@ from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.nebula_graph import NebulaGraph -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate class NebulaGraphQAChain(Chain): diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index a8c7f15597e..bfb2244d686 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -17,10 +17,10 @@ from langchain.callbacks.manager import ( from langchain.chains.base import Chain from langchain.input import get_colored_text from langchain.load.dump import dumpd -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( BaseLLMOutputParser, + BasePromptTemplate, LLMResult, NoOpOutputParser, PromptValue, diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 468c0ba7507..f2ec384b655 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -12,8 +12,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT -from langchain.prompts.base import BasePromptTemplate -from langchain.schema import OutputParserException +from langchain.schema import BasePromptTemplate, OutputParserException from langchain.utilities.bash import BashProcess logger = logging.getLogger(__name__) diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index b737b57f3fb..bce99676712 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -17,7 +17,7 @@ from langchain.callbacks.manager import ( from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate class LLMMathChain(Chain): diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 5474fc7974b..732d489e87d 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -17,7 +17,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.text_splitter import TextSplitter diff --git a/langchain/chains/openai_functions/openapi.py b/langchain/chains/openai_functions/openapi.py index 9e8558e7e8b..349ee4ffde8 100644 --- a/langchain/chains/openai_functions/openapi.py +++ b/langchain/chains/openai_functions/openapi.py @@ -7,7 +7,7 @@ import requests from openapi_schema_pydantic import Parameter from requests import Response -from langchain import BasePromptTemplate, LLMChain +from langchain import LLMChain from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -16,6 +16,7 @@ from langchain.chat_models import ChatOpenAI from langchain.input import get_colored_text from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain.prompts import ChatPromptTemplate +from langchain.schema import BasePromptTemplate from langchain.tools import APIOperation from langchain.utilities.openapi import OpenAPISpec diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 275680a8bb1..3e3ada74be5 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -15,7 +15,7 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.utilities import PythonREPL diff --git a/langchain/chains/prompt_selector.py b/langchain/chains/prompt_selector.py index d2660d59155..4112aba33af 100644 --- a/langchain/chains/prompt_selector.py +++ b/langchain/chains/prompt_selector.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field from langchain.base_language import BaseLanguageModel from langchain.chat_models.base import BaseChatModel from langchain.llms.base import BaseLLM -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate class BasePromptSelector(BaseModel, ABC): diff --git a/langchain/chains/qa_generation/base.py b/langchain/chains/qa_generation/base.py index 1c0ae6b9784..9cc2383570b 100644 --- a/langchain/chains/qa_generation/base.py +++ b/langchain/chains/qa_generation/base.py @@ -10,7 +10,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index fe2911b9c4c..ba882310b63 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -26,7 +26,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import ( QUESTION_PROMPT, ) from langchain.docstore.document import Document -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate class BaseQAWithSourcesChain(Chain, ABC): diff --git a/langchain/chains/qa_with_sources/loading.py b/langchain/chains/qa_with_sources/loading.py index 97b76474dfb..57ea76a7631 100644 --- a/langchain/chains/qa_with_sources/loading.py +++ b/langchain/chains/qa_with_sources/loading.py @@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources import ( from langchain.chains.question_answering.map_rerank_prompt import ( PROMPT as MAP_RERANK_PROMPT, ) -from langchain.prompts.base import BasePromptTemplate +from langchain.schema.prompt_template import BasePromptTemplate class LoadingCallable(Protocol): diff --git a/langchain/chains/query_constructor/base.py b/langchain/chains/query_constructor/base.py index 18452c2cdf7..6111ca46a86 100644 --- a/langchain/chains/query_constructor/base.py +++ b/langchain/chains/query_constructor/base.py @@ -4,7 +4,7 @@ from __future__ import annotations import json from typing import Any, Callable, List, Optional, Sequence -from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain +from langchain import FewShotPromptTemplate, LLMChain from langchain.base_language import BaseLanguageModel from langchain.chains.query_constructor.ir import ( Comparator, @@ -23,7 +23,7 @@ from langchain.chains.query_constructor.prompt import ( ) from langchain.chains.query_constructor.schema import AttributeInfo from langchain.output_parsers.json import parse_and_check_json_markdown -from langchain.schema import BaseOutputParser, OutputParserException +from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index b2cb1d52d9f..a6eb09e2325 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -18,7 +18,7 @@ from langchain.chains.question_answering import ( from langchain.chains.question_answering.map_rerank_prompt import ( PROMPT as MAP_RERANK_PROMPT, ) -from langchain.prompts.base import BasePromptTemplate +from langchain.schema.prompt_template import BasePromptTemplate class LoadingCallable(Protocol): diff --git a/langchain/chains/router/llm_router.py b/langchain/chains/router/llm_router.py index cf8392c1dda..27a0e69d551 100644 --- a/langchain/chains/router/llm_router.py +++ b/langchain/chains/router/llm_router.py @@ -13,8 +13,7 @@ from langchain.callbacks.manager import ( from langchain.chains import LLMChain from langchain.chains.router.base import RouterChain from langchain.output_parsers.json import parse_and_check_json_markdown -from langchain.prompts import BasePromptTemplate -from langchain.schema import BaseOutputParser, OutputParserException +from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException class LLMRouterChain(RouterChain): diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 843a041bc05..a7f9649cb48 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -11,8 +11,8 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BasePromptTemplate from langchain.sql_database import SQLDatabase from langchain.tools.sql_database.prompt import QUERY_CHECKER diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index 6fc835dd0f9..fa88cea1ffe 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -8,7 +8,7 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate class LoadingCallable(Protocol): diff --git a/langchain/evaluation/criteria/eval_chain.py b/langchain/evaluation/criteria/eval_chain.py index ec2863a27b9..c7d6a1f7c4a 100644 --- a/langchain/evaluation/criteria/eval_chain.py +++ b/langchain/evaluation/criteria/eval_chain.py @@ -8,8 +8,7 @@ from langchain.base_language import BaseLanguageModel from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES -from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseOutputParser +from langchain.schema import BaseOutputParser, BasePromptTemplate _SUPPORTED_CRITERIA = { "conciseness": "Is the submission concise and to the point?", diff --git a/langchain/evaluation/qa/eval_chain.py b/langchain/evaluation/qa/eval_chain.py index bba4c298263..4a5fd53a1e2 100644 --- a/langchain/evaluation/qa/eval_chain.py +++ b/langchain/evaluation/qa/eval_chain.py @@ -19,20 +19,14 @@ def _parse_string_eval_output(text: str) -> dict: Returns: Any: The parsed output. """ - splits = text.strip().rsplit("\n", maxsplit=1) - if len(splits) == 1: - verdict = splits[0] - reasoning = None - else: - reasoning, verdict = splits - reasoning = reasoning.strip() + reasoning, verdict = text.strip().rsplit("\n", maxsplit=1) score = ( 1 if verdict.upper() == "CORRECT" else (0 if verdict.upper() == "INCORRECT" else None) ) return { - "reasoning": reasoning, + "reasoning": reasoning.strip(), "value": verdict, "score": score, } diff --git a/langchain/evaluation/run_evaluators/implementations.py b/langchain/evaluation/run_evaluators/implementations.py index d4cf0d5cede..24f50f2dbdb 100644 --- a/langchain/evaluation/run_evaluators/implementations.py +++ b/langchain/evaluation/run_evaluators/implementations.py @@ -23,9 +23,8 @@ from langchain.evaluation.run_evaluators.base import ( RunEvaluatorInputMapper, RunEvaluatorOutputParser, ) -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import OutputParserException +from langchain.schema import BasePromptTemplate, OutputParserException from langchain.tools.base import BaseTool _QA_PROMPTS = { diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 5c06bd197de..1da66f3da37 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -13,7 +13,7 @@ from langchain.memory.prompt import ( ENTITY_SUMMARIZATION_PROMPT, ) from langchain.memory.utils import get_prompt_input_key -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.schema.messages import BaseMessage, get_buffer_string logger = logging.getLogger(__name__) diff --git a/langchain/memory/kg.py b/langchain/memory/kg.py index fd071c066d6..bad45eda9cc 100644 --- a/langchain/memory/kg.py +++ b/langchain/memory/kg.py @@ -12,7 +12,7 @@ from langchain.memory.prompt import ( KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, ) from langchain.memory.utils import get_prompt_input_key -from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BasePromptTemplate from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index 1fd58196be5..7345cd72759 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -8,9 +8,9 @@ from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.prompts.base import BasePromptTemplate from langchain.schema import ( BaseChatMessageHistory, + BasePromptTemplate, ) from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string diff --git a/langchain/output_parsers/fix.py b/langchain/output_parsers/fix.py index 166d570f81c..602ff828905 100644 --- a/langchain/output_parsers/fix.py +++ b/langchain/output_parsers/fix.py @@ -5,8 +5,7 @@ from typing import TypeVar from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT -from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseOutputParser, OutputParserException +from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException T = TypeVar("T") diff --git a/langchain/output_parsers/retry.py b/langchain/output_parsers/retry.py index 7d9a797379c..c0061e4278d 100644 --- a/langchain/output_parsers/retry.py +++ b/langchain/output_parsers/retry.py @@ -4,10 +4,10 @@ from typing import TypeVar from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( BaseOutputParser, + BasePromptTemplate, OutputParserException, PromptValue, ) diff --git a/langchain/prompts/__init__.py b/langchain/prompts/__init__.py index 66f05c10d5b..dbe135605d6 100644 --- a/langchain/prompts/__init__.py +++ b/langchain/prompts/__init__.py @@ -1,5 +1,5 @@ """Prompt template classes.""" -from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate +from langchain.prompts.base import StringPromptTemplate from langchain.prompts.chat import ( AIMessagePromptTemplate, BaseChatPromptTemplate, @@ -20,6 +20,7 @@ from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates from langchain.prompts.loading import load_prompt from langchain.prompts.pipeline import PipelinePromptTemplate from langchain.prompts.prompt import Prompt, PromptTemplate +from langchain.schema.prompt_template import BasePromptTemplate __all__ = [ "AIMessagePromptTemplate", diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 6852a50558d..3d1a13e7699 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,18 +1,13 @@ """BasePrompt schema definition.""" from __future__ import annotations -import json -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union - -import yaml -from pydantic import Field, root_validator +from abc import ABC +from typing import Any, Callable, Dict, List, Set from langchain.formatting import formatter -from langchain.load.serializable import Serializable -from langchain.schema import BaseOutputParser, PromptValue +from langchain.schema import BasePromptTemplate from langchain.schema.messages import BaseMessage, HumanMessage +from langchain.schema.prompt import PromptValue def jinja2_formatter(template: str, **kwargs: Any) -> str: @@ -110,133 +105,6 @@ class StringPromptValue(PromptValue): return [HumanMessage(content=self.text)] -class BasePromptTemplate(Serializable, ABC): - """Base class for all prompt templates, returning a prompt.""" - - input_variables: List[str] - """A list of the names of the variables the prompt template expects.""" - output_parser: Optional[BaseOutputParser] = None - """How to parse the output of calling an LLM on this formatted prompt.""" - partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( - default_factory=dict - ) - - @property - def lc_serializable(self) -> bool: - return True - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @abstractmethod - def format_prompt(self, **kwargs: Any) -> PromptValue: - """Create Chat Messages.""" - - @root_validator() - def validate_variable_names(cls, values: Dict) -> Dict: - """Validate variable names do not include restricted names.""" - if "stop" in values["input_variables"]: - raise ValueError( - "Cannot have an input variable named 'stop', as it is used internally," - " please rename." - ) - if "stop" in values["partial_variables"]: - raise ValueError( - "Cannot have an partial variable named 'stop', as it is used " - "internally, please rename." - ) - - overall = set(values["input_variables"]).intersection( - values["partial_variables"] - ) - if overall: - raise ValueError( - f"Found overlapping input and partial variables: {overall}" - ) - return values - - def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: - """Return a partial of the prompt template.""" - prompt_dict = self.__dict__.copy() - prompt_dict["input_variables"] = list( - set(self.input_variables).difference(kwargs) - ) - prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} - return type(self)(**prompt_dict) - - def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: - # Get partial params: - partial_kwargs = { - k: v if isinstance(v, str) else v() - for k, v in self.partial_variables.items() - } - return {**partial_kwargs, **kwargs} - - @abstractmethod - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. - - Args: - kwargs: Any arguments to be passed to the prompt template. - - Returns: - A formatted string. - - Example: - - .. code-block:: python - - prompt.format(variable1="foo") - """ - - @property - def _prompt_type(self) -> str: - """Return the prompt type key.""" - raise NotImplementedError - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of prompt.""" - prompt_dict = super().dict(**kwargs) - prompt_dict["_type"] = self._prompt_type - return prompt_dict - - def save(self, file_path: Union[Path, str]) -> None: - """Save the prompt. - - Args: - file_path: Path to directory to save prompt to. - - Example: - .. code-block:: python - - prompt.save(file_path="path/prompt.yaml") - """ - if self.partial_variables: - raise ValueError("Cannot save prompt with partial variables.") - # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path - - directory_path = save_path.parent - directory_path.mkdir(parents=True, exist_ok=True) - - # Fetch dictionary to save - prompt_dict = self.dict() - - if save_path.suffix == ".json": - with open(file_path, "w") as f: - json.dump(prompt_dict, f, indent=4) - elif save_path.suffix == ".yaml": - with open(file_path, "w") as f: - yaml.dump(prompt_dict, f, default_flow_style=False) - else: - raise ValueError(f"{save_path} must be json or yaml") - - class StringPromptTemplate(BasePromptTemplate, ABC): """String prompt should expose the format method, returning a prompt.""" diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 59cb44aa336..0264b8596bc 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -8,9 +8,10 @@ from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union from pydantic import Field, root_validator from langchain.load.serializable import Serializable -from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate +from langchain.prompts.base import StringPromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( + BasePromptTemplate, PromptValue, ) from langchain.schema.messages import ( diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index 20c8f8d7880..cc7507cc685 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -8,10 +8,9 @@ from typing import Union import yaml from langchain.output_parsers.regex import RegexParser -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseLLMOutputParser, NoOpOutputParser +from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, NoOpOutputParser from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" diff --git a/langchain/prompts/pipeline.py b/langchain/prompts/pipeline.py index a3f7860bc04..41e34b17ed9 100644 --- a/langchain/prompts/pipeline.py +++ b/langchain/prompts/pipeline.py @@ -2,9 +2,8 @@ from typing import Any, Dict, List, Tuple from pydantic import root_validator -from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import BaseChatPromptTemplate -from langchain.schema import PromptValue +from langchain.schema import BasePromptTemplate, PromptValue def _get_inputs(inputs: dict, input_variables: List[str]) -> dict: diff --git a/langchain/retrievers/document_compressors/chain_filter.py b/langchain/retrievers/document_compressors/chain_filter.py index ad44158fc81..ae038175a9e 100644 --- a/langchain/retrievers/document_compressors/chain_filter.py +++ b/langchain/retrievers/document_compressors/chain_filter.py @@ -1,7 +1,7 @@ """Filter that uses an LLM to drop documents that aren't relevant to the query.""" from typing import Any, Callable, Dict, Optional, Sequence -from langchain import BasePromptTemplate, LLMChain, PromptTemplate +from langchain import LLMChain, PromptTemplate from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import Callbacks from langchain.output_parsers.boolean import BooleanOutputParser @@ -9,7 +9,7 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompresso from langchain.retrievers.document_compressors.chain_filter_prompt import ( prompt_template, ) -from langchain.schema import Document +from langchain.schema import BasePromptTemplate, Document def _get_default_chain_prompt() -> PromptTemplate: diff --git a/langchain/schema/__init__.py b/langchain/schema/__init__.py index 0821f2fb3fa..13983206bfe 100644 --- a/langchain/schema/__init__.py +++ b/langchain/schema/__init__.py @@ -28,6 +28,7 @@ from langchain.schema.output_parser import ( OutputParserException, ) from langchain.schema.prompt import PromptValue +from langchain.schema.prompt_template import BasePromptTemplate from langchain.schema.retriever import BaseRetriever RUN_KEY = "__run" @@ -64,4 +65,5 @@ __all__ = [ "NoOpOutputParser", "BaseOutputParser", "BaseLLMOutputParser", + "BasePromptTemplate", ] diff --git a/langchain/schema/prompt_template.py b/langchain/schema/prompt_template.py new file mode 100644 index 00000000000..6ed048df4aa --- /dev/null +++ b/langchain/schema/prompt_template.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping, Optional, Union + +import yaml +from pydantic import Field, root_validator + +from langchain.load.serializable import Serializable +from langchain.schema import BaseOutputParser, PromptValue + + +class BasePromptTemplate(Serializable, ABC): + """Base class for all prompt templates, returning a prompt.""" + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + output_parser: Optional[BaseOutputParser] = None + """How to parse the output of calling an LLM on this formatted prompt.""" + partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( + default_factory=dict + ) + + @property + def lc_serializable(self) -> bool: + return True + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @abstractmethod + def format_prompt(self, **kwargs: Any) -> PromptValue: + """Create Chat Messages.""" + + @root_validator() + def validate_variable_names(cls, values: Dict) -> Dict: + """Validate variable names do not include restricted names.""" + if "stop" in values["input_variables"]: + raise ValueError( + "Cannot have an input variable named 'stop', as it is used internally," + " please rename." + ) + if "stop" in values["partial_variables"]: + raise ValueError( + "Cannot have an partial variable named 'stop', as it is used " + "internally, please rename." + ) + + overall = set(values["input_variables"]).intersection( + values["partial_variables"] + ) + if overall: + raise ValueError( + f"Found overlapping input and partial variables: {overall}" + ) + return values + + def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: + """Return a partial of the prompt template.""" + prompt_dict = self.__dict__.copy() + prompt_dict["input_variables"] = list( + set(self.input_variables).difference(kwargs) + ) + prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} + return type(self)(**prompt_dict) + + def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: + # Get partial params: + partial_kwargs = { + k: v if isinstance(v, str) else v() + for k, v in self.partial_variables.items() + } + return {**partial_kwargs, **kwargs} + + @abstractmethod + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of prompt.""" + prompt_dict = super().dict(**kwargs) + prompt_dict["_type"] = self._prompt_type + return prompt_dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the prompt. + + Args: + file_path: Path to directory to save prompt to. + + Example: + .. code-block:: python + + prompt.save(file_path="path/prompt.yaml") + """ + if self.partial_variables: + raise ValueError("Cannot save prompt with partial variables.") + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + prompt_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(prompt_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(prompt_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml")