mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 18:19:21 +00:00
Compare commits
14 Commits
langchain-
...
harrison/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea6fa901c8 | ||
|
|
086ea90625 | ||
|
|
e1d5926fdc | ||
|
|
d41ec72cb6 | ||
|
|
2e13b33bc6 | ||
|
|
855ce1a0aa | ||
|
|
0091bb0cda | ||
|
|
84cd181835 | ||
|
|
96c5d4536f | ||
|
|
9d5a51ad8f | ||
|
|
5c74c93c61 | ||
|
|
38dc42b242 | ||
|
|
73f808ad68 | ||
|
|
b0f8c287c3 |
@@ -14,15 +14,6 @@ from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.tools import InvalidTool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForChainRun,
|
||||
CallbackManagerForToolRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
@@ -34,6 +25,15 @@ from langchain.schema import (
|
||||
BasePromptTemplate,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForChainRun,
|
||||
CallbackManagerForToolRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
@@ -6,8 +6,8 @@ from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
|
||||
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ from langchain.agents.agent_toolkits.openapi.prompt import (
|
||||
from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -28,13 +28,13 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
||||
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.memory import ReadOnlySharedMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.requests.tool import BaseRequestsTool
|
||||
|
||||
@@ -16,9 +16,9 @@ from langchain.agents.agent_toolkits.pandas.prompt import (
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.agents.types import AgentType
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
@@ -9,8 +9,8 @@ from langchain.agents.agent_toolkits.powerbi.prompt import (
|
||||
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.utilities.powerbi import PowerBIDataset
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ from langchain.agents.agent_toolkits.powerbi.prompt import (
|
||||
)
|
||||
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
from langchain.agents.conversational_chat.base import ConversationalChatAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.utilities.powerbi import PowerBIDataset
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import List, Optional, Union
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
@@ -13,6 +12,7 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.powerbi.prompt import (
|
||||
|
||||
@@ -7,8 +7,8 @@ from langchain.agents.agent_toolkits.python.prompt import PREFIX
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.agents.types import AgentType
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from langchain.tools.python.tool import PythonREPLTool
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUF
|
||||
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -12,13 +12,13 @@ from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage, SystemMessage
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from langchain.agents.agent_toolkits.vectorstore.toolkit import (
|
||||
VectorStoreToolkit,
|
||||
)
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from langchain.agents.chat.prompt import (
|
||||
SYSTEM_MESSAGE_SUFFIX,
|
||||
)
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
@@ -19,6 +18,7 @@ from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction, BasePromptTemplate
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.conversational.output_parser import ConvoOutputParser
|
||||
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from langchain.agents.conversational_chat.prompt import (
|
||||
TEMPLATE_TOOL_RESPONSE,
|
||||
)
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
@@ -22,6 +21,7 @@ from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction, BaseOutputParser, BasePromptTemplate
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Optional, Sequence
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from mypy_extensions import Arg, KwArg
|
||||
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
|
||||
@@ -11,9 +11,9 @@ from langchain.agents.mrkl.output_parser import MRKLOutputParser
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
BaseMessagePromptTemplate,
|
||||
@@ -22,6 +20,8 @@ from langchain.schema import (
|
||||
BasePromptTemplate,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
|
||||
@@ -7,8 +7,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.agents import BaseMultiActionAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
BaseMessagePromptTemplate,
|
||||
@@ -22,6 +20,8 @@ from langchain.schema import (
|
||||
BasePromptTemplate,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
|
||||
@@ -8,7 +8,6 @@ from langchain.agents.structured_chat.output_parser import (
|
||||
StructuredChatOutputParserWithRetries,
|
||||
)
|
||||
from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
@@ -16,6 +15,7 @@ from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction, BasePromptTemplate
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Interface for tools."""
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
def import_aim() -> Any:
|
||||
|
||||
@@ -2,8 +2,8 @@ import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import import_pandas
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import arthurai
|
||||
|
||||
@@ -1,545 +1,7 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
from langchain.schema.callbacks.base import (
|
||||
AsyncCallbackHandler,
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
)
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.output import LLMResult
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
"""Mixin for Retriever callbacks."""
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
|
||||
|
||||
class LLMManagerMixin:
|
||||
"""Mixin for LLM callbacks."""
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class ToolManagerMixin:
|
||||
"""Mixin for tool callbacks."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
|
||||
class CallbackManagerMixin:
|
||||
"""Mixin for callback manager."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
|
||||
class RunManagerMixin:
|
||||
"""Mixin for run manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
raise_error: bool = False
|
||||
|
||||
run_inline: bool = False
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return False
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever start."""
|
||||
|
||||
async def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever end."""
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error."""
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[UUID] = parent_run_id
|
||||
self.tags = tags or []
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
return False
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
if handler not in self.handlers:
|
||||
self.handlers.append(handler)
|
||||
if inherit and handler not in self.inheritable_handlers:
|
||||
self.inheritable_handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = []
|
||||
self.inheritable_handlers = []
|
||||
for handler in handlers:
|
||||
self.add_handler(handler, inherit=inherit)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
|
||||
for tag in tags:
|
||||
if tag in self.tags:
|
||||
self.remove_tags([tag])
|
||||
self.tags.extend(tags)
|
||||
if inherit:
|
||||
self.inheritable_tags.extend(tags)
|
||||
|
||||
def remove_tags(self, tags: List[str]) -> None:
|
||||
for tag in tags:
|
||||
self.tags.remove(tag)
|
||||
self.inheritable_tags.remove(tag)
|
||||
|
||||
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
|
||||
self.metadata.update(metadata)
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
|
||||
def remove_metadata(self, keys: List[str]) -> None:
|
||||
for key in keys:
|
||||
self.metadata.pop(key)
|
||||
self.inheritable_metadata.pop(key)
|
||||
__all__ = ["BaseCallbackManager", "AsyncCallbackHandler", "BaseCallbackHandler"]
|
||||
|
||||
@@ -3,7 +3,6 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
@@ -14,6 +13,7 @@ from langchain.callbacks.utils import (
|
||||
load_json,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
def import_clearml() -> Any:
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
@@ -13,6 +12,7 @@ from langchain.callbacks.utils import (
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
LANGCHAIN_MODEL_NAME = "langchain-model"
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@ import os
|
||||
from typing import Any, Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
def import_context() -> Any:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Callback Handler that writes to a file."""
|
||||
from typing import Any, Dict, Optional, TextIO, cast
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class FileCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -5,7 +5,6 @@ import logging
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
@@ -14,6 +13,7 @@ from langchain.callbacks.utils import (
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import flytekit
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
def _default_approve(_input: str) -> bool:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
def import_infino() -> Any:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,6 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
@@ -16,6 +15,7 @@ from langchain.callbacks.utils import (
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
# GPT-4 input
|
||||
|
||||
@@ -5,11 +5,11 @@ import datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
# TODO If used by two LLM runs in parallel this won't work as expected
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
import sys
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -2,13 +2,13 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
||||
LLMThoughtLabeler as LLMThoughtLabeler,
|
||||
)
|
||||
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
||||
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
@@ -5,9 +5,9 @@ from __future__ import annotations
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
@@ -7,9 +7,9 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output import ChatGeneration, LLMResult
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
@@ -14,6 +13,7 @@ from langchain.callbacks.utils import (
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
def import_wandb() -> Any:
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.callbacks.base import BaseCallbackHandler
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -5,15 +5,15 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.requests import TextRequestsWrapper
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -7,12 +7,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast
|
||||
from pydantic import BaseModel, Field
|
||||
from requests import Response
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains.api.openapi.requests_chain import APIRequesterChain
|
||||
from langchain.chains.api.openapi.response_chain import APIResponderChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.requests import Requests
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.openapi.utils.api_models import APIOperation
|
||||
|
||||
|
||||
@@ -1,598 +1,6 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
"""Purely for backwards compatibility.
|
||||
|
||||
import yaml
|
||||
from pydantic import Field, root_validator, validator
|
||||
Chain used to be defined here before moving to schema."""
|
||||
from langchain.schema.chain import Chain
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
"""Abstract base class for creating structured sequences of calls to components.
|
||||
|
||||
Chains should be used to encode a sequence of calls to components like
|
||||
models, document retrievers, other chains, etc., and provide a simple interface
|
||||
to this sequence.
|
||||
|
||||
The Chain interface makes it easy to create apps that are:
|
||||
- Stateful: add Memory to any Chain to give it state,
|
||||
- Observable: pass Callbacks to a Chain to execute additional functionality,
|
||||
like logging, outside the main sequence of component calls,
|
||||
- Composable: the Chain API is flexible enough that it is easy to combine
|
||||
Chains with other components, including other Chains.
|
||||
|
||||
The main methods exposed by chains are:
|
||||
- `__call__`: Chains are callable. The `__call__` method is the primary way to
|
||||
execute a Chain. This takes inputs as a dictionary and returns a
|
||||
dictionary output.
|
||||
- `run`: A convenience method that takes inputs as args/kwargs and returns the
|
||||
output as a string. This method can only be used for a subset of chains and
|
||||
cannot return as rich of an output as `__call__`.
|
||||
"""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
Memory is a class that gets called at the start
|
||||
and at the end of every chain. At the start, memory loads variables and passes
|
||||
them along in the chain. At the end, it saves any returned variables.
|
||||
There are many different types of memory - please see memory docs
|
||||
for the full catalog."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Optional list of callback handlers (or callback manager). Defaults to None.
|
||||
Callback handlers are called throughout the lifecycle of a call to a chain,
|
||||
starting with on_chain_start, ending with on_chain_end or on_chain_error.
|
||||
Each custom chain can optionally call additional callback methods, see Callback docs
|
||||
for full details."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Deprecated, use `callbacks` instead."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||
will be printed to the console. Defaults to `langchain.verbose` value."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the chain. Defaults to None
|
||||
These tags will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the chain. Defaults to None
|
||||
This metadata will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""Set the chain verbosity.
|
||||
|
||||
Defaults to the global setting if not specified by the user.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the keys expected to be in the chain input."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the keys expected to be in the chain output."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
`Chain.__call__`, which is the user-facing wrapper method that handles
|
||||
callbacks configuration and some input/output processing.
|
||||
|
||||
Args:
|
||||
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
|
||||
specified in `Chain.input_keys`, including any inputs added by memory.
|
||||
run_manager: The callbacks manager that contains the callback handlers for
|
||||
this run of the chain.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
`Chain.acall`, which is the user-facing wrapper method that handles
|
||||
callbacks configuration and some input/output processing.
|
||||
|
||||
Args:
|
||||
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
|
||||
specified in `Chain.input_keys`, including any inputs added by memory.
|
||||
run_manager: The callbacks manager that contains the callback handlers for
|
||||
this run of the chain.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
raise NotImplementedError("Async call not supported for this chain type.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
return_only_outputs: Whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
return_only_outputs: Whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prepare chain outputs, and save info about this run to memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of chain inputs, including any inputs added by chain
|
||||
memory.
|
||||
outputs: Dictionary of initial chain outputs.
|
||||
return_only_outputs: Whether to only return the chain outputs. If False,
|
||||
inputs are also added to the final outputs.
|
||||
|
||||
Returns:
|
||||
A dict of the final chain outputs.
|
||||
"""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
"""Validate and prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of raw inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
|
||||
Returns:
|
||||
A dictionary of all inputs, including those added by the chain's memory.
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
return self.output_keys[0]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Convenience method for executing chain when there's a single string output.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this method
|
||||
can only be used for chains that return a single string output. If a Chain
|
||||
has more outputs, a non-string output, or you want to return the inputs/run
|
||||
info along with the outputs, use `Chain.__call__`.
|
||||
|
||||
The other difference is that this method expects inputs to be passed directly in
|
||||
as positional arguments or keyword arguments, whereas `Chain.__call__` expects
|
||||
a single input dictionary with all the inputs.
|
||||
|
||||
Args:
|
||||
*args: If the chain expects a single input, it can be passed in as the
|
||||
sole positional argument.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
**kwargs: If the chain expects multiple inputs, they can be passed in
|
||||
directly as keyword arguments.
|
||||
|
||||
Returns:
|
||||
The chain output as a string.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Suppose we have a single-input chain that takes a 'question' string:
|
||||
chain.run("What's the temperature in Boise, Idaho?")
|
||||
# -> "The temperature in Boise is..."
|
||||
|
||||
# Suppose we have a multi-input chain that takes a 'question' string
|
||||
# and 'context' string:
|
||||
question = "What's the temperature in Boise, Idaho?"
|
||||
context = "Weather report for Boise, Idaho on 07/03/23..."
|
||||
chain.run(question=question, context=context)
|
||||
# -> "The temperature in Boise is..."
|
||||
"""
|
||||
# Run at start to make sure this is possible/defined
|
||||
_output_key = self._run_output_key
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
|
||||
_output_key
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
|
||||
_output_key
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Convenience method for executing chain when there's a single string output.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this method
|
||||
can only be used for chains that return a single string output. If a Chain
|
||||
has more outputs, a non-string output, or you want to return the inputs/run
|
||||
info along with the outputs, use `Chain.__call__`.
|
||||
|
||||
The other difference is that this method expects inputs to be passed directly in
|
||||
as positional arguments or keyword arguments, whereas `Chain.__call__` expects
|
||||
a single input dictionary with all the inputs.
|
||||
|
||||
Args:
|
||||
*args: If the chain expects a single input, it can be passed in as the
|
||||
sole positional argument.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
**kwargs: If the chain expects multiple inputs, they can be passed in
|
||||
directly as keyword arguments.
|
||||
|
||||
Returns:
|
||||
The chain output as a string.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Suppose we have a single-input chain that takes a 'question' string:
|
||||
await chain.arun("What's the temperature in Boise, Idaho?")
|
||||
# -> "The temperature in Boise is..."
|
||||
|
||||
# Suppose we have a multi-input chain that takes a 'question' string
|
||||
# and 'context' string:
|
||||
question = "What's the temperature in Boise, Idaho?"
|
||||
context = "Weather report for Boise, Idaho on 07/03/23..."
|
||||
await chain.arun(question=question, context=context)
|
||||
# -> "The temperature in Boise is..."
|
||||
"""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
elif args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return (
|
||||
await self.acall(
|
||||
args[0], callbacks=callbacks, tags=tags, metadata=metadata
|
||||
)
|
||||
)[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return (
|
||||
await self.acall(
|
||||
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
|
||||
)
|
||||
)[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
|
||||
method.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the chain.
|
||||
|
||||
Example:
|
||||
..code-block:: python
|
||||
|
||||
chain.dict(exclude_unset=True)
|
||||
# -> {"_type": "foo", "verbose": False, ...}
|
||||
"""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict(**kwargs)
|
||||
_dict["_type"] = self._chain_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
__all__ = ["Chain"]
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
|
||||
|
||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
|
||||
|
||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@@ -6,9 +6,9 @@ from typing import Any, Callable, List, Optional, Protocol, Tuple
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
|
||||
|
||||
class CombineDocsProtocol(Protocol):
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
)
|
||||
@@ -14,6 +13,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate, format_document
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
)
|
||||
@@ -12,6 +11,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate, format_document
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""Chain for applying constitutional principles to the outputs of another chain."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
||||
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -9,18 +9,18 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
@@ -7,10 +7,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.flare.prompts import (
|
||||
PROMPT,
|
||||
QUESTION_GENERATOR_PROMPT,
|
||||
@@ -19,6 +15,8 @@ from langchain.chains.flare.prompts import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.schema import BasePromptTemplate, BaseRetriever, Generation
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, GRAPH_QA_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.neo4j_graph import Neo4jGraph
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
@@ -5,8 +5,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
GREMLIN_GENERATION_PROMPT,
|
||||
@@ -14,6 +12,8 @@ from langchain.chains.graph_qa.prompts import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.hugegraph import HugeGraph
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.kuzu_graph import KuzuGraph
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.nebula_graph import NebulaGraph
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import (
|
||||
SPARQL_GENERATION_SELECT_PROMPT,
|
||||
SPARQL_GENERATION_UPDATE_PROMPT,
|
||||
@@ -18,6 +16,8 @@ from langchain.chains.graph_qa.prompts import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.rdf_graph import RdfGraph
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ from typing import Any, Dict, List, Optional
|
||||
import numpy as np
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -6,14 +6,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import Extra, Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
@@ -24,6 +16,14 @@ from langchain.schema import (
|
||||
NoOpOutputParser,
|
||||
PromptValue,
|
||||
)
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -7,11 +7,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.schema import BasePromptTemplate, OutputParserException
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
|
||||
@@ -6,8 +6,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_checker.prompt import (
|
||||
CHECK_ASSERTIONS_PROMPT,
|
||||
@@ -17,6 +15,8 @@ from langchain.chains.llm_checker.prompt import (
|
||||
)
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ from typing import Any, Dict, List, Optional
|
||||
import numexpr
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.requests import TextRequestsWrapper
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501
|
||||
|
||||
@@ -8,11 +8,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||
|
||||
@@ -7,14 +7,14 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_symbolic_math.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
|
||||
class LLMSymbolicMathChain(Chain):
|
||||
|
||||
@@ -7,7 +7,6 @@ import yaml
|
||||
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
@@ -30,6 +29,7 @@ from langchain.prompts.loading import (
|
||||
load_prompt,
|
||||
load_prompt_from_config,
|
||||
)
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
|
||||
|
||||
@@ -9,15 +9,15 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.natbot.prompt import PROMPT
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import (
|
||||
_convert_schema,
|
||||
@@ -14,6 +13,7 @@ from langchain.output_parsers.openai_functions import (
|
||||
PydanticAttrOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -8,14 +8,14 @@ from openapi_schema_pydantic import Parameter
|
||||
from requests import Response
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import APIOperation
|
||||
from langchain.utilities.openapi import OpenAPISpec
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
@@ -8,6 +7,7 @@ from langchain.output_parsers.openai_functions import (
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
@@ -9,12 +9,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
@@ -27,6 +22,11 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -4,14 +4,14 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
|
||||
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
@@ -5,13 +5,13 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Load question answering chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -18,6 +16,8 @@ from langchain.chains.question_answering import (
|
||||
from langchain.chains.question_answering.map_rerank_prompt import (
|
||||
PROMPT as MAP_RERANK_PROMPT,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.callbacks.manager import Callbacks
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
|
||||
|
||||
@@ -8,11 +8,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -20,6 +15,11 @@ from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@@ -6,12 +6,12 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema.chain import Chain
|
||||
|
||||
|
||||
class Route(NamedTuple):
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.router.base import RouterChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.router.base import RouterChain
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.conversation.prompt import DEFAULT_TEMPLATE
|
||||
from langchain.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA
|
||||
from langchain.chains.router.base import MultiRouteChain
|
||||
@@ -15,6 +14,7 @@ from langchain.chains.router.multi_retrieval_prompt import (
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.schema.chain import Chain
|
||||
|
||||
|
||||
class SequentialChain(Chain):
|
||||
|
||||
@@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Chain that runs an arbitrary python function."""
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema.chain import Chain
|
||||
|
||||
|
||||
class TransformChain(Chain):
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.anthropic import _AnthropicCommon
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
|
||||
@@ -8,14 +8,6 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
@@ -24,6 +16,14 @@ from langchain.schema import (
|
||||
PromptValue,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain.schema.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
|
||||
|
||||
@@ -13,15 +13,15 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
|
||||
@@ -7,12 +7,12 @@ from typing import Any, Callable, List, Mapping, Optional
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
|
||||
@@ -22,10 +22,6 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
@@ -36,6 +32,10 @@ from langchain.schema import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user