mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 17:45:25 +00:00
move base prompt to schema (#6995)
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
200be43da6
commit
60b05511d3
@ -38,11 +38,11 @@ from langchain.llms import (
|
||||
)
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.prompts import (
|
||||
BasePromptTemplate,
|
||||
FewShotPromptTemplate,
|
||||
Prompt,
|
||||
PromptTemplate,
|
||||
)
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
||||
|
@ -26,13 +26,13 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseOutputParser,
|
||||
BasePromptTemplate,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
@ -34,8 +34,8 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.memory import ReadOnlySharedMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.requests.tool import BaseRequestsTool
|
||||
|
||||
|
@ -19,7 +19,7 @@ from langchain.agents.types import AgentType
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
@ -14,13 +14,12 @@ from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.schema import AgentAction, BasePromptTemplate
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
@ -16,17 +16,13 @@ from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain.schema import AgentAction, BaseOutputParser, BasePromptTemplate
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
@ -11,7 +11,6 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
BaseMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
@ -21,6 +20,7 @@ from langchain.prompts.chat import (
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BasePromptTemplate,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
|
@ -11,7 +11,6 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
BaseMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
@ -21,6 +20,7 @@ from langchain.prompts.chat import (
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BasePromptTemplate,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
|
@ -13,7 +13,7 @@ from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
||||
from langchain.utilities.serpapi import SerpAPIWrapper
|
||||
|
@ -11,13 +11,12 @@ from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX,
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.schema import AgentAction, BasePromptTemplate
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""A tracer that runs evaluators over completed runs."""
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from typing import Any, Optional, Sequence, Set, Union
|
||||
from uuid import UUID
|
||||
@ -9,8 +8,6 @@ from langchainplus_sdk import LangChainPlusClient, RunEvaluator
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluatorCallbackHandler(BaseTracer):
|
||||
"""A tracer that runs a run evaluator whenever a run is persisted.
|
||||
@ -50,7 +47,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
max_workers: Optional[int] = None,
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
example_id: Optional[Union[UUID, str]] = None,
|
||||
**kwargs: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.example_id = (
|
||||
@ -63,17 +60,6 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
)
|
||||
self.futures: Set[Future] = set()
|
||||
|
||||
def _evaluate_run(self, run: Run, evaluator: RunEvaluator) -> None:
|
||||
try:
|
||||
self.client.evaluate_run(run, evaluator)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error evaluating run {run.id} with "
|
||||
f"{evaluator.__class__.__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Run the evaluator on the run.
|
||||
|
||||
@ -86,7 +72,9 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
run_ = run.copy()
|
||||
run_.reference_example_id = self.example_id
|
||||
for evaluator in self.evaluators:
|
||||
self.futures.add(self.executor.submit(self._evaluate_run, run_, evaluator))
|
||||
self.futures.add(
|
||||
self.executor.submit(self.client.evaluate_run, run_, evaluator)
|
||||
)
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for all futures to complete."""
|
||||
|
@ -13,8 +13,8 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.requests import TextRequestsWrapper
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class APIChain(Chain):
|
||||
|
@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
|
@ -13,8 +13,8 @@ from langchain.chains.combine_documents.base import (
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
|
@ -11,8 +11,8 @@ from langchain.chains.combine_documents.base import (
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
|
@ -8,7 +8,7 @@ from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
||||
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class ConstitutionalChain(Chain):
|
||||
|
@ -6,8 +6,7 @@ from pydantic import Extra, Field, root_validator
|
||||
from langchain.chains.conversation.prompt import PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseMemory
|
||||
from langchain.schema import BaseMemory, BasePromptTemplate
|
||||
|
||||
|
||||
class ConversationChain(LLMChain):
|
||||
|
@ -21,8 +21,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
@ -19,8 +19,7 @@ from langchain.chains.flare.prompts import (
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema import BaseRetriever, Generation
|
||||
from langchain.schema import BasePromptTemplate, BaseRetriever, Generation
|
||||
|
||||
|
||||
class _ResponseChain(LLMChain):
|
||||
|
@ -11,7 +11,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class GraphQAChain(Chain):
|
||||
|
@ -12,7 +12,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.neo4j_graph import Neo4jGraph
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
@ -11,7 +11,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.kuzu_graph import KuzuGraph
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class KuzuQAChain(Chain):
|
||||
|
@ -11,7 +11,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.nebula_graph import NebulaGraph
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class NebulaGraphQAChain(Chain):
|
||||
|
@ -17,10 +17,10 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
BaseLLMOutputParser,
|
||||
BasePromptTemplate,
|
||||
LLMResult,
|
||||
NoOpOutputParser,
|
||||
PromptValue,
|
||||
|
@ -12,8 +12,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import OutputParserException
|
||||
from langchain.schema import BasePromptTemplate, OutputParserException
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class LLMMathChain(Chain):
|
||||
|
@ -17,7 +17,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ import requests
|
||||
from openapi_schema_pydantic import Parameter
|
||||
from requests import Response
|
||||
|
||||
from langchain import BasePromptTemplate, LLMChain
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
@ -16,6 +16,7 @@ from langchain.chat_models import ChatOpenAI
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.tools import APIOperation
|
||||
from langchain.utilities.openapi import OpenAPISpec
|
||||
|
||||
|
@ -15,7 +15,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class BasePromptSelector(BaseModel, ABC):
|
||||
|
@ -10,7 +10,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
QUESTION_PROMPT,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class BaseQAWithSourcesChain(Chain, ABC):
|
||||
|
@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources import (
|
||||
from langchain.chains.question_answering.map_rerank_prompt import (
|
||||
PROMPT as MAP_RERANK_PROMPT,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
|
||||
from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain
|
||||
from langchain import FewShotPromptTemplate, LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
@ -23,7 +23,7 @@ from langchain.chains.query_constructor.prompt import (
|
||||
)
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||
|
||||
|
||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
|
@ -18,7 +18,7 @@ from langchain.chains.question_answering import (
|
||||
from langchain.chains.question_answering.map_rerank_prompt import (
|
||||
PROMPT as MAP_RERANK_PROMPT,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
@ -13,8 +13,7 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.router.base import RouterChain
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||
|
||||
|
||||
class LLMRouterChain(RouterChain):
|
||||
|
@ -11,8 +11,8 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
@ -8,7 +8,7 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
@ -8,8 +8,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate
|
||||
|
||||
_SUPPORTED_CRITERIA = {
|
||||
"conciseness": "Is the submission concise and to the point?",
|
||||
|
@ -19,20 +19,14 @@ def _parse_string_eval_output(text: str) -> dict:
|
||||
Returns:
|
||||
Any: The parsed output.
|
||||
"""
|
||||
splits = text.strip().rsplit("\n", maxsplit=1)
|
||||
if len(splits) == 1:
|
||||
verdict = splits[0]
|
||||
reasoning = None
|
||||
else:
|
||||
reasoning, verdict = splits
|
||||
reasoning = reasoning.strip()
|
||||
reasoning, verdict = text.strip().rsplit("\n", maxsplit=1)
|
||||
score = (
|
||||
1
|
||||
if verdict.upper() == "CORRECT"
|
||||
else (0 if verdict.upper() == "INCORRECT" else None)
|
||||
)
|
||||
return {
|
||||
"reasoning": reasoning,
|
||||
"reasoning": reasoning.strip(),
|
||||
"value": verdict,
|
||||
"score": score,
|
||||
}
|
||||
|
@ -23,9 +23,8 @@ from langchain.evaluation.run_evaluators.base import (
|
||||
RunEvaluatorInputMapper,
|
||||
RunEvaluatorOutputParser,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import OutputParserException
|
||||
from langchain.schema import BasePromptTemplate, OutputParserException
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
_QA_PROMPTS = {
|
||||
|
@ -13,7 +13,7 @@ from langchain.memory.prompt import (
|
||||
ENTITY_SUMMARIZATION_PROMPT,
|
||||
)
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -12,7 +12,7 @@ from langchain.memory.prompt import (
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||
)
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
|
||||
|
||||
|
@ -8,9 +8,9 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BasePromptTemplate,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
|
||||
|
@ -5,8 +5,7 @@ from typing import TypeVar
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -4,10 +4,10 @@ from typing import TypeVar
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
BaseOutputParser,
|
||||
BasePromptTemplate,
|
||||
OutputParserException,
|
||||
PromptValue,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Prompt template classes."""
|
||||
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
|
||||
from langchain.prompts.base import StringPromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
@ -20,6 +20,7 @@ from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
||||
from langchain.prompts.loading import load_prompt
|
||||
from langchain.prompts.pipeline import PipelinePromptTemplate
|
||||
from langchain.prompts.prompt import Prompt, PromptTemplate
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
|
||||
__all__ = [
|
||||
"AIMessagePromptTemplate",
|
||||
|
@ -1,18 +1,13 @@
|
||||
"""BasePrompt schema definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import Field, root_validator
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
|
||||
from langchain.formatting import formatter
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import BaseOutputParser, PromptValue
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from langchain.schema.prompt import PromptValue
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
@ -110,133 +105,6 @@ class StringPromptValue(PromptValue):
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
class BasePromptTemplate(Serializable, ABC):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
if overall:
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||
"""Return a partial of the prompt template."""
|
||||
prompt_dict = self.__dict__.copy()
|
||||
prompt_dict["input_variables"] = list(
|
||||
set(self.input_variables).difference(kwargs)
|
||||
)
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
# Get partial params:
|
||||
partial_kwargs = {
|
||||
k: v if isinstance(v, str) else v()
|
||||
for k, v in self.partial_variables.items()
|
||||
}
|
||||
return {**partial_kwargs, **kwargs}
|
||||
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of prompt."""
|
||||
prompt_dict = super().dict(**kwargs)
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
return prompt_dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the prompt.
|
||||
|
||||
Args:
|
||||
file_path: Path to directory to save prompt to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt.save(file_path="path/prompt.yaml")
|
||||
"""
|
||||
if self.partial_variables:
|
||||
raise ValueError("Cannot save prompt with partial variables.")
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""String prompt should expose the format method, returning a prompt."""
|
||||
|
||||
|
@ -8,9 +8,10 @@ from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
|
||||
from langchain.prompts.base import StringPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
BasePromptTemplate,
|
||||
PromptValue,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
|
@ -8,10 +8,9 @@ from typing import Union
|
||||
import yaml
|
||||
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseLLMOutputParser, NoOpOutputParser
|
||||
from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, NoOpOutputParser
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||
|
@ -2,9 +2,8 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import BaseChatPromptTemplate
|
||||
from langchain.schema import PromptValue
|
||||
from langchain.schema import BasePromptTemplate, PromptValue
|
||||
|
||||
|
||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
from langchain import BasePromptTemplate, LLMChain, PromptTemplate
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
@ -9,7 +9,7 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompresso
|
||||
from langchain.retrievers.document_compressors.chain_filter_prompt import (
|
||||
prompt_template,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
from langchain.schema import BasePromptTemplate, Document
|
||||
|
||||
|
||||
def _get_default_chain_prompt() -> PromptTemplate:
|
||||
|
@ -28,6 +28,7 @@ from langchain.schema.output_parser import (
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
RUN_KEY = "__run"
|
||||
@ -64,4 +65,5 @@ __all__ = [
|
||||
"NoOpOutputParser",
|
||||
"BaseOutputParser",
|
||||
"BaseLLMOutputParser",
|
||||
"BasePromptTemplate",
|
||||
]
|
||||
|
139
langchain/schema/prompt_template.py
Normal file
139
langchain/schema/prompt_template.py
Normal file
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import BaseOutputParser, PromptValue
|
||||
|
||||
|
||||
class BasePromptTemplate(Serializable, ABC):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
if overall:
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||
"""Return a partial of the prompt template."""
|
||||
prompt_dict = self.__dict__.copy()
|
||||
prompt_dict["input_variables"] = list(
|
||||
set(self.input_variables).difference(kwargs)
|
||||
)
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
# Get partial params:
|
||||
partial_kwargs = {
|
||||
k: v if isinstance(v, str) else v()
|
||||
for k, v in self.partial_variables.items()
|
||||
}
|
||||
return {**partial_kwargs, **kwargs}
|
||||
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of prompt."""
|
||||
prompt_dict = super().dict(**kwargs)
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
return prompt_dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the prompt.
|
||||
|
||||
Args:
|
||||
file_path: Path to directory to save prompt to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt.save(file_path="path/prompt.yaml")
|
||||
"""
|
||||
if self.partial_variables:
|
||||
raise ValueError("Cannot save prompt with partial variables.")
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
Loading…
Reference in New Issue
Block a user