mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-20 01:49:51 +00:00
move base prompt to schema (#6995)
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
200be43da6
commit
60b05511d3
@ -38,11 +38,11 @@ from langchain.llms import (
|
|||||||
)
|
)
|
||||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
from langchain.prompts import (
|
from langchain.prompts import (
|
||||||
BasePromptTemplate,
|
|
||||||
FewShotPromptTemplate,
|
FewShotPromptTemplate,
|
||||||
Prompt,
|
Prompt,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||||
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
||||||
|
@ -26,13 +26,13 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping
|
from langchain.input import get_color_mapping
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
BaseOutputParser,
|
BaseOutputParser,
|
||||||
|
BasePromptTemplate,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
|
@ -34,8 +34,8 @@ from langchain.chains.llm import LLMChain
|
|||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
from langchain.memory import ReadOnlySharedMemory
|
from langchain.memory import ReadOnlySharedMemory
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.requests import RequestsWrapper
|
from langchain.requests import RequestsWrapper
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.tools.requests.tool import BaseRequestsTool
|
from langchain.tools.requests.tool import BaseRequestsTool
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from langchain.agents.types import AgentType
|
|||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.schema.messages import SystemMessage
|
from langchain.schema.messages import SystemMessage
|
||||||
from langchain.tools.python.tool import PythonAstREPLTool
|
from langchain.tools.python.tool import PythonAstREPLTool
|
||||||
|
|
||||||
|
@ -14,13 +14,12 @@ from langchain.agents.utils import validate_tools_single_input
|
|||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AgentAction
|
from langchain.schema import AgentAction, BasePromptTemplate
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,17 +16,13 @@ from langchain.agents.utils import validate_tools_single_input
|
|||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import (
|
from langchain.schema import AgentAction, BaseOutputParser, BasePromptTemplate
|
||||||
AgentAction,
|
|
||||||
BaseOutputParser,
|
|
||||||
)
|
|
||||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
BaseMessagePromptTemplate,
|
BaseMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
@ -21,6 +20,7 @@ from langchain.prompts.chat import (
|
|||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
|
BasePromptTemplate,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
|
@ -11,7 +11,6 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
BaseMessagePromptTemplate,
|
BaseMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
@ -21,6 +20,7 @@ from langchain.prompts.chat import (
|
|||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
|
BasePromptTemplate,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
|
@ -13,7 +13,7 @@ from langchain.agents.utils import validate_tools_single_input
|
|||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.docstore.base import Docstore
|
from langchain.docstore.base import Docstore
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from langchain.agents.self_ask_with_search.prompt import PROMPT
|
|||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.agents.utils import validate_tools_single_input
|
from langchain.agents.utils import validate_tools_single_input
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
||||||
from langchain.utilities.serpapi import SerpAPIWrapper
|
from langchain.utilities.serpapi import SerpAPIWrapper
|
||||||
|
@ -11,13 +11,12 @@ from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX,
|
|||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AgentAction
|
from langchain.schema import AgentAction, BasePromptTemplate
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
"""A tracer that runs evaluators over completed runs."""
|
"""A tracer that runs evaluators over completed runs."""
|
||||||
import logging
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||||
from typing import Any, Optional, Sequence, Set, Union
|
from typing import Any, Optional, Sequence, Set, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -9,8 +8,6 @@ from langchainplus_sdk import LangChainPlusClient, RunEvaluator
|
|||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluatorCallbackHandler(BaseTracer):
|
class EvaluatorCallbackHandler(BaseTracer):
|
||||||
"""A tracer that runs a run evaluator whenever a run is persisted.
|
"""A tracer that runs a run evaluator whenever a run is persisted.
|
||||||
@ -50,7 +47,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
client: Optional[LangChainPlusClient] = None,
|
client: Optional[LangChainPlusClient] = None,
|
||||||
example_id: Optional[Union[UUID, str]] = None,
|
example_id: Optional[Union[UUID, str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.example_id = (
|
self.example_id = (
|
||||||
@ -63,17 +60,6 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
)
|
)
|
||||||
self.futures: Set[Future] = set()
|
self.futures: Set[Future] = set()
|
||||||
|
|
||||||
def _evaluate_run(self, run: Run, evaluator: RunEvaluator) -> None:
|
|
||||||
try:
|
|
||||||
self.client.evaluate_run(run, evaluator)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error evaluating run {run.id} with "
|
|
||||||
f"{evaluator.__class__.__name__}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
"""Run the evaluator on the run.
|
"""Run the evaluator on the run.
|
||||||
|
|
||||||
@ -86,7 +72,9 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
run_ = run.copy()
|
run_ = run.copy()
|
||||||
run_.reference_example_id = self.example_id
|
run_.reference_example_id = self.example_id
|
||||||
for evaluator in self.evaluators:
|
for evaluator in self.evaluators:
|
||||||
self.futures.add(self.executor.submit(self._evaluate_run, run_, evaluator))
|
self.futures.add(
|
||||||
|
self.executor.submit(self.client.evaluate_run, run_, evaluator)
|
||||||
|
)
|
||||||
|
|
||||||
def wait_for_futures(self) -> None:
|
def wait_for_futures(self) -> None:
|
||||||
"""Wait for all futures to complete."""
|
"""Wait for all futures to complete."""
|
||||||
|
@ -13,8 +13,8 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts import BasePromptTemplate
|
|
||||||
from langchain.requests import TextRequestsWrapper
|
from langchain.requests import TextRequestsWrapper
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class APIChain(Chain):
|
class APIChain(Chain):
|
||||||
|
@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ from langchain.chains.combine_documents.base import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
def _get_default_document_prompt() -> PromptTemplate:
|
def _get_default_document_prompt() -> PromptTemplate:
|
||||||
|
@ -11,8 +11,8 @@ from langchain.chains.combine_documents.base import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
def _get_default_document_prompt() -> PromptTemplate:
|
def _get_default_document_prompt() -> PromptTemplate:
|
||||||
|
@ -8,7 +8,7 @@ from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
|||||||
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
||||||
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class ConstitutionalChain(Chain):
|
class ConstitutionalChain(Chain):
|
||||||
|
@ -6,8 +6,7 @@ from pydantic import Extra, Field, root_validator
|
|||||||
from langchain.chains.conversation.prompt import PROMPT
|
from langchain.chains.conversation.prompt import PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.buffer import ConversationBufferMemory
|
from langchain.memory.buffer import ConversationBufferMemory
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BaseMemory, BasePromptTemplate
|
||||||
from langchain.schema import BaseMemory
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationChain(LLMChain):
|
class ConversationChain(LLMChain):
|
||||||
|
@ -21,8 +21,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|||||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.question_answering import load_qa_chain
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
||||||
from langchain.schema import BaseRetriever, Document
|
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
@ -19,8 +19,7 @@ from langchain.chains.flare.prompts import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate, BaseRetriever, Generation
|
||||||
from langchain.schema import BaseRetriever, Generation
|
|
||||||
|
|
||||||
|
|
||||||
class _ResponseChain(LLMChain):
|
class _ResponseChain(LLMChain):
|
||||||
|
@ -11,7 +11,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
|
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
|
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class GraphQAChain(Chain):
|
class GraphQAChain(Chain):
|
||||||
|
@ -12,7 +12,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.graphs.neo4j_graph import Neo4jGraph
|
from langchain.graphs.neo4j_graph import Neo4jGraph
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT
|
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.graphs.kuzu_graph import KuzuGraph
|
from langchain.graphs.kuzu_graph import KuzuGraph
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class KuzuQAChain(Chain):
|
class KuzuQAChain(Chain):
|
||||||
|
@ -11,7 +11,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
|
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.graphs.nebula_graph import NebulaGraph
|
from langchain.graphs.nebula_graph import NebulaGraph
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class NebulaGraphQAChain(Chain):
|
class NebulaGraphQAChain(Chain):
|
||||||
|
@ -17,10 +17,10 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.input import get_colored_text
|
from langchain.input import get_colored_text
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
BaseLLMOutputParser,
|
BaseLLMOutputParser,
|
||||||
|
BasePromptTemplate,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
NoOpOutputParser,
|
NoOpOutputParser,
|
||||||
PromptValue,
|
PromptValue,
|
||||||
|
@ -12,8 +12,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_bash.prompt import PROMPT
|
from langchain.chains.llm_bash.prompt import PROMPT
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate, OutputParserException
|
||||||
from langchain.schema import OutputParserException
|
|
||||||
from langchain.utilities.bash import BashProcess
|
from langchain.utilities.bash import BashProcess
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_math.prompt import PROMPT
|
from langchain.chains.llm_math.prompt import PROMPT
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class LLMMathChain(Chain):
|
class LLMMathChain(Chain):
|
||||||
|
@ -17,7 +17,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
|
|||||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.text_splitter import TextSplitter
|
from langchain.text_splitter import TextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import requests
|
|||||||
from openapi_schema_pydantic import Parameter
|
from openapi_schema_pydantic import Parameter
|
||||||
from requests import Response
|
from requests import Response
|
||||||
|
|
||||||
from langchain import BasePromptTemplate, LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@ -16,6 +16,7 @@ from langchain.chat_models import ChatOpenAI
|
|||||||
from langchain.input import get_colored_text
|
from langchain.input import get_colored_text
|
||||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.tools import APIOperation
|
from langchain.tools import APIOperation
|
||||||
from langchain.utilities.openapi import OpenAPISpec
|
from langchain.utilities.openapi import OpenAPISpec
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.utilities import PythonREPL
|
from langchain.utilities import PythonREPL
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
|||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class BasePromptSelector(BaseModel, ABC):
|
class BasePromptSelector(BaseModel, ABC):
|
||||||
|
@ -10,7 +10,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
|
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
|||||||
QUESTION_PROMPT,
|
QUESTION_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class BaseQAWithSourcesChain(Chain, ABC):
|
class BaseQAWithSourcesChain(Chain, ABC):
|
||||||
|
@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources import (
|
|||||||
from langchain.chains.question_answering.map_rerank_prompt import (
|
from langchain.chains.question_answering.map_rerank_prompt import (
|
||||||
PROMPT as MAP_RERANK_PROMPT,
|
PROMPT as MAP_RERANK_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class LoadingCallable(Protocol):
|
class LoadingCallable(Protocol):
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Callable, List, Optional, Sequence
|
from typing import Any, Callable, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain
|
from langchain import FewShotPromptTemplate, LLMChain
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.chains.query_constructor.ir import (
|
from langchain.chains.query_constructor.ir import (
|
||||||
Comparator,
|
Comparator,
|
||||||
@ -23,7 +23,7 @@ from langchain.chains.query_constructor.prompt import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||||
|
|
||||||
|
|
||||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||||
|
@ -18,7 +18,7 @@ from langchain.chains.question_answering import (
|
|||||||
from langchain.chains.question_answering.map_rerank_prompt import (
|
from langchain.chains.question_answering.map_rerank_prompt import (
|
||||||
PROMPT as MAP_RERANK_PROMPT,
|
PROMPT as MAP_RERANK_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class LoadingCallable(Protocol):
|
class LoadingCallable(Protocol):
|
||||||
|
@ -13,8 +13,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.chains.router.base import RouterChain
|
from langchain.chains.router.base import RouterChain
|
||||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRouterChain(RouterChain):
|
class LLMRouterChain(RouterChain):
|
||||||
|
@ -11,8 +11,8 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
|||||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class LoadingCallable(Protocol):
|
class LoadingCallable(Protocol):
|
||||||
|
@ -8,8 +8,7 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
|
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BaseOutputParser, BasePromptTemplate
|
||||||
from langchain.schema import BaseOutputParser
|
|
||||||
|
|
||||||
_SUPPORTED_CRITERIA = {
|
_SUPPORTED_CRITERIA = {
|
||||||
"conciseness": "Is the submission concise and to the point?",
|
"conciseness": "Is the submission concise and to the point?",
|
||||||
|
@ -19,20 +19,14 @@ def _parse_string_eval_output(text: str) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
Any: The parsed output.
|
Any: The parsed output.
|
||||||
"""
|
"""
|
||||||
splits = text.strip().rsplit("\n", maxsplit=1)
|
reasoning, verdict = text.strip().rsplit("\n", maxsplit=1)
|
||||||
if len(splits) == 1:
|
|
||||||
verdict = splits[0]
|
|
||||||
reasoning = None
|
|
||||||
else:
|
|
||||||
reasoning, verdict = splits
|
|
||||||
reasoning = reasoning.strip()
|
|
||||||
score = (
|
score = (
|
||||||
1
|
1
|
||||||
if verdict.upper() == "CORRECT"
|
if verdict.upper() == "CORRECT"
|
||||||
else (0 if verdict.upper() == "INCORRECT" else None)
|
else (0 if verdict.upper() == "INCORRECT" else None)
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"reasoning": reasoning,
|
"reasoning": reasoning.strip(),
|
||||||
"value": verdict,
|
"value": verdict,
|
||||||
"score": score,
|
"score": score,
|
||||||
}
|
}
|
||||||
|
@ -23,9 +23,8 @@ from langchain.evaluation.run_evaluators.base import (
|
|||||||
RunEvaluatorInputMapper,
|
RunEvaluatorInputMapper,
|
||||||
RunEvaluatorOutputParser,
|
RunEvaluatorOutputParser,
|
||||||
)
|
)
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import OutputParserException
|
from langchain.schema import BasePromptTemplate, OutputParserException
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
_QA_PROMPTS = {
|
_QA_PROMPTS = {
|
||||||
|
@ -13,7 +13,7 @@ from langchain.memory.prompt import (
|
|||||||
ENTITY_SUMMARIZATION_PROMPT,
|
ENTITY_SUMMARIZATION_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.memory.utils import get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -12,7 +12,7 @@ from langchain.memory.prompt import (
|
|||||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.memory.utils import get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
|
BasePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||||
|
|
||||||
|
@ -5,8 +5,7 @@ from typing import TypeVar
|
|||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@ -4,10 +4,10 @@ from typing import TypeVar
|
|||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
BaseOutputParser,
|
BaseOutputParser,
|
||||||
|
BasePromptTemplate,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
PromptValue,
|
PromptValue,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Prompt template classes."""
|
"""Prompt template classes."""
|
||||||
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
|
from langchain.prompts.base import StringPromptTemplate
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
BaseChatPromptTemplate,
|
BaseChatPromptTemplate,
|
||||||
@ -20,6 +20,7 @@ from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
|||||||
from langchain.prompts.loading import load_prompt
|
from langchain.prompts.loading import load_prompt
|
||||||
from langchain.prompts.pipeline import PipelinePromptTemplate
|
from langchain.prompts.pipeline import PipelinePromptTemplate
|
||||||
from langchain.prompts.prompt import Prompt, PromptTemplate
|
from langchain.prompts.prompt import Prompt, PromptTemplate
|
||||||
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIMessagePromptTemplate",
|
"AIMessagePromptTemplate",
|
||||||
|
@ -1,18 +1,13 @@
|
|||||||
"""BasePrompt schema definition."""
|
"""BasePrompt schema definition."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
from abc import ABC
|
||||||
from abc import ABC, abstractmethod
|
from typing import Any, Callable, Dict, List, Set
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from pydantic import Field, root_validator
|
|
||||||
|
|
||||||
from langchain.formatting import formatter
|
from langchain.formatting import formatter
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.schema import BaseOutputParser, PromptValue
|
|
||||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||||
|
from langchain.schema.prompt import PromptValue
|
||||||
|
|
||||||
|
|
||||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||||
@ -110,133 +105,6 @@ class StringPromptValue(PromptValue):
|
|||||||
return [HumanMessage(content=self.text)]
|
return [HumanMessage(content=self.text)]
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(Serializable, ABC):
|
|
||||||
"""Base class for all prompt templates, returning a prompt."""
|
|
||||||
|
|
||||||
input_variables: List[str]
|
|
||||||
"""A list of the names of the variables the prompt template expects."""
|
|
||||||
output_parser: Optional[BaseOutputParser] = None
|
|
||||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
|
||||||
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
|
||||||
default_factory=dict
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_serializable(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
|
||||||
"""Create Chat Messages."""
|
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate variable names do not include restricted names."""
|
|
||||||
if "stop" in values["input_variables"]:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
|
||||||
" please rename."
|
|
||||||
)
|
|
||||||
if "stop" in values["partial_variables"]:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have an partial variable named 'stop', as it is used "
|
|
||||||
"internally, please rename."
|
|
||||||
)
|
|
||||||
|
|
||||||
overall = set(values["input_variables"]).intersection(
|
|
||||||
values["partial_variables"]
|
|
||||||
)
|
|
||||||
if overall:
|
|
||||||
raise ValueError(
|
|
||||||
f"Found overlapping input and partial variables: {overall}"
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
|
||||||
"""Return a partial of the prompt template."""
|
|
||||||
prompt_dict = self.__dict__.copy()
|
|
||||||
prompt_dict["input_variables"] = list(
|
|
||||||
set(self.input_variables).difference(kwargs)
|
|
||||||
)
|
|
||||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
|
||||||
return type(self)(**prompt_dict)
|
|
||||||
|
|
||||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
|
||||||
# Get partial params:
|
|
||||||
partial_kwargs = {
|
|
||||||
k: v if isinstance(v, str) else v()
|
|
||||||
for k, v in self.partial_variables.items()
|
|
||||||
}
|
|
||||||
return {**partial_kwargs, **kwargs}
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
|
||||||
"""Format the prompt with the inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kwargs: Any arguments to be passed to the prompt template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A formatted string.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
prompt.format(variable1="foo")
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _prompt_type(self) -> str:
|
|
||||||
"""Return the prompt type key."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
|
||||||
"""Return dictionary representation of prompt."""
|
|
||||||
prompt_dict = super().dict(**kwargs)
|
|
||||||
prompt_dict["_type"] = self._prompt_type
|
|
||||||
return prompt_dict
|
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
|
||||||
"""Save the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to directory to save prompt to.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
prompt.save(file_path="path/prompt.yaml")
|
|
||||||
"""
|
|
||||||
if self.partial_variables:
|
|
||||||
raise ValueError("Cannot save prompt with partial variables.")
|
|
||||||
# Convert file to Path object.
|
|
||||||
if isinstance(file_path, str):
|
|
||||||
save_path = Path(file_path)
|
|
||||||
else:
|
|
||||||
save_path = file_path
|
|
||||||
|
|
||||||
directory_path = save_path.parent
|
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Fetch dictionary to save
|
|
||||||
prompt_dict = self.dict()
|
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
|
||||||
with open(file_path, "w") as f:
|
|
||||||
json.dump(prompt_dict, f, indent=4)
|
|
||||||
elif save_path.suffix == ".yaml":
|
|
||||||
with open(file_path, "w") as f:
|
|
||||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
|
||||||
|
|
||||||
|
|
||||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||||
"""String prompt should expose the format method, returning a prompt."""
|
"""String prompt should expose the format method, returning a prompt."""
|
||||||
|
|
||||||
|
@ -8,9 +8,10 @@ from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
|||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
|
from langchain.prompts.base import StringPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
|
BasePromptTemplate,
|
||||||
PromptValue,
|
PromptValue,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
|
@ -8,10 +8,9 @@ from typing import Union
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from langchain.output_parsers.regex import RegexParser
|
from langchain.output_parsers.regex import RegexParser
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import BaseLLMOutputParser, NoOpOutputParser
|
from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, NoOpOutputParser
|
||||||
from langchain.utilities.loading import try_load_from_hub
|
from langchain.utilities.loading import try_load_from_hub
|
||||||
|
|
||||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||||
|
@ -2,9 +2,8 @@ from typing import Any, Dict, List, Tuple
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.chat import BaseChatPromptTemplate
|
from langchain.prompts.chat import BaseChatPromptTemplate
|
||||||
from langchain.schema import PromptValue
|
from langchain.schema import BasePromptTemplate, PromptValue
|
||||||
|
|
||||||
|
|
||||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""
|
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""
|
||||||
from typing import Any, Callable, Dict, Optional, Sequence
|
from typing import Any, Callable, Dict, Optional, Sequence
|
||||||
|
|
||||||
from langchain import BasePromptTemplate, LLMChain, PromptTemplate
|
from langchain import LLMChain, PromptTemplate
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||||
@ -9,7 +9,7 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompresso
|
|||||||
from langchain.retrievers.document_compressors.chain_filter_prompt import (
|
from langchain.retrievers.document_compressors.chain_filter_prompt import (
|
||||||
prompt_template,
|
prompt_template,
|
||||||
)
|
)
|
||||||
from langchain.schema import Document
|
from langchain.schema import BasePromptTemplate, Document
|
||||||
|
|
||||||
|
|
||||||
def _get_default_chain_prompt() -> PromptTemplate:
|
def _get_default_chain_prompt() -> PromptTemplate:
|
||||||
|
@ -28,6 +28,7 @@ from langchain.schema.output_parser import (
|
|||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
from langchain.schema.retriever import BaseRetriever
|
from langchain.schema.retriever import BaseRetriever
|
||||||
|
|
||||||
RUN_KEY = "__run"
|
RUN_KEY = "__run"
|
||||||
@ -64,4 +65,5 @@ __all__ = [
|
|||||||
"NoOpOutputParser",
|
"NoOpOutputParser",
|
||||||
"BaseOutputParser",
|
"BaseOutputParser",
|
||||||
"BaseLLMOutputParser",
|
"BaseLLMOutputParser",
|
||||||
|
"BasePromptTemplate",
|
||||||
]
|
]
|
||||||
|
139
langchain/schema/prompt_template.py
Normal file
139
langchain/schema/prompt_template.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema import BaseOutputParser, PromptValue
|
||||||
|
|
||||||
|
|
||||||
|
class BasePromptTemplate(Serializable, ABC):
|
||||||
|
"""Base class for all prompt templates, returning a prompt."""
|
||||||
|
|
||||||
|
input_variables: List[str]
|
||||||
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
|
output_parser: Optional[BaseOutputParser] = None
|
||||||
|
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||||
|
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
"""Create Chat Messages."""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate variable names do not include restricted names."""
|
||||||
|
if "stop" in values["input_variables"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||||
|
" please rename."
|
||||||
|
)
|
||||||
|
if "stop" in values["partial_variables"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have an partial variable named 'stop', as it is used "
|
||||||
|
"internally, please rename."
|
||||||
|
)
|
||||||
|
|
||||||
|
overall = set(values["input_variables"]).intersection(
|
||||||
|
values["partial_variables"]
|
||||||
|
)
|
||||||
|
if overall:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found overlapping input and partial variables: {overall}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||||
|
"""Return a partial of the prompt template."""
|
||||||
|
prompt_dict = self.__dict__.copy()
|
||||||
|
prompt_dict["input_variables"] = list(
|
||||||
|
set(self.input_variables).difference(kwargs)
|
||||||
|
)
|
||||||
|
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||||
|
return type(self)(**prompt_dict)
|
||||||
|
|
||||||
|
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
# Get partial params:
|
||||||
|
partial_kwargs = {
|
||||||
|
k: v if isinstance(v, str) else v()
|
||||||
|
for k, v in self.partial_variables.items()
|
||||||
|
}
|
||||||
|
return {**partial_kwargs, **kwargs}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format(self, **kwargs: Any) -> str:
|
||||||
|
"""Format the prompt with the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Any arguments to be passed to the prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt.format(variable1="foo")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _prompt_type(self) -> str:
|
||||||
|
"""Return the prompt type key."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
|
"""Return dictionary representation of prompt."""
|
||||||
|
prompt_dict = super().dict(**kwargs)
|
||||||
|
prompt_dict["_type"] = self._prompt_type
|
||||||
|
return prompt_dict
|
||||||
|
|
||||||
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
|
"""Save the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to directory to save prompt to.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt.save(file_path="path/prompt.yaml")
|
||||||
|
"""
|
||||||
|
if self.partial_variables:
|
||||||
|
raise ValueError("Cannot save prompt with partial variables.")
|
||||||
|
# Convert file to Path object.
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
save_path = Path(file_path)
|
||||||
|
else:
|
||||||
|
save_path = file_path
|
||||||
|
|
||||||
|
directory_path = save_path.parent
|
||||||
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Fetch dictionary to save
|
||||||
|
prompt_dict = self.dict()
|
||||||
|
|
||||||
|
if save_path.suffix == ".json":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump(prompt_dict, f, indent=4)
|
||||||
|
elif save_path.suffix == ".yaml":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{save_path} must be json or yaml")
|
Loading…
Reference in New Issue
Block a user