Compare commits

...

7 Commits

Author SHA1 Message Date
Nuno Campos
d29bda6e63 Allow 3rd party packages to register callback modifiers 2023-07-09 21:35:38 +01:00
Nuno Campos
1e34b827ca Separate CallbackManager configure logic from loading handlers from context/env 2023-07-09 21:21:46 +01:00
Harrison Chase
b7d661d01d move to core 2023-07-09 12:14:47 -04:00
Harrison Chase
5c74c93c61 move chain to schema 2023-07-09 11:38:24 -04:00
Harrison Chase
38dc42b242 cr 2023-07-09 11:33:47 -04:00
Harrison Chase
73f808ad68 move callbacks to schema 2023-07-09 11:21:41 -04:00
Harrison Chase
b0f8c287c3 move callbacks base to schema 2023-07-09 09:51:51 -04:00
311 changed files with 2442 additions and 2346 deletions

View File

@@ -5,6 +5,7 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.callbacks.context import add_handlers_from_langchain_context
from langchain.chains import (
ConversationChain,
LLMBashChain,
@@ -42,6 +43,7 @@ from langchain.prompts import (
Prompt,
PromptTemplate,
)
from langchain.schema.callbacks.manager import callback_modifiers
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.sql_database import SQLDatabase
from langchain.utilities.arxiv import ArxivAPIWrapper
@@ -54,6 +56,8 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from langchain.vectorstores import FAISS, ElasticVectorSearch
callback_modifiers.append(add_handlers_from_langchain_context)
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:

View File

@@ -14,15 +14,6 @@ from pydantic import BaseModel, root_validator
from langchain.agents.agent_types import AgentType
from langchain.agents.tools import InvalidTool
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForToolRun,
CallbackManagerForChainRun,
CallbackManagerForToolRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
from langchain.prompts.few_shot import FewShotPromptTemplate
@@ -34,6 +25,15 @@ from langchain.schema import (
BasePromptTemplate,
OutputParserException,
)
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForToolRun,
CallbackManagerForChainRun,
CallbackManagerForToolRun,
Callbacks,
)
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
from langchain.tools.base import BaseTool

View File

@@ -6,8 +6,8 @@ from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -9,8 +9,8 @@ from langchain.agents.agent_toolkits.openapi.prompt import (
from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -28,13 +28,13 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import (
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.requests import RequestsWrapper
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.requests.tool import BaseRequestsTool

View File

@@ -16,9 +16,9 @@ from langchain.agents.agent_toolkits.pandas.prompt import (
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import SystemMessage
from langchain.tools.python.tool import PythonAstREPLTool

View File

@@ -9,8 +9,8 @@ from langchain.agents.agent_toolkits.powerbi.prompt import (
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.utilities.powerbi import PowerBIDataset

View File

@@ -9,10 +9,10 @@ from langchain.agents.agent_toolkits.powerbi.prompt import (
)
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain.agents.conversational_chat.base import ConversationalChatAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chat_models.base import BaseChatModel
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.utilities.powerbi import PowerBIDataset

View File

@@ -4,7 +4,6 @@ from typing import List, Optional, Union
from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.prompts import PromptTemplate
@@ -13,6 +12,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from langchain.tools.powerbi.prompt import (

View File

@@ -7,8 +7,8 @@ from langchain.agents.agent_toolkits.python.prompt import PREFIX
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import SystemMessage
from langchain.tools.python.tool import PythonREPLTool

View File

@@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.tools.python.tool import PythonAstREPLTool

View File

@@ -6,8 +6,8 @@ from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUF
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -12,13 +12,13 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, SystemMessage

View File

@@ -8,8 +8,8 @@ from langchain.agents.agent_toolkits.vectorstore.toolkit import (
VectorStoreToolkit,
)
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -11,7 +11,6 @@ from langchain.agents.chat.prompt import (
SYSTEM_MESSAGE_SUFFIX,
)
from langchain.agents.utils import validate_tools_single_input
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
@@ -19,6 +18,7 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction, BasePromptTemplate
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@@ -10,9 +10,9 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.conversational.output_parser import ConvoOutputParser
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.utils import validate_tools_single_input
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@@ -13,7 +13,6 @@ from langchain.agents.conversational_chat.prompt import (
TEMPLATE_TOOL_RESPONSE,
)
from langchain.agents.utils import validate_tools_single_input
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
@@ -22,6 +21,7 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction, BaseOutputParser, BasePromptTemplate
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.tools.base import BaseTool

View File

@@ -4,7 +4,7 @@ from typing import Any, Optional, Sequence
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@@ -6,8 +6,8 @@ from mypy_extensions import Arg, KwArg
from langchain.agents.tools import Tool
from langchain.schema.language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.callbacks.manager import Callbacks
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain

View File

@@ -11,9 +11,9 @@ from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@@ -7,8 +7,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
from pydantic import root_validator
from langchain.agents import BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
@@ -22,6 +20,8 @@ from langchain.schema import (
BasePromptTemplate,
OutputParserException,
)
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.callbacks.manager import Callbacks
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import (
AIMessage,

View File

@@ -7,8 +7,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
from pydantic import root_validator
from langchain.agents import BaseMultiActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
@@ -22,6 +20,8 @@ from langchain.schema import (
BasePromptTemplate,
OutputParserException,
)
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.callbacks.manager import Callbacks
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import (
AIMessage,

View File

@@ -8,7 +8,6 @@ from langchain.agents.structured_chat.output_parser import (
StructuredChatOutputParserWithRetries,
)
from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
@@ -16,6 +15,7 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction, BasePromptTemplate
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool

View File

@@ -1,7 +1,7 @@
"""Interface for tools."""
from typing import Optional
from langchain.callbacks.manager import (
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)

View File

@@ -6,16 +6,16 @@ from langchain.callbacks.arize_callback import ArizeCallbackHandler
from langchain.callbacks.arthur_callback import ArthurCallbackHandler
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
from langchain.callbacks.context import (
get_openai_callback,
tracing_enabled,
wandb_tracing_enabled,
)
from langchain.callbacks.context_callback import ContextCallbackHandler
from langchain.callbacks.file import FileCallbackHandler
from langchain.callbacks.flyte_callback import FlyteCallbackHandler
from langchain.callbacks.human import HumanApprovalCallbackHandler
from langchain.callbacks.infino_callback import InfinoCallbackHandler
from langchain.callbacks.manager import (
get_openai_callback,
tracing_enabled,
wandb_tracing_enabled,
)
from langchain.callbacks.mlflow_callback import MlflowCallbackHandler
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.promptlayer_callback import PromptLayerCallbackHandler

View File

@@ -1,8 +1,8 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
def import_aim() -> Any:

View File

@@ -2,8 +2,8 @@ import os
import warnings
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
class ArgillaCallbackHandler(BaseCallbackHandler):

View File

@@ -1,9 +1,9 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import import_pandas
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
class ArizeCallbackHandler(BaseCallbackHandler):

View File

@@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional, Union
import numpy as np
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
import arthurai

View File

@@ -3,7 +3,6 @@ from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
@@ -14,6 +13,7 @@ from langchain.callbacks.utils import (
load_json,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
def import_clearml() -> Any:

View File

@@ -4,7 +4,6 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import langchain
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
@@ -13,6 +12,7 @@ from langchain.callbacks.utils import (
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
LANGCHAIN_MODEL_NAME = "langchain-model"

View File

@@ -0,0 +1,349 @@
from __future__ import annotations
import logging
import os
import warnings
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from typing import (
AsyncGenerator,
Generator,
List,
Optional,
Union,
cast,
)
from uuid import UUID
import langchain
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.callbacks.tracers.wandb import WandbTracer
from langchain.schema.callbacks.manager import (
AsyncCallbackManager,
CallbackManager,
)
logger = logging.getLogger(__name__)
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
tracing_callback_var: ContextVar[
Optional[LangChainTracerV1]
] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
wandb_tracing_callback_var: ContextVar[
Optional[WandbTracer]
] = ContextVar( # noqa: E501
"tracing_wandb_callback", default=None
)
tracing_v2_callback_var: ContextVar[
Optional[LangChainTracer]
] = ContextVar( # noqa: E501
"tracing_callback_v2", default=None
)
def _get_debug() -> bool:
return langchain.debug
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get the OpenAI callback handler in a context manager.
which conveniently exposes token and cost information.
Returns:
OpenAICallbackHandler: The OpenAI callback handler.
Example:
>>> with get_openai_callback() as cb:
... # Use the OpenAI callback handler
"""
cb = OpenAICallbackHandler()
openai_callback_var.set(cb)
yield cb
openai_callback_var.set(None)
@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]:
"""Get the Deprecated LangChainTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
TracerSessionV1: The LangChainTracer session.
Example:
>>> with tracing_enabled() as session:
... # Use the LangChainTracer session
"""
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
@contextmanager
def wandb_tracing_enabled(
session_name: str = "default",
) -> Generator[None, None, None]:
"""Get the WandbTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
None
Example:
>>> with wandb_tracing_enabled() as session:
... # Use the WandbTracer session
"""
cb = WandbTracer()
wandb_tracing_callback_var.set(cb)
yield None
wandb_tracing_callback_var.set(None)
@contextmanager
def tracing_v2_enabled(
project_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
) -> Generator[None, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith.
Args:
project_name (str, optional): The name of the project.
Defaults to "default".
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The tags to add to the run.
Defaults to None.
Returns:
None
Example:
>>> with tracing_v2_enabled():
... # LangChain code will automatically be traced
"""
# Issue a warning that this is experimental
warnings.warn(
"The tracing v2 API is in development. "
"This is not yet stable and may change in the future."
)
if isinstance(example_id, str):
example_id = UUID(example_id)
cb = LangChainTracer(
example_id=example_id,
project_name=project_name,
tags=tags,
)
tracing_v2_callback_var.set(cb)
yield
tracing_v2_callback_var.set(None)
@contextmanager
def trace_as_chain_group(
group_name: str,
*,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
) -> Generator[CallbackManager, None, None]:
"""Get a callback manager for a chain group in a context manager.
Useful for grouping different calls together as a single run even if
they aren't composed in a single chain.
Args:
group_name (str): The name of the chain group.
project_name (str, optional): The name of the project.
Defaults to None.
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None.
Returns:
CallbackManager: The callback manager for the chain group.
Example:
>>> with trace_as_chain_group("group_name") as manager:
... # Use the callback manager for the chain group
... llm.predict("Foo", callbacks=manager)
"""
cb = LangChainTracer(
project_name=project_name,
example_id=example_id,
)
cm = CallbackManager.configure(
inheritable_callbacks=[cb],
inheritable_tags=tags,
)
run_manager = cm.on_chain_start({"name": group_name}, {})
yield run_manager.get_child()
run_manager.on_chain_end({})
@asynccontextmanager
async def atrace_as_chain_group(
group_name: str,
*,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
) -> AsyncGenerator[AsyncCallbackManager, None]:
"""Get an async callback manager for a chain group in a context manager.
Useful for grouping different async calls together as a single run even if
they aren't composed in a single chain.
Args:
group_name (str): The name of the chain group.
project_name (str, optional): The name of the project.
Defaults to None.
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None.
Returns:
AsyncCallbackManager: The async callback manager for the chain group.
Example:
>>> async with atrace_as_chain_group("group_name") as manager:
... # Use the async callback manager for the chain group
... await llm.apredict("Foo", callbacks=manager)
"""
cb = LangChainTracer(
project_name=project_name,
example_id=example_id,
)
cm = AsyncCallbackManager.configure(
inheritable_callbacks=[cb], inheritable_tags=tags
)
run_manager = await cm.on_chain_start({"name": group_name}, {})
try:
yield run_manager.get_child()
finally:
await run_manager.on_chain_end({})
def env_var_is_set(env_var: str) -> bool:
"""Check if an environment variable is set.
Args:
env_var (str): The name of the environment variable.
Returns:
bool: True if the environment variable is set, False otherwise.
"""
return env_var in os.environ and os.environ[env_var] not in (
"",
"0",
"false",
"False",
)
def add_handlers_from_langchain_context(
callback_manager: Union[CallbackManager, AsyncCallbackManager], verbose: bool
) -> None:
tracer = tracing_callback_var.get()
wandb_tracer = wandb_tracing_callback_var.get()
open_ai = openai_callback_var.get()
tracing_enabled_ = (
env_var_is_set("LANGCHAIN_TRACING")
or tracer is not None
or env_var_is_set("LANGCHAIN_HANDLER")
)
wandb_tracing_enabled_ = (
env_var_is_set("LANGCHAIN_WANDB_TRACING") or wandb_tracer is not None
)
tracer_v2 = tracing_v2_callback_var.get()
tracing_v2_enabled_ = (
env_var_is_set("LANGCHAIN_TRACING_V2") or tracer_v2 is not None
)
tracer_project = os.environ.get(
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
)
debug = _get_debug()
if (
verbose
or debug
or tracing_enabled_
or tracing_v2_enabled_
or wandb_tracing_enabled_
or open_ai is not None
):
if verbose and not any(
isinstance(handler, StdOutCallbackHandler)
for handler in callback_manager.handlers
):
if debug:
pass
else:
callback_manager.add_handler(StdOutCallbackHandler(), False)
if debug and not any(
isinstance(handler, ConsoleCallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(ConsoleCallbackHandler(), True)
if tracing_enabled_ and not any(
isinstance(handler, LangChainTracerV1)
for handler in callback_manager.handlers
):
if tracer:
callback_manager.add_handler(tracer, True)
else:
handler = LangChainTracerV1()
handler.load_session(tracer_project)
callback_manager.add_handler(handler, True)
if wandb_tracing_enabled_ and not any(
isinstance(handler, WandbTracer) for handler in callback_manager.handlers
):
if wandb_tracer:
callback_manager.add_handler(wandb_tracer, True)
else:
callback_manager.add_handler(WandbTracer(), True)
if tracing_v2_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
if tracer_v2:
callback_manager.add_handler(tracer_v2, True)
else:
try:
callback_manager.add_handler(
LangChainTracer(project_name=tracer_project), True
)
except Exception as e:
logger.warning(
"Unable to load requested LangChainTracer."
" To disable this warning,"
" unset the LANGCHAIN_TRACING_V2 environment variables.",
e,
)
if open_ai is not None and not any(
isinstance(handler, OpenAICallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(open_ai, True)

View File

@@ -3,11 +3,11 @@ import os
from typing import Any, Dict, List
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import (
BaseMessage,
LLMResult,
)
from langchain.schema.callbacks.base import BaseCallbackHandler
def import_context() -> Any:

View File

@@ -1,9 +1,9 @@
"""Callback Handler that writes to a file."""
from typing import Any, Dict, Optional, TextIO, cast
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish
from langchain.schema.callbacks.base import BaseCallbackHandler
class FileCallbackHandler(BaseCallbackHandler):

View File

@@ -5,7 +5,6 @@ import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
@@ -14,6 +13,7 @@ from langchain.callbacks.utils import (
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
import flytekit

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, Optional
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.callbacks.base import BaseCallbackHandler
def _default_approve(_input: str) -> bool:

View File

@@ -1,8 +1,8 @@
import time
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
def import_infino() -> Any:

View File

@@ -6,7 +6,6 @@ from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
@@ -16,6 +15,7 @@ from langchain.callbacks.utils import (
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
from langchain.utils import get_from_dict_or_env

View File

@@ -1,8 +1,8 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
MODEL_COST_PER_1K_TOKENS = {
# GPT-4 input

View File

@@ -5,11 +5,11 @@ import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import (
ChatGeneration,
LLMResult,
)
from langchain.schema.callbacks.base import BaseCallbackHandler
from langchain.schema.messages import (
AIMessage,
BaseMessage,

View File

@@ -1,9 +1,9 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
class StdOutCallbackHandler(BaseCallbackHandler):

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import LLMResult
from langchain.schema.callbacks.base import AsyncCallbackHandler
# TODO If used by two LLM runs in parallel this won't work as expected

View File

@@ -2,8 +2,8 @@
import sys
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
class StreamingStdOutCallbackHandler(BaseCallbackHandler):

View File

@@ -2,13 +2,13 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streamlit.streamlit_callback_handler import (
LLMThoughtLabeler as LLMThoughtLabeler,
)
from langchain.callbacks.streamlit.streamlit_callback_handler import (
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
)
from langchain.schema.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator

View File

@@ -5,9 +5,9 @@ from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator

View File

@@ -7,9 +7,9 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.load.dump import dumpd
from langchain.schema.callbacks.base import BaseCallbackHandler
from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks.context import tracing_v2_enabled
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run

View File

@@ -4,7 +4,6 @@ from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
@@ -14,6 +13,7 @@ from langchain.callbacks.utils import (
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
def import_wandb() -> Any:

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
from langchain.schema.callbacks.base import BaseCallbackHandler
from langchain.utils import get_from_env
if TYPE_CHECKING:

View File

@@ -5,15 +5,15 @@ from typing import Any, Dict, List, Optional
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.requests import TextRequestsWrapper
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -7,12 +7,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast
from pydantic import BaseModel, Field
from requests import Response
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.api.openapi.requests_chain import APIRequesterChain
from langchain.chains.api.openapi.response_chain import APIResponderChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.requests import Requests
from langchain.schema.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.openapi.utils.api_models import APIOperation

View File

@@ -1,598 +1,6 @@
"""Base interface that all chains should implement."""
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
"""Purely for backwards compatibility.
import yaml
from pydantic import Field, root_validator, validator
Chain used to be defined here before moving to schema."""
from langchain.schema.chain import Chain
import langchain
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
logger = logging.getLogger(__name__)
def _get_verbosity() -> bool:
return langchain.verbose
class Chain(Serializable, ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like
models, document retrievers, other chains, etc., and provide a simple interface
to this sequence.
The Chain interface makes it easy to create apps that are:
- Stateful: add Memory to any Chain to give it state,
- Observable: pass Callbacks to a Chain to execute additional functionality,
like logging, outside the main sequence of component calls,
- Composable: the Chain API is flexible enough that it is easy to combine
Chains with other components, including other Chains.
The main methods exposed by chains are:
- `__call__`: Chains are callable. The `__call__` method is the primary way to
execute a Chain. This takes inputs as a dictionary and returns a
dictionary output.
- `run`: A convenience method that takes inputs as args/kwargs and returns the
output as a string. This method can only be used for a subset of chains and
cannot return as rich of an output as `__call__`.
"""
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Memory is a class that gets called at the start
and at the end of every chain. At the start, memory loads variables and passes
them along in the chain. At the end, it saves any returned variables.
There are many different types of memory - please see memory docs
for the full catalog."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Optional list of callback handlers (or callback manager). Defaults to None.
Callback handlers are called throughout the lifecycle of a call to a chain,
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Deprecated, use `callbacks` instead."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to `langchain.verbose` value."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""Set the chain verbosity.
Defaults to the global setting if not specified by the user.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Return the keys expected to be in the chain input."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Return the keys expected to be in the chain output."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Execute the chain.
This is a private method that is not user-facing. It is only called within
`Chain.__call__`, which is the user-facing wrapper method that handles
callbacks configuration and some input/output processing.
Args:
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
specified in `Chain.input_keys`, including any inputs added by memory.
run_manager: The callbacks manager that contains the callback handlers for
this run of the chain.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Asynchronously execute the chain.
This is a private method that is not user-facing. It is only called within
`Chain.acall`, which is the user-facing wrapper method that handles
callbacks configuration and some input/output processing.
Args:
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
specified in `Chain.input_keys`, including any inputs added by memory.
run_manager: The callbacks manager that contains the callback handlers for
this run of the chain.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
raise NotImplementedError("Async call not supported for this chain type.")
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Execute the chain.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
return_only_outputs: Whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
async def acall(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Asynchronously execute the chain.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
return_only_outputs: Whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
)
try:
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prepare chain outputs, and save info about this run to memory.
Args:
inputs: Dictionary of chain inputs, including any inputs added by chain
memory.
outputs: Dictionary of initial chain outputs.
return_only_outputs: Whether to only return the chain outputs. If False,
inputs are also added to the final outputs.
Returns:
A dict of the final chain outputs.
"""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Validate and prepare chain inputs, including adding inputs from memory.
Args:
inputs: Dictionary of raw inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
Returns:
A dictionary of all inputs, including those added by the chain's memory.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
@property
def _run_output_key(self) -> str:
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
return self.output_keys[0]
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Convenience method for executing chain when there's a single string output.
The main difference between this method and `Chain.__call__` is that this method
can only be used for chains that return a single string output. If a Chain
has more outputs, a non-string output, or you want to return the inputs/run
info along with the outputs, use `Chain.__call__`.
The other difference is that this method expects inputs to be passed directly in
as positional arguments or keyword arguments, whereas `Chain.__call__` expects
a single input dictionary with all the inputs.
Args:
*args: If the chain expects a single input, it can be passed in as the
sole positional argument.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
**kwargs: If the chain expects multiple inputs, they can be passed in
directly as keyword arguments.
Returns:
The chain output as a string.
Example:
.. code-block:: python
# Suppose we have a single-input chain that takes a 'question' string:
chain.run("What's the temperature in Boise, Idaho?")
# -> "The temperature in Boise is..."
# Suppose we have a multi-input chain that takes a 'question' string
# and 'context' string:
question = "What's the temperature in Boise, Idaho?"
context = "Weather report for Boise, Idaho on 07/03/23..."
chain.run(question=question, context=context)
# -> "The temperature in Boise is..."
"""
# Run at start to make sure this is possible/defined
_output_key = self._run_output_key
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
if not kwargs and not args:
raise ValueError(
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
else:
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
async def arun(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Convenience method for executing chain when there's a single string output.
The main difference between this method and `Chain.__call__` is that this method
can only be used for chains that return a single string output. If a Chain
has more outputs, a non-string output, or you want to return the inputs/run
info along with the outputs, use `Chain.__call__`.
The other difference is that this method expects inputs to be passed directly in
as positional arguments or keyword arguments, whereas `Chain.__call__` expects
a single input dictionary with all the inputs.
Args:
*args: If the chain expects a single input, it can be passed in as the
sole positional argument.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
**kwargs: If the chain expects multiple inputs, they can be passed in
directly as keyword arguments.
Returns:
The chain output as a string.
Example:
.. code-block:: python
# Suppose we have a single-input chain that takes a 'question' string:
await chain.arun("What's the temperature in Boise, Idaho?")
# -> "The temperature in Boise is..."
# Suppose we have a multi-input chain that takes a 'question' string
# and 'context' string:
question = "What's the temperature in Boise, Idaho?"
context = "Weather report for Boise, Idaho on 07/03/23..."
await chain.arun(question=question, context=context)
# -> "The temperature in Boise is..."
"""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
elif args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return (
await self.acall(
args[0], callbacks=callbacks, tags=tags, metadata=metadata
)
)[self.output_keys[0]]
if kwargs and not args:
return (
await self.acall(
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
)
)[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.
Args:
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
method.
Returns:
A dictionary representation of the chain.
Example:
..code-block:: python
chain.dict(exclude_unset=True)
# -> {"_type": "foo", "verbose": False, ...}
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict(**kwargs)
_dict["_type"] = self._chain_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
__all__ = ["Chain"]

View File

@@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field
from langchain.callbacks.manager import (
from langchain.docstore.document import Document
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.schema.chain import Chain
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter

View File

@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Extra, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.schema.callbacks.manager import Callbacks
class MapReduceDocumentsChain(BaseCombineDocumentsChain):

View File

@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from pydantic import Extra, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.schema.callbacks.manager import Callbacks
class MapRerankDocumentsChain(BaseCombineDocumentsChain):

View File

@@ -6,9 +6,9 @@ from typing import Any, Callable, List, Optional, Protocol, Tuple
from pydantic import Extra
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.docstore.document import Document
from langchain.schema.callbacks.manager import Callbacks
class CombineDocsProtocol(Protocol):

View File

@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Tuple
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
)
@@ -14,6 +13,7 @@ from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate, format_document
from langchain.schema.callbacks.manager import Callbacks
def _get_default_document_prompt() -> PromptTemplate:

View File

@@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
)
@@ -12,6 +11,7 @@ from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate, format_document
from langchain.schema.callbacks.manager import Callbacks
def _get_default_document_prompt() -> PromptTemplate:

View File

@@ -1,13 +1,13 @@
"""Chain for applying constitutional principles to the outputs of another chain."""
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.constitutional_ai.principles import PRINCIPLES
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -9,18 +9,18 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
from langchain.vectorstores.base import VectorStore

View File

@@ -7,10 +7,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
from pydantic import Field
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
PROMPT,
QUESTION_GENERATOR_PROMPT,
@@ -19,6 +15,8 @@ from langchain.chains.flare.prompts import (
from langchain.chains.llm import LLMChain
from langchain.llms import OpenAI
from langchain.schema import BasePromptTemplate, BaseRetriever, Generation
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, GRAPH_QA_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.neo4j_graph import Neo4jGraph
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
INTERMEDIATE_STEPS_KEY = "intermediate_steps"

View File

@@ -5,8 +5,6 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
GREMLIN_GENERATION_PROMPT,
@@ -14,6 +12,8 @@ from langchain.chains.graph_qa.prompts import (
from langchain.chains.llm import LLMChain
from langchain.graphs.hugegraph import HugeGraph
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.kuzu_graph import KuzuGraph
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.nebula_graph import NebulaGraph
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -7,8 +7,6 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
SPARQL_GENERATION_SELECT_PROMPT,
SPARQL_GENERATION_UPDATE_PROMPT,
@@ -18,6 +16,8 @@ from langchain.chains.graph_qa.prompts import (
from langchain.chains.llm import LLMChain
from langchain.graphs.rdf_graph import RdfGraph
from langchain.prompts.base import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -9,11 +9,11 @@ from typing import Any, Dict, List, Optional
import numpy as np
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.chains.llm import LLMChain
from langchain.embeddings.base import Embeddings
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -1,340 +1,3 @@
"""Chain that just formats a prompt and calls an LLM."""
from __future__ import annotations
from langchain.core.chains.llm import LLMChain
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import Extra, Field
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLLMOutputParser,
BasePromptTemplate,
LLMResult,
NoOpOutputParser,
PromptValue,
)
from langchain.schema.language_model import BaseLanguageModel
class LLMChain(Chain):
"""Chain to run queries against LLMs.
Example:
.. code-block:: python
from langchain import LLMChain, OpenAI, PromptTemplate
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template
)
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@property
def lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
"""Language model to call."""
output_key: str = "text" #: :meta private:
output_parser: BaseLLMOutputParser = Field(default_factory=NoOpOutputParser)
"""Output parser to use.
Defaults to one that takes the most likely string but does not change it
otherwise."""
return_final_only: bool = True
"""Whether to return only the final parsed result. Defaults to True.
If false, will return a bunch of extra information about the generation."""
llm_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
if self.return_final_only:
return [self.output_key]
else:
return [self.output_key, "full_generation"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
async def agenerate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
return await self.llm.agenerate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
async def aprep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = self.generate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
return outputs
async def aapply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = await self.agenerate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
await run_manager.on_chain_end({"outputs": outputs})
return outputs
@property
def _run_output_key(self) -> str:
return self.output_key
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
"""Create outputs from response."""
result = [
# Get the text of the top generated string.
{
self.output_key: self.output_parser.parse_result(generation),
"full_generation": generation,
}
for generation in llm_result.generations
]
if self.return_final_only:
result = [{self.output_key: r[self.output_key]} for r in result]
return result
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = await self.agenerate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs, callbacks=callbacks)[self.output_key]
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
"""Call predict and then parse the results."""
warnings.warn(
"The predict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
async def apredict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Call apredict and then parse the results."""
warnings.warn(
"The apredict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The apply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = self.apply(input_list, callbacks=callbacks)
return self._parse_generation(result)
def _parse_generation(
self, generation: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key])
for res in generation
]
else:
return generation
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The aapply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = await self.aapply(input_list, callbacks=callbacks)
return self._parse_generation(result)
@property
def _chain_type(self) -> str:
return "llm_chain"
@classmethod
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
__all__ = ["LLMChain"]

View File

@@ -7,11 +7,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.schema import BasePromptTemplate, OutputParserException
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.utilities.bash import BashProcess

View File

@@ -6,8 +6,6 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_checker.prompt import (
CHECK_ASSERTIONS_PROMPT,
@@ -17,6 +15,8 @@ from langchain.chains.llm_checker.prompt import (
)
from langchain.chains.sequential import SequentialChain
from langchain.prompts import PromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -9,14 +9,14 @@ from typing import Any, Dict, List, Optional
import numexpr
from pydantic import Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -5,10 +5,10 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from langchain.requests import TextRequestsWrapper
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501

View File

@@ -8,11 +8,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sequential import SequentialChain
from langchain.prompts.prompt import PromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
PROMPTS_DIR = Path(__file__).parent / "prompts"

View File

@@ -7,7 +7,6 @@ import yaml
from langchain.chains import ReduceDocumentsChain
from langchain.chains.api.base import APIChain
from langchain.chains.base import Chain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
@@ -30,6 +29,7 @@ from langchain.prompts.loading import (
load_prompt,
load_prompt_from_config,
)
from langchain.schema.chain import Chain
from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"

View File

@@ -9,15 +9,15 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.text_splitter import TextSplitter

View File

@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.utils import get_from_dict_or_env

View File

@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.natbot.prompt import PROMPT
from langchain.llms.openai import OpenAI
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -2,7 +2,6 @@ from typing import Any, List
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import (
_convert_schema,
@@ -14,6 +13,7 @@ from langchain.output_parsers.openai_functions import (
PydanticAttrOutputFunctionsParser,
)
from langchain.prompts import ChatPromptTemplate
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -8,14 +8,14 @@ from openapi_schema_pydantic import Parameter
from requests import Response
from langchain import LLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain
from langchain.chat_models import ChatOpenAI
from langchain.input import get_colored_text
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import APIOperation
from langchain.utilities.openapi import OpenAPISpec

View File

@@ -1,6 +1,5 @@
from typing import Any
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
from langchain.output_parsers.openai_functions import (
@@ -8,6 +7,7 @@ from langchain.output_parsers.openai_functions import (
PydanticOutputFunctionsParser,
)
from langchain.prompts import ChatPromptTemplate
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -9,12 +9,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.utilities import PythonREPL

View File

@@ -5,11 +5,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter

View File

@@ -9,12 +9,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
@@ -27,6 +22,11 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
)
from langchain.docstore.document import Document
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -4,14 +4,14 @@ from typing import Any, Dict, List
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):

View File

@@ -5,13 +5,13 @@ from typing import Any, Dict, List
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.vectorstores.base import VectorStore

View File

@@ -1,8 +1,6 @@
"""Load question answering chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains import ReduceDocumentsChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -18,6 +16,8 @@ from langchain.chains.question_answering import (
from langchain.chains.question_answering.map_rerank_prompt import (
PROMPT as MAP_RERANK_PROMPT,
)
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.callbacks.manager import Callbacks
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.prompt_template import BasePromptTemplate

View File

@@ -8,11 +8,6 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
@@ -20,6 +15,11 @@ from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.prompts import PromptTemplate
from langchain.schema import BaseRetriever, Document
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.vectorstores.base import VectorStore

View File

@@ -6,12 +6,12 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional
from pydantic import Extra
from langchain.callbacks.manager import (
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.schema.chain import Chain
class Route(NamedTuple):

View File

@@ -4,10 +4,10 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.router.base import RouterChain
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.vectorstores.base import VectorStore

View File

@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Optional, Type, cast
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains import LLMChain
from langchain.chains.router.base import RouterChain
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional
from langchain.chains import ConversationChain
from langchain.chains.base import Chain
from langchain.chains.conversation.prompt import DEFAULT_TEMPLATE
from langchain.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA
from langchain.chains.router.base import MultiRouteChain
@@ -15,6 +14,7 @@ from langchain.chains.router.multi_retrieval_prompt import (
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.schema import BaseRetriever
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel

View File

@@ -3,12 +3,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import (
from langchain.input import get_color_mapping
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.input import get_color_mapping
from langchain.schema.chain import Chain
class SequentialChain(Chain):

View File

@@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.sql_database import SQLDatabase
from langchain.tools.sql_database.prompt import QUERY_CHECKER

View File

@@ -1,8 +1,8 @@
"""Chain that runs an arbitrary python function."""
from typing import Callable, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema.callbacks.manager import CallbackManagerForChainRun
from langchain.schema.chain import Chain
class TransformChain(Chain):

View File

@@ -1,15 +1,15 @@
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,

View File

@@ -8,14 +8,6 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence
from pydantic import Field, root_validator
import langchain
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.load.dump import dumpd, dumps
from langchain.schema import (
ChatGeneration,
@@ -24,6 +16,14 @@ from langchain.schema import (
PromptValue,
RunInfo,
)
from langchain.schema.callbacks.base import BaseCallbackManager
from langchain.schema.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage

View File

@@ -1,8 +1,8 @@
"""Fake ChatModel for testing purposes."""
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema.messages import BaseMessage

View File

@@ -13,15 +13,15 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,

View File

@@ -7,12 +7,12 @@ from typing import Any, Callable, List, Mapping, Optional
import yaml
from pydantic import Field
from langchain.callbacks.manager import (
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema.messages import (
BaseMessage,
HumanMessage,

Some files were not shown because too many files have changed in this diff Show More