mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
docstrings update (#12093)
Added missed docstrings. Added missed Args:, Returns: Raises:
This commit is contained in:
parent
ba20c14e28
commit
11f13aed53
@ -31,7 +31,7 @@ from langchain.schema.messages import (
|
||||
async def aenumerate(
|
||||
iterable: AsyncIterator[Any], start: int = 0
|
||||
) -> AsyncIterator[tuple[int, Any]]:
|
||||
"""Async version of enumerate."""
|
||||
"""Async version of enumerate function."""
|
||||
i = start
|
||||
async for x in iterable:
|
||||
yield i, x
|
||||
@ -39,6 +39,14 @@ async def aenumerate(
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
|
||||
Args:
|
||||
_dict: The dictionary.
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
"""
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
@ -60,6 +68,14 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
message: The LangChain message.
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
@ -122,6 +138,8 @@ def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str
|
||||
|
||||
|
||||
class ChatCompletion:
|
||||
"""Chat completion."""
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def create(
|
||||
@ -217,7 +235,14 @@ def _has_assistant_message(session: ChatSession) -> bool:
|
||||
def convert_messages_for_finetuning(
|
||||
sessions: Iterable[ChatSession],
|
||||
) -> List[List[dict]]:
|
||||
"""Convert messages to a list of lists of dictionaries for fine-tuning."""
|
||||
"""Convert messages to a list of lists of dictionaries for fine-tuning.
|
||||
|
||||
Args:
|
||||
sessions: The chat sessions.
|
||||
|
||||
Returns:
|
||||
The list of lists of dictionaries.
|
||||
"""
|
||||
return [
|
||||
[convert_message_to_dict(s) for s in session["messages"]]
|
||||
for session in sessions
|
||||
|
@ -6,6 +6,14 @@ from langchain.schema.agent import AgentAction
|
||||
def format_xml(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
) -> str:
|
||||
"""Format the intermediate steps as XML.
|
||||
|
||||
Args:
|
||||
intermediate_steps: The intermediate steps.
|
||||
|
||||
Returns:
|
||||
The intermediate steps as XML.
|
||||
"""
|
||||
log = ""
|
||||
for action, observation in intermediate_steps:
|
||||
log += (
|
||||
|
@ -108,6 +108,17 @@ def fix_filter_directive(
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
allowed_attributes: Optional[Sequence[str]] = None,
|
||||
) -> Optional[FilterDirective]:
|
||||
"""Fix invalid filter directive.
|
||||
|
||||
Args:
|
||||
filter: Filter directive to fix.
|
||||
allowed_comparators: allowed comparators. Defaults to all comparators.
|
||||
allowed_operators: allowed operators. Defaults to all operators.
|
||||
allowed_attributes: allowed attributes. Defaults to all attributes.
|
||||
|
||||
Returns:
|
||||
Fixed filter directive.
|
||||
"""
|
||||
if (
|
||||
not (allowed_comparators or allowed_operators or allowed_attributes)
|
||||
) or not filter:
|
||||
@ -154,6 +165,14 @@ def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
|
||||
|
||||
|
||||
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
|
||||
"""Construct examples from input-output pairs.
|
||||
|
||||
Args:
|
||||
input_output_pairs: Sequence of input-output pairs.
|
||||
|
||||
Returns:
|
||||
List of examples.
|
||||
"""
|
||||
examples = []
|
||||
for i, (_input, output) in enumerate(input_output_pairs):
|
||||
structured_request = (
|
||||
@ -192,6 +211,9 @@ def get_query_constructor_prompt(
|
||||
schema_prompt: Prompt for describing query schema. Should have string input
|
||||
variables allowed_comparators and allowed_operators.
|
||||
**kwargs: Additional named params to pass to FewShotPromptTemplate init.
|
||||
|
||||
Returns:
|
||||
A prompt template that can be used to construct queries.
|
||||
"""
|
||||
default_schema_prompt = (
|
||||
SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT
|
||||
|
@ -22,6 +22,17 @@ from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatRes
|
||||
|
||||
|
||||
def get_role(message: BaseMessage) -> str:
|
||||
"""Get the role of the message.
|
||||
|
||||
Args:
|
||||
message: The message.
|
||||
|
||||
Returns:
|
||||
The role of the message.
|
||||
|
||||
Raises:
|
||||
ValueError: If the message is of an unknown type.
|
||||
"""
|
||||
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
||||
return "User"
|
||||
elif isinstance(message, AIMessage):
|
||||
@ -38,6 +49,16 @@ def get_cohere_chat_request(
|
||||
connectors: Optional[List[Dict[str, str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the request for the Cohere chat API.
|
||||
|
||||
Args:
|
||||
messages: The messages.
|
||||
connectors: The connectors.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
The request for the Cohere chat API.
|
||||
"""
|
||||
documents = (
|
||||
None
|
||||
if "source_documents" not in kwargs
|
||||
|
@ -49,6 +49,17 @@ _PDF_FILTER_WITHOUT_LOSS = [
|
||||
def extract_from_images_with_rapidocr(
|
||||
images: Sequence[Union[Iterable[np.ndarray], bytes]]
|
||||
) -> str:
|
||||
"""Extract text from images with RapidOCR.
|
||||
|
||||
Args:
|
||||
images: Images to extract text from.
|
||||
|
||||
Returns:
|
||||
Text extracted from images.
|
||||
|
||||
Raises:
|
||||
ImportError: If `rapidocr-onnxruntime` package is not installed.
|
||||
"""
|
||||
try:
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
except ImportError:
|
||||
|
@ -61,7 +61,8 @@ def create_llm_result(
|
||||
|
||||
|
||||
class Anyscale(BaseOpenAI):
|
||||
"""Wrapper around Anyscale Endpoint.
|
||||
"""Anyscale large language models.
|
||||
|
||||
To use, you should have the environment variable ``ANYSCALE_API_BASE`` and
|
||||
``ANYSCALE_API_KEY``set with your Anyscale Endpoint, or pass it as a named
|
||||
parameter to the constructor.
|
||||
|
@ -8,7 +8,8 @@ from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class NIBittensorLLM(LLM):
|
||||
"""
|
||||
"""NIBittensor LLMs
|
||||
|
||||
NIBittensorLLM is created by Neural Internet (https://neuralinternet.ai/),
|
||||
powered by Bittensor, a decentralized network full of different AI models.
|
||||
|
||||
|
@ -18,6 +18,8 @@ from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class TrainResult(TypedDict):
|
||||
"""Train result."""
|
||||
|
||||
loss: float
|
||||
|
||||
|
||||
|
@ -21,8 +21,7 @@ class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
||||
|
||||
|
||||
class JavelinAIGateway(LLM):
|
||||
"""
|
||||
Wrapper around completions LLMs in the Javelin AI Gateway.
|
||||
"""Javelin AI Gateway LLMs.
|
||||
|
||||
To use, you should have the ``javelin_sdk`` python package installed.
|
||||
For more information, see https://docs.getjavelin.io
|
||||
|
@ -83,7 +83,16 @@ DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
|
||||
def check_valid_template(
|
||||
template: str, template_format: str, input_variables: List[str]
|
||||
) -> None:
|
||||
"""Check that template string is valid."""
|
||||
"""Check that template string is valid.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
template_format: The template format. Should be one of "f-string" or "jinja2".
|
||||
input_variables: The input variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If the template format is not supported.
|
||||
"""
|
||||
if template_format not in DEFAULT_FORMATTER_MAPPING:
|
||||
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
|
||||
raise ValueError(
|
||||
@ -101,6 +110,18 @@ def check_valid_template(
|
||||
|
||||
|
||||
def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||
"""Get the variables from the template.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
template_format: The template format. Should be one of "f-string" or "jinja2".
|
||||
|
||||
Returns:
|
||||
The variables from the template.
|
||||
|
||||
Raises:
|
||||
ValueError: If the template format is not supported.
|
||||
"""
|
||||
if template_format == "jinja2":
|
||||
# Get the variables for the template
|
||||
input_variables = _get_jinja2_variables_from_template(template)
|
||||
|
@ -8,6 +8,8 @@ from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class SearchDepth(Enum):
|
||||
"""Search depth as enumerator."""
|
||||
|
||||
BASIC = "basic"
|
||||
ADVANCED = "advanced"
|
||||
|
||||
@ -31,7 +33,7 @@ class TavilySearchAPIRetriever(BaseRetriever):
|
||||
try:
|
||||
from tavily import Client
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Tavily python package not found. "
|
||||
"Please install it with `pip install tavily-python`."
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""LangChain Runnables and the LangChain Expression Language (LCEL).
|
||||
"""LangChain **Runnable** and the **LangChain Expression Language (LCEL)**.
|
||||
|
||||
The LangChain Expression Language (LCEL) offers a declarative method to build
|
||||
production-grade programs that harness the power of LLMs.
|
||||
@ -6,10 +6,10 @@ production-grade programs that harness the power of LLMs.
|
||||
Programs created using LCEL and LangChain Runnables inherently support
|
||||
synchronous, asynchronous, batch, and streaming operations.
|
||||
|
||||
Support for async allows servers hosting LCEL based programs to scale better
|
||||
Support for **async** allows servers hosting LCEL based programs to scale better
|
||||
for higher concurrent loads.
|
||||
|
||||
Streaming of intermediate outputs as they're being generated allows for
|
||||
**Streaming** of intermediate outputs as they're being generated allows for
|
||||
creating more responsive UX.
|
||||
|
||||
This module contains schema and implementation of LangChain Runnables primitives.
|
||||
|
@ -2627,6 +2627,14 @@ RunnableLike = Union[
|
||||
|
||||
|
||||
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
||||
"""Coerce a runnable-like object into a Runnable.
|
||||
|
||||
Args:
|
||||
thing: A runnable-like object.
|
||||
|
||||
Returns:
|
||||
A Runnable.
|
||||
"""
|
||||
if isinstance(thing, Runnable):
|
||||
return thing
|
||||
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
|
||||
|
@ -35,6 +35,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class EmptyDict(TypedDict, total=False):
|
||||
"""Empty dict type."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -85,6 +87,15 @@ class RunnableConfig(TypedDict, total=False):
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
"""Ensure that a config is a dict with all keys present.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig], optional): The config to ensure.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The ensured config.
|
||||
"""
|
||||
empty = RunnableConfig(
|
||||
tags=[],
|
||||
metadata={},
|
||||
@ -101,9 +112,21 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
def get_config_list(
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
"""
|
||||
Helper method to get a list of configs from a single config or a list of
|
||||
configs, useful for subclasses overriding batch() or abatch().
|
||||
"""Get a list of configs from a single config or a list of configs.
|
||||
|
||||
It is useful for subclasses overriding batch() or abatch().
|
||||
|
||||
Args:
|
||||
config (Optional[Union[RunnableConfig, List[RunnableConfig]]]):
|
||||
The config or list of configs.
|
||||
length (int): The length of the list.
|
||||
|
||||
Returns:
|
||||
List[RunnableConfig]: The list of configs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the length of the list is not equal to the length of the inputs.
|
||||
|
||||
"""
|
||||
if length < 0:
|
||||
raise ValueError(f"length must be >= 0, but got {length}")
|
||||
@ -129,9 +152,27 @@ def patch_config(
|
||||
run_name: Optional[str] = None,
|
||||
configurable: Optional[Dict[str, Any]] = None,
|
||||
) -> RunnableConfig:
|
||||
"""Patch a config with new values.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]): The config to patch.
|
||||
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
|
||||
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
|
||||
Defaults to None.
|
||||
recursion_limit (Optional[int], optional): The recursion limit to set.
|
||||
Defaults to None.
|
||||
max_concurrency (Optional[int], optional): The max concurrency to set.
|
||||
Defaults to None.
|
||||
run_name (Optional[str], optional): The run name to set. Defaults to None.
|
||||
configurable (Optional[Dict[str, Any]], optional): The configurable to set.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The patched config.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
if callbacks is not None:
|
||||
# If we're replacing callbacks we need to unset run_name
|
||||
# If we're replacing callbacks, we need to unset run_name
|
||||
# As that should apply only to the same run as the original callbacks
|
||||
config["callbacks"] = callbacks
|
||||
if "run_name" in config:
|
||||
@ -148,9 +189,17 @@ def patch_config(
|
||||
|
||||
|
||||
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
"""Merge multiple configs into one.
|
||||
|
||||
Args:
|
||||
*configs (Optional[RunnableConfig]): The configs to merge.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The merged config.
|
||||
"""
|
||||
base: RunnableConfig = {}
|
||||
# Even though the keys aren't literals this is correct
|
||||
# because both dicts are same type
|
||||
# Even though the keys aren't literals, this is correct
|
||||
# because both dicts are the same type
|
||||
for config in (c for c in configs if c is not None):
|
||||
for key in config:
|
||||
if key == "metadata":
|
||||
@ -184,7 +233,22 @@ def call_func_with_variable_args(
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
"""Call function that may optionally accept a run_manager and/or config."""
|
||||
"""Call function that may optionally accept a run_manager and/or config.
|
||||
|
||||
Args:
|
||||
func (Union[Callable[[Input], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
|
||||
The function to call.
|
||||
input (Input): The input to the function.
|
||||
run_manager (CallbackManagerForChainRun): The run manager to
|
||||
pass to the function.
|
||||
config (RunnableConfig): The config to pass to the function.
|
||||
**kwargs (Any): The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
Output: The output of the function.
|
||||
"""
|
||||
if accepts_config(func):
|
||||
if run_manager is not None:
|
||||
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
|
||||
@ -210,7 +274,22 @@ async def acall_func_with_variable_args(
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
"""Call function that may optionally accept a run_manager and/or config."""
|
||||
"""Call function that may optionally accept a run_manager and/or config.
|
||||
|
||||
Args:
|
||||
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
|
||||
AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input,
|
||||
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
|
||||
The function to call.
|
||||
input (Input): The input to the function.
|
||||
run_manager (AsyncCallbackManagerForChainRun): The run manager
|
||||
to pass to the function.
|
||||
config (RunnableConfig): The config to pass to the function.
|
||||
**kwargs (Any): The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
Output: The output of the function.
|
||||
"""
|
||||
if accepts_config(func):
|
||||
if run_manager is not None:
|
||||
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
|
||||
@ -222,6 +301,14 @@ async def acall_func_with_variable_args(
|
||||
|
||||
|
||||
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||
"""Get a callback manager for a config.
|
||||
|
||||
Args:
|
||||
config (RunnableConfig): The config.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The callback manager.
|
||||
"""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
return CallbackManager.configure(
|
||||
@ -234,6 +321,14 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||
def get_async_callback_manager_for_config(
|
||||
config: RunnableConfig,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Get an async callback manager for a config.
|
||||
|
||||
Args:
|
||||
config (RunnableConfig): The config.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The async callback manager.
|
||||
"""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
return AsyncCallbackManager.configure(
|
||||
@ -245,5 +340,13 @@ def get_async_callback_manager_for_config(
|
||||
|
||||
@contextmanager
|
||||
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
|
||||
"""Get an executor for a config.
|
||||
|
||||
Args:
|
||||
config (RunnableConfig): The config.
|
||||
|
||||
Yields:
|
||||
Generator[Executor, None, None]: The executor.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
|
||||
yield executor
|
||||
|
@ -7,6 +7,7 @@ def sanitize(
|
||||
"""
|
||||
Sanitize input string or dict of strings by replacing sensitive data with
|
||||
placeholders.
|
||||
|
||||
It returns the sanitized input string or dict of strings and the secure
|
||||
context as a dict following the format:
|
||||
{
|
||||
@ -29,6 +30,10 @@ def sanitize(
|
||||
}
|
||||
|
||||
The `secure_context` needs to be passed to the `desanitize` function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is not a string or dict of strings.
|
||||
ImportError: If the `opaqueprompts` Python package is not installed.
|
||||
"""
|
||||
try:
|
||||
import opaqueprompts as op
|
||||
|
Loading…
Reference in New Issue
Block a user