mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-13 14:21:27 +00:00
Compare commits
2 Commits
langchain-
...
nc/19jun/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e69f2396aa | ||
|
|
e59f800cea |
@@ -1,402 +0,0 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||
|
||||
T = TypeVar("T")
|
||||
Values = Dict[Union[asyncio.Event, threading.Event], Any]
|
||||
CONTEXT_CONFIG_PREFIX = "__context__/"
|
||||
CONTEXT_CONFIG_SUFFIX_GET = "/get"
|
||||
CONTEXT_CONFIG_SUFFIX_SET = "/set"
|
||||
|
||||
|
||||
async def _asetter(done: asyncio.Event, values: Values, value: T) -> T:
|
||||
values[done] = value
|
||||
done.set()
|
||||
return value
|
||||
|
||||
|
||||
async def _agetter(done: asyncio.Event, values: Values) -> Any:
|
||||
await done.wait()
|
||||
return values[done]
|
||||
|
||||
|
||||
def _setter(done: threading.Event, values: Values, value: T) -> T:
|
||||
values[done] = value
|
||||
done.set()
|
||||
return value
|
||||
|
||||
|
||||
def _getter(done: threading.Event, values: Values) -> Any:
|
||||
done.wait()
|
||||
return values[done]
|
||||
|
||||
|
||||
def _key_from_id(id_: str) -> str:
|
||||
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
|
||||
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
|
||||
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
|
||||
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
else:
|
||||
raise ValueError(f"Invalid context config id {id_}")
|
||||
|
||||
|
||||
def _config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
setter: Callable,
|
||||
getter: Callable,
|
||||
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
|
||||
) -> RunnableConfig:
|
||||
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
|
||||
return config
|
||||
|
||||
context_specs = [
|
||||
(spec, i)
|
||||
for i, step in enumerate(steps)
|
||||
for spec in step.config_specs
|
||||
if spec.id.startswith(CONTEXT_CONFIG_PREFIX)
|
||||
]
|
||||
grouped_by_key = {
|
||||
key: list(group)
|
||||
for key, group in groupby(
|
||||
sorted(context_specs, key=lambda s: s[0].id),
|
||||
key=lambda s: _key_from_id(s[0].id),
|
||||
)
|
||||
}
|
||||
deps_by_key = {
|
||||
key: set(
|
||||
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
|
||||
)
|
||||
for key, group in grouped_by_key.items()
|
||||
}
|
||||
|
||||
values: Values = {}
|
||||
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
event_cls
|
||||
)
|
||||
context_funcs: Dict[str, Callable[[], Any]] = {}
|
||||
for key, group in grouped_by_key.items():
|
||||
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
|
||||
for dep in deps_by_key[key]:
|
||||
if key in deps_by_key[dep]:
|
||||
raise ValueError(
|
||||
f"Deadlock detected between context keys {key} and {dep}"
|
||||
)
|
||||
if len(setters) != 1:
|
||||
raise ValueError(f"Expected exactly one setter for context key {key}")
|
||||
setter_idx = setters[0][1]
|
||||
if any(getter_idx < setter_idx for _, getter_idx in getters):
|
||||
raise ValueError(
|
||||
f"Context setter for key {key} must be defined after all getters."
|
||||
)
|
||||
|
||||
if getters:
|
||||
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
|
||||
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
|
||||
|
||||
return patch_config(config, configurable=context_funcs)
|
||||
|
||||
|
||||
def aconfig_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronously patch a runnable config with context getters and setters.
|
||||
|
||||
Args:
|
||||
config: The runnable config.
|
||||
steps: The runnable steps.
|
||||
|
||||
Returns:
|
||||
The patched runnable config.
|
||||
"""
|
||||
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
|
||||
|
||||
|
||||
def config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Patch a runnable config with context getters and setters.
|
||||
|
||||
Args:
|
||||
config: The runnable config.
|
||||
steps: The runnable steps.
|
||||
|
||||
Returns:
|
||||
The patched runnable config.
|
||||
"""
|
||||
return _config_with_context(config, steps, _setter, _getter, threading.Event)
|
||||
|
||||
|
||||
@beta()
|
||||
class ContextGet(RunnableSerializable):
|
||||
"""Get a context value."""
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
key: Union[str, List[str]]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextGet({_print_keys(self.key)})"
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
keys = self.key if isinstance(self.key, list) else [self.key]
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}"
|
||||
for k in keys
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
annotation=Callable[[], Any],
|
||||
)
|
||||
for id_ in self.ids
|
||||
]
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
|
||||
else:
|
||||
return configurable[self.ids[0]]()
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
|
||||
return {key: value for key, value in zip(self.key, values)}
|
||||
else:
|
||||
return await configurable[self.ids[0]]()
|
||||
|
||||
|
||||
SetValue = Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input], Awaitable[Output]],
|
||||
Any,
|
||||
]
|
||||
|
||||
|
||||
def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:
|
||||
if not isinstance(value, Runnable) and not callable(value):
|
||||
return coerce_to_runnable(lambda _: value)
|
||||
return coerce_to_runnable(value)
|
||||
|
||||
|
||||
@beta()
|
||||
class ContextSet(RunnableSerializable):
|
||||
"""Set a context value."""
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
keys: Mapping[str, Optional[Runnable]]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: Optional[str] = None,
|
||||
value: Optional[SetValue] = None,
|
||||
prefix: str = "",
|
||||
**kwargs: SetValue,
|
||||
):
|
||||
if key is not None:
|
||||
kwargs[key] = value
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
keys={
|
||||
k: _coerce_set_value(v) if v is not None else None
|
||||
for k, v in kwargs.items()
|
||||
},
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
|
||||
for key in self.keys
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
mapper_config_specs = [
|
||||
s
|
||||
for mapper in self.keys.values()
|
||||
if mapper is not None
|
||||
for s in mapper.config_specs
|
||||
]
|
||||
for spec in mapper_config_specs:
|
||||
if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET):
|
||||
getter_key = spec.id.split("/")[1]
|
||||
if getter_key in self.keys:
|
||||
raise ValueError(
|
||||
f"Circular reference in context setter for key {getter_key}"
|
||||
)
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
annotation=Callable[[], Any],
|
||||
)
|
||||
for id_ in self.ids
|
||||
]
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
if mapper is not None:
|
||||
configurable[id_](mapper.invoke(input, config))
|
||||
else:
|
||||
configurable[id_](input)
|
||||
return input
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
if mapper is not None:
|
||||
await configurable[id_](await mapper.ainvoke(input, config))
|
||||
else:
|
||||
await configurable[id_](input)
|
||||
return input
|
||||
|
||||
|
||||
class Context:
|
||||
"""
|
||||
Context for a runnable.
|
||||
|
||||
The `Context` class provides methods for creating context scopes,
|
||||
getters, and setters within a runnable. It allows for managing
|
||||
and accessing contextual information throughout the execution
|
||||
of a program.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.beta.runnables.context import Context
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
chain = (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": RunnablePassthrough()
|
||||
| Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| PromptTemplate.from_template("{context} {question}")
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"result": RunnablePassthrough(),
|
||||
"context": Context.getter("context"),
|
||||
"input": Context.getter("input"),
|
||||
}
|
||||
)
|
||||
|
||||
# Use the chain
|
||||
output = chain.invoke("What's your name?")
|
||||
print(output["result"]) # Output: "hello"
|
||||
print(output["context"]) # Output: "What's your name?"
|
||||
print(output["input"]) # Output: "What's your name?
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_scope(scope: str, /) -> "PrefixContext":
|
||||
"""Create a context scope.
|
||||
|
||||
Args:
|
||||
scope: The scope.
|
||||
|
||||
Returns:
|
||||
The context scope.
|
||||
"""
|
||||
return PrefixContext(prefix=scope)
|
||||
|
||||
@staticmethod
|
||||
def getter(key: Union[str, List[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key)
|
||||
|
||||
@staticmethod
|
||||
def setter(
|
||||
_key: Optional[str] = None,
|
||||
_value: Optional[SetValue] = None,
|
||||
/,
|
||||
**kwargs: SetValue,
|
||||
) -> ContextSet:
|
||||
return ContextSet(_key, _value, prefix="", **kwargs)
|
||||
|
||||
|
||||
class PrefixContext:
|
||||
"""Context for a runnable with a prefix."""
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
def __init__(self, prefix: str = ""):
|
||||
self.prefix = prefix
|
||||
|
||||
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key, prefix=self.prefix)
|
||||
|
||||
def setter(
|
||||
self,
|
||||
_key: Optional[str] = None,
|
||||
_value: Optional[SetValue] = None,
|
||||
/,
|
||||
**kwargs: SetValue,
|
||||
) -> ContextSet:
|
||||
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)
|
||||
|
||||
|
||||
def _print_keys(keys: Union[str, Sequence[str]]) -> str:
|
||||
if isinstance(keys, str):
|
||||
return f"'{keys}'"
|
||||
else:
|
||||
return ", ".join(f"'{k}'" for k in keys)
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
class Document(Serializable):
|
||||
@@ -21,18 +21,14 @@ class Document(Serializable):
|
||||
)
|
||||
"""
|
||||
|
||||
page_content: str
|
||||
page_content: str = field(kw_only=False)
|
||||
"""String text."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
type: Literal["Document"] = "Document"
|
||||
|
||||
def __init__(self, page_content: str, **kwargs: Any) -> None:
|
||||
"""Pass page_content in as positional or named arg."""
|
||||
super().__init__(page_content=page_content, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
"""Select examples based on length."""
|
||||
import re
|
||||
from dataclasses import field
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
||||
|
||||
|
||||
def _get_length_based(text: str) -> int:
|
||||
return len(re.split("\n| ", text))
|
||||
|
||||
|
||||
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
class LengthBasedExampleSelector(BaseExampleSelector, Serializable):
|
||||
"""Select examples based on length."""
|
||||
|
||||
examples: List[dict]
|
||||
@@ -26,7 +27,19 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
max_length: int = 2048
|
||||
"""Max length for the prompt, beyond which examples are cut."""
|
||||
|
||||
example_text_lengths: List[int] = [] #: :meta private:
|
||||
example_text_lengths: List[int] = field(default_factory=list) #: :meta private:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# validate example_text_lengths
|
||||
# Check if text lengths were passed in
|
||||
if self.example_text_lengths:
|
||||
return
|
||||
else:
|
||||
# If they were not, calculate them
|
||||
example_prompt = self.example_prompt
|
||||
get_text_length = self.get_text_length
|
||||
string_examples = [example_prompt.format(**eg) for eg in self.examples]
|
||||
self.example_text_lengths = [get_text_length(eg) for eg in string_examples]
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to list."""
|
||||
@@ -38,18 +51,6 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Add new example to list."""
|
||||
self.add_example(example)
|
||||
|
||||
@validator("example_text_lengths", always=True)
|
||||
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]:
|
||||
"""Calculate text lengths if they don't exist."""
|
||||
# Check if text lengths were passed in
|
||||
if v:
|
||||
return v
|
||||
# If they were not, calculate them
|
||||
example_prompt = values["example_prompt"]
|
||||
get_text_length = values["get_text_length"]
|
||||
string_examples = [example_prompt.format(**eg) for eg in values["examples"]]
|
||||
return [get_text_length(eg) for eg in string_examples]
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on the input lengths."""
|
||||
inputs = " ".join(input_variables.values())
|
||||
|
||||
@@ -4,13 +4,13 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
@@ -27,7 +27,6 @@ from typing import (
|
||||
from langchain_core.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.indexing.base import RecordManager
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
# Magic UUID to use as a namespace for hashing.
|
||||
@@ -55,23 +54,22 @@ def _hash_nested_dict_to_uuid(data: dict[Any, Any]) -> uuid.UUID:
|
||||
class _HashedDocument(Document):
|
||||
"""A hashed document with a unique ID."""
|
||||
|
||||
uid: str
|
||||
hash_: str
|
||||
uid: Optional[str] = None
|
||||
hash_: str = field(init=False)
|
||||
"""The hash of the document including content and metadata."""
|
||||
content_hash: str
|
||||
content_hash: str = field(init=False)
|
||||
"""The hash of the document content."""
|
||||
metadata_hash: str
|
||||
metadata_hash: str = field(init=False)
|
||||
"""The hash of the document metadata."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@root_validator(pre=True)
|
||||
def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def __post_init__(self) -> None:
|
||||
"""Root validator to calculate content and metadata hash."""
|
||||
content = values.get("page_content", "")
|
||||
metadata = values.get("metadata", {})
|
||||
content = self.page_content
|
||||
metadata = self.metadata
|
||||
|
||||
forbidden_keys = ("hash_", "content_hash", "metadata_hash")
|
||||
|
||||
@@ -92,15 +90,12 @@ class _HashedDocument(Document):
|
||||
f"Please use a dict that can be serialized using json."
|
||||
)
|
||||
|
||||
values["content_hash"] = content_hash
|
||||
values["metadata_hash"] = metadata_hash
|
||||
values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash))
|
||||
self.content_hash = content_hash
|
||||
self.metadata_hash = metadata_hash
|
||||
self.hash_ = str(_hash_string_to_uuid(content_hash + metadata_hash))
|
||||
|
||||
_uid = values.get("uid", None)
|
||||
|
||||
if _uid is None:
|
||||
values["uid"] = values["hash_"]
|
||||
return values
|
||||
if self.uid is None:
|
||||
self.uid = self.hash_
|
||||
|
||||
def to_document(self) -> Document:
|
||||
"""Return a Document object."""
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import field
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
@@ -73,7 +75,9 @@ def _get_verbosity() -> bool:
|
||||
|
||||
|
||||
class BaseLanguageModel(
|
||||
RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC
|
||||
RunnableSerializable[LanguageModelInput, LanguageModelOutputVar],
|
||||
Generic[LanguageModelOutputVar],
|
||||
ABC,
|
||||
):
|
||||
"""Abstract base class for interfacing with language models.
|
||||
|
||||
@@ -90,16 +94,16 @@ class BaseLanguageModel(
|
||||
|
||||
Caching is not currently supported for streaming methods of models.
|
||||
"""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
verbose: bool = field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callbacks: Callbacks = field(default=None, metadata={"exclude": True})
|
||||
"""Callbacks to add to the run trace."""
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
tags: Optional[List[str]] = field(default=None, metadata={"exclude": True})
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
metadata: Optional[Dict[str, Any]] = field(default=None, metadata={"exclude": True})
|
||||
"""Metadata to add to the run trace."""
|
||||
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
|
||||
default=None, exclude=True
|
||||
custom_get_token_ids: Optional[Callable[[str], List[int]]] = field(
|
||||
default=None, metadata={"exclude": True}
|
||||
)
|
||||
"""Optional encoder to use for counting tokens."""
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import inspect
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import field
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -136,7 +137,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
| `_astream` | Use to implement async version of `_stream` | Optional |
|
||||
""" # noqa: E501
|
||||
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = field(
|
||||
default=None, metadata={"exclude": True}
|
||||
)
|
||||
"""[DEPRECATED] Callback manager to add to the run trace."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import field
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
@@ -58,7 +59,6 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
@@ -224,24 +224,19 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
It should take in a prompt and return a string."""
|
||||
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = field(
|
||||
default=None, metadata={"exclude": True}
|
||||
)
|
||||
"""[DEPRECATED]"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
def __post_init__(self) -> None:
|
||||
"""Post-init method."""
|
||||
if self.callback_manager is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
self.callbacks = self.callback_manager
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
|
||||
260
libs/core/langchain_core/load/dataclass_ext.py
Normal file
260
libs/core/langchain_core/load/dataclass_ext.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import sys
|
||||
from dataclasses import MISSING, Field
|
||||
from types import FunctionType
|
||||
from typing import Iterable
|
||||
|
||||
_FIELDS = "__dataclass_fields__"
|
||||
_POST_INIT_NAME = "__post_init__"
|
||||
|
||||
|
||||
class _HAS_DEFAULT_FACTORY_CLASS:
|
||||
def __repr__(self):
|
||||
return "<factory>"
|
||||
|
||||
|
||||
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
|
||||
|
||||
|
||||
def _fields_in_init_order(fields: Iterable[Field]):
|
||||
# Returns the fields as __init__ will output them. It returns 2 tuples:
|
||||
# the first for normal args, and the second for keyword args.
|
||||
|
||||
return (
|
||||
tuple(f for f in fields if f.init and not f.kw_only),
|
||||
tuple(f for f in fields if f.init and f.kw_only),
|
||||
)
|
||||
|
||||
|
||||
def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING):
|
||||
# Note that we may mutate locals. Callers beware!
|
||||
# The only callers are internal to this module, so no
|
||||
# worries about external callers.
|
||||
if locals is None:
|
||||
locals = {}
|
||||
return_annotation = ""
|
||||
if return_type is not MISSING:
|
||||
locals["_return_type"] = return_type
|
||||
return_annotation = "->_return_type"
|
||||
args = ",".join(args)
|
||||
body = "\n".join(f" {b}" for b in body)
|
||||
|
||||
# Compute the text of the entire function.
|
||||
txt = f" def {name}({args}){return_annotation}:\n{body}"
|
||||
|
||||
local_vars = ", ".join(locals.keys())
|
||||
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
|
||||
ns = {}
|
||||
exec(txt, globals, ns)
|
||||
return ns["__create_fn__"](**locals)
|
||||
|
||||
|
||||
def _field_assign(frozen, name, value, self_name):
|
||||
# If we're a frozen class, then assign to our fields in __init__
|
||||
# via object.__setattr__. Otherwise, just use a simple
|
||||
# assignment.
|
||||
#
|
||||
# self_name is what "self" is called in this function: don't
|
||||
# hard-code "self", since that might be a field name.
|
||||
if frozen:
|
||||
return (
|
||||
f"__dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})" # noqa: E501
|
||||
)
|
||||
return f"{self_name}.{name}={value}"
|
||||
|
||||
|
||||
def _field_init(f, frozen, globals, self_name, slots):
|
||||
# Return the text of the line in the body of __init__ that will
|
||||
# initialize this field.
|
||||
|
||||
default_name = f"_dflt_{f.name}"
|
||||
if f.default_factory is not MISSING:
|
||||
if f.init:
|
||||
# This field has a default factory. If a parameter is
|
||||
# given, use it. If not, call the factory.
|
||||
globals[default_name] = f.default_factory
|
||||
value = (
|
||||
f"{default_name}() "
|
||||
f"if {f.name} is _HAS_DEFAULT_FACTORY "
|
||||
f"else {f.name}"
|
||||
)
|
||||
else:
|
||||
# This is a field that's not in the __init__ params, but
|
||||
# has a default factory function. It needs to be
|
||||
# initialized here by calling the factory function,
|
||||
# because there's no other way to initialize it.
|
||||
|
||||
# For a field initialized with a default=defaultvalue, the
|
||||
# class dict just has the default value
|
||||
# (cls.fieldname=defaultvalue). But that won't work for a
|
||||
# default factory, the factory must be called in __init__
|
||||
# and we must assign that to self.fieldname. We can't
|
||||
# fall back to the class dict's value, both because it's
|
||||
# not set, and because it might be different per-class
|
||||
# (which, after all, is why we have a factory function!).
|
||||
|
||||
globals[default_name] = f.default_factory
|
||||
value = f"{default_name}()"
|
||||
else:
|
||||
# No default factory.
|
||||
if f.init:
|
||||
if f.default is MISSING:
|
||||
# There's no default, just do an assignment.
|
||||
value = f.name
|
||||
elif f.default is not MISSING:
|
||||
globals[default_name] = f.default
|
||||
value = f.name
|
||||
else:
|
||||
# If the class has slots, then initialize this field.
|
||||
if slots and f.default is not MISSING:
|
||||
globals[default_name] = f.default
|
||||
value = default_name
|
||||
else:
|
||||
# This field does not need initialization: reading from it will
|
||||
# just use the class attribute that contains the default.
|
||||
# Signify that to the caller by returning None.
|
||||
return None
|
||||
|
||||
# Only test this now, so that we can create variables for the
|
||||
# default. However, return None to signify that we're not going
|
||||
# to actually do the assignment statement for InitVars.
|
||||
if f._field_type.name == "_FIELD_INITVAR":
|
||||
return None
|
||||
|
||||
# Now, actually generate the field assignment.
|
||||
return _field_assign(frozen, f.name, value, self_name)
|
||||
|
||||
|
||||
def _init_param(f):
|
||||
# Return the __init__ parameter string for this field. For
|
||||
# example, the equivalent of 'x:int=3' (except instead of 'int',
|
||||
# reference a variable set to int, and instead of '3', reference a
|
||||
# variable set to 3).
|
||||
if f.default is MISSING and f.default_factory is MISSING:
|
||||
# There's no default, and no default_factory, just output the
|
||||
# variable name and type.
|
||||
default = ""
|
||||
elif f.default is not MISSING:
|
||||
# There's a default, this will be the name that's used to look
|
||||
# it up.
|
||||
default = f"=_dflt_{f.name}"
|
||||
elif f.default_factory is not MISSING:
|
||||
# There's a factory function. Set a marker.
|
||||
default = "=_HAS_DEFAULT_FACTORY"
|
||||
return f"{f.name}:_type_{f.name}{default}"
|
||||
|
||||
|
||||
def _init_fn(
|
||||
fields, std_fields, kw_only_fields, frozen, has_post_init, self_name, globals, slots
|
||||
):
|
||||
# fields contains both real fields and InitVar pseudo-fields.
|
||||
|
||||
# Make sure we don't have fields without defaults following fields
|
||||
# with defaults. This actually would be caught when exec-ing the
|
||||
# function source code, but catching it here gives a better error
|
||||
# message, and future-proofs us in case we build up the function
|
||||
# using ast.
|
||||
|
||||
seen_default = False
|
||||
for f in std_fields:
|
||||
# Only consider the non-kw-only fields in the __init__ call.
|
||||
if f.init:
|
||||
if not (f.default is MISSING and f.default_factory is MISSING):
|
||||
seen_default = True
|
||||
elif seen_default:
|
||||
raise TypeError(
|
||||
f"non-default argument {f.name!r} " "follows default argument"
|
||||
)
|
||||
|
||||
locals = {f"_type_{f.name}": f.type for f in fields}
|
||||
locals.update(
|
||||
{
|
||||
"MISSING": MISSING,
|
||||
"_HAS_DEFAULT_FACTORY": _HAS_DEFAULT_FACTORY,
|
||||
"__dataclass_builtins_object__": object,
|
||||
}
|
||||
)
|
||||
|
||||
body_lines = []
|
||||
for f in fields:
|
||||
line = _field_init(f, frozen, locals, self_name, slots)
|
||||
# line is None means that this field doesn't require
|
||||
# initialization (it's a pseudo-field). Just skip it.
|
||||
if line:
|
||||
body_lines.append(line)
|
||||
|
||||
# Does this class have a post-init function?
|
||||
if has_post_init:
|
||||
params_str = ",".join(
|
||||
f.name for f in fields if f._field_type.name == "_FIELD_INITVAR"
|
||||
)
|
||||
body_lines.append(f"{self_name}.{_POST_INIT_NAME}({params_str})")
|
||||
|
||||
# If no body lines, use 'pass'.
|
||||
if not body_lines:
|
||||
body_lines = ["pass"]
|
||||
|
||||
_init_params = [_init_param(f) for f in std_fields]
|
||||
if kw_only_fields:
|
||||
# Add the keyword-only args. Because the * can only be added if
|
||||
# there's at least one keyword-only arg, there needs to be a test here
|
||||
# (instead of just concatenting the lists together).
|
||||
_init_params += ["*"]
|
||||
_init_params += [_init_param(f) for f in kw_only_fields]
|
||||
return _create_fn(
|
||||
"__init__",
|
||||
[self_name] + _init_params,
|
||||
body_lines,
|
||||
locals=locals,
|
||||
globals=globals,
|
||||
return_type=None,
|
||||
)
|
||||
|
||||
|
||||
def _set_qualname(cls, value):
|
||||
# Ensure that the functions returned from _create_fn uses the proper
|
||||
# __qualname__ (the class they belong to).
|
||||
if isinstance(value, FunctionType):
|
||||
value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
|
||||
return value
|
||||
|
||||
|
||||
def _set_new_attribute(cls, name, value):
|
||||
_set_qualname(cls, value)
|
||||
setattr(cls, name, value)
|
||||
return False
|
||||
|
||||
|
||||
def set_init(cls) -> None:
|
||||
if cls.__module__ in sys.modules:
|
||||
globals = sys.modules[cls.__module__].__dict__
|
||||
else:
|
||||
# Theoretically this can happen if someone writes
|
||||
# a custom string to cls.__module__. In which case
|
||||
# such dataclass won't be fully introspectable
|
||||
# (w.r.t. typing.get_type_hints) but will still function
|
||||
# correctly.
|
||||
globals = {}
|
||||
|
||||
all_fields = getattr(cls, _FIELDS, None)
|
||||
all_init_fields = tuple(
|
||||
f for f in all_fields.values() if f._field_type.name != "_FIELD_CLASSVAR"
|
||||
)
|
||||
std_init_fields, kw_only_init_fields = _fields_in_init_order(all_init_fields)
|
||||
has_post_init = hasattr(cls, _POST_INIT_NAME)
|
||||
_set_new_attribute(
|
||||
cls,
|
||||
"__default_init__",
|
||||
_init_fn(
|
||||
all_init_fields,
|
||||
std_init_fields,
|
||||
kw_only_init_fields,
|
||||
False,
|
||||
has_post_init,
|
||||
# The name to use for the "self"
|
||||
# param in __init__. Use "self"
|
||||
# if possible.
|
||||
"__dataclass_self__" if "self" in all_fields else "self",
|
||||
globals,
|
||||
False,
|
||||
),
|
||||
)
|
||||
@@ -1,10 +1,16 @@
|
||||
from abc import ABC
|
||||
from collections import deque
|
||||
from dataclasses import MISSING, dataclass, fields, replace
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
@@ -12,7 +18,7 @@ from typing import (
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.load.dataclass_ext import set_init
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
@@ -44,24 +50,7 @@ class SerializedNotImplemented(BaseSerialized):
|
||||
repr: Optional[str]
|
||||
|
||||
|
||||
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
"""Try to determine if a value is different from the default.
|
||||
|
||||
Args:
|
||||
value: The value.
|
||||
key: The key.
|
||||
model: The model.
|
||||
|
||||
Returns:
|
||||
Whether the value is different from the default.
|
||||
"""
|
||||
try:
|
||||
return model.__fields__[key].get_default() != value
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
class Serializable(ABC):
|
||||
"""Serializable base class.
|
||||
|
||||
This class is used to serialize objects to JSON.
|
||||
@@ -83,6 +72,48 @@ class Serializable(BaseModel, ABC):
|
||||
as part of the serialized representation..
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
super().__init_subclass__()
|
||||
dataclass(kw_only=True)(cls)
|
||||
set_init(cls)
|
||||
|
||||
def __default_init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, Any]]:
|
||||
return iter(self.__dict__.items())
|
||||
|
||||
def copy(
|
||||
self, deep: Optional[bool] = False, update: Optional[Dict[str, Any]] = None
|
||||
) -> "Serializable":
|
||||
"""Create a copy of the object."""
|
||||
if deep:
|
||||
copied = {
|
||||
k: v.copy() if hasattr(v, "copy") and callable(v.copy) else v
|
||||
for k, v in self.__dict__.items()
|
||||
}
|
||||
else:
|
||||
copied = {}
|
||||
return replace(self, **{**copied, **(update or {})})
|
||||
|
||||
def dict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
|
||||
"""Return the object as a dictionary."""
|
||||
|
||||
def convert(v: Any) -> Any:
|
||||
if hasattr(v, "dict") and callable(v.dict):
|
||||
return v.dict()
|
||||
if _sequence_like(v):
|
||||
return v.__class__(convert(x) for x in v)
|
||||
if isinstance(v, dict):
|
||||
return {k: convert(x) for k, x in v.items()}
|
||||
return v
|
||||
|
||||
return {
|
||||
k: convert(v)
|
||||
for k, v in self.__dict__.items()
|
||||
if exclude is None or k not in exclude
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Is this class serializable?"""
|
||||
@@ -123,16 +154,6 @@ class Serializable(BaseModel, ABC):
|
||||
"""
|
||||
return [*cls.get_lc_namespace(), cls.__name__]
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or try_neq_default(v, k, self))
|
||||
]
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
if not self.is_lc_serializable():
|
||||
return self.to_json_not_implemented()
|
||||
@@ -140,10 +161,7 @@ class Serializable(BaseModel, ABC):
|
||||
secrets = dict()
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {
|
||||
k: getattr(self, k, v)
|
||||
for k, v in self
|
||||
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
|
||||
and _is_field_useful(self, k, v)
|
||||
k: getattr(self, k, v) for k, v in self if _is_field_useful(self, k, v)
|
||||
}
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
@@ -211,10 +229,18 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
Returns:
|
||||
Whether the field is useful.
|
||||
"""
|
||||
field = inst.__fields__.get(key)
|
||||
field = next((f for f in fields(inst) if f.name == key), None)
|
||||
if not field:
|
||||
return False
|
||||
return field.required is True or value or field.get_default() != value
|
||||
if field.metadata.get("exclude"):
|
||||
return False
|
||||
if not field.init:
|
||||
return False
|
||||
if field.default is not MISSING:
|
||||
return value != field.default
|
||||
if field.default_factory is not MISSING:
|
||||
return value != field.default_factory()
|
||||
return True
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
@@ -267,3 +293,27 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, deque)) and not _is_namedtuple(
|
||||
type(v)
|
||||
)
|
||||
|
||||
|
||||
def _is_namedtuple(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`
|
||||
"""
|
||||
|
||||
return _lenient_issubclass(type_, tuple) and hasattr(type_, "_fields")
|
||||
|
||||
|
||||
def _lenient_issubclass(
|
||||
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]
|
||||
) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
@@ -15,7 +16,6 @@ from langchain_core.messages.tool import (
|
||||
default_tool_chunk_parser,
|
||||
default_tool_parser,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import (
|
||||
parse_partial_json,
|
||||
@@ -62,9 +62,9 @@ class AIMessage(BaseMessage):
|
||||
At the moment, this is ignored by most models. Usage is discouraged.
|
||||
"""
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
"""If provided, tool calls associated with the message."""
|
||||
invalid_tool_calls: List[InvalidToolCall] = []
|
||||
invalid_tool_calls: List[InvalidToolCall] = field(default_factory=list)
|
||||
"""If provided, tool calls with parsing errors associated with the message."""
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
"""If provided, usage metadata for a message, such as token counts.
|
||||
@@ -75,12 +75,6 @@ class AIMessage(BaseMessage):
|
||||
type: Literal["ai"] = "ai"
|
||||
"""The type of the message (used for deserialization)."""
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg."""
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
@@ -94,27 +88,23 @@ class AIMessage(BaseMessage):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator(pre=True)
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
||||
def __post_init__(self) -> None:
|
||||
raw_tool_calls = (self.additional_kwargs or {}).get("tool_calls")
|
||||
tool_calls = (
|
||||
values.get("tool_calls")
|
||||
or values.get("invalid_tool_calls")
|
||||
or values.get("tool_call_chunks")
|
||||
self.tool_calls
|
||||
or self.invalid_tool_calls
|
||||
or getattr(self, "tool_call_chunks", None)
|
||||
)
|
||||
if raw_tool_calls and not tool_calls:
|
||||
try:
|
||||
if issubclass(cls, AIMessageChunk): # type: ignore
|
||||
values["tool_call_chunks"] = default_tool_chunk_parser(
|
||||
raw_tool_calls
|
||||
)
|
||||
if isinstance(self, AIMessageChunk): # type: ignore
|
||||
self.tool_call_chunks = default_tool_chunk_parser(raw_tool_calls)
|
||||
else:
|
||||
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
|
||||
values["tool_calls"] = tool_calls
|
||||
values["invalid_tool_calls"] = invalid_tool_calls
|
||||
self.tool_calls = tool_calls
|
||||
self.invalid_tool_calls = invalid_tool_calls
|
||||
except Exception:
|
||||
pass
|
||||
return values
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
"""Return a pretty representation of the message."""
|
||||
@@ -148,9 +138,6 @@ class AIMessage(BaseMessage):
|
||||
return (base.strip() + "\n" + "\n".join(lines)).strip()
|
||||
|
||||
|
||||
AIMessage.update_forward_refs()
|
||||
|
||||
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"""Message chunk from an AI."""
|
||||
|
||||
@@ -159,7 +146,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
# non-chunk variant.
|
||||
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment]
|
||||
|
||||
tool_call_chunks: List[ToolCallChunk] = []
|
||||
tool_call_chunks: List[ToolCallChunk] = field(default_factory=list)
|
||||
"""If provided, tool call chunks associated with the message."""
|
||||
|
||||
@classmethod
|
||||
@@ -175,35 +162,36 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def init_tool_calls(cls, values: dict) -> dict:
|
||||
if not values["tool_call_chunks"]:
|
||||
if values["tool_calls"]:
|
||||
values["tool_call_chunks"] = [
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
if not self.tool_call_chunks:
|
||||
if self.tool_calls:
|
||||
self.tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
name=tc["name"],
|
||||
args=json.dumps(tc["args"]),
|
||||
id=tc["id"],
|
||||
index=None,
|
||||
)
|
||||
for tc in values["tool_calls"]
|
||||
for tc in self.tool_calls
|
||||
]
|
||||
if values["invalid_tool_calls"]:
|
||||
tool_call_chunks = values.get("tool_call_chunks", [])
|
||||
if self.invalid_tool_calls:
|
||||
tool_call_chunks = self.tool_call_chunks or []
|
||||
tool_call_chunks.extend(
|
||||
[
|
||||
ToolCallChunk(
|
||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
||||
)
|
||||
for tc in values["invalid_tool_calls"]
|
||||
for tc in self.invalid_tool_calls
|
||||
]
|
||||
)
|
||||
values["tool_call_chunks"] = tool_call_chunks
|
||||
self.tool_call_chunks = tool_call_chunks
|
||||
|
||||
return values
|
||||
return
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
for chunk in values["tool_call_chunks"]:
|
||||
for chunk in self.tool_call_chunks:
|
||||
try:
|
||||
args_ = parse_partial_json(chunk["args"])
|
||||
if isinstance(args_, dict):
|
||||
@@ -225,9 +213,8 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
values["tool_calls"] = tool_calls
|
||||
values["invalid_tool_calls"] = invalid_tool_calls
|
||||
return values
|
||||
self.tool_calls = tool_calls
|
||||
self.invalid_tool_calls = invalid_tool_calls
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, AIMessageChunk):
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
@@ -18,17 +18,17 @@ class BaseMessage(Serializable):
|
||||
Messages are the inputs and outputs of ChatModels.
|
||||
"""
|
||||
|
||||
content: Union[str, List[Union[str, Dict]]]
|
||||
content: Union[str, List[Union[str, Dict]]] = field(kw_only=False)
|
||||
"""The string contents of the message."""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
additional_kwargs: dict = field(default_factory=dict)
|
||||
"""Reserved for additional payload data associated with the message.
|
||||
|
||||
For example, for a message from an AI, this could include tool calls as
|
||||
encoded by the model provider.
|
||||
"""
|
||||
|
||||
response_metadata: dict = Field(default_factory=dict)
|
||||
response_metadata: dict = field(default_factory=dict)
|
||||
"""Response metadata. For example: response headers, logprobs, token counts."""
|
||||
|
||||
type: str
|
||||
@@ -51,15 +51,6 @@ class BaseMessage(Serializable):
|
||||
"""An optional unique identifier for the message. This should ideally be
|
||||
provided by the provider/model which created the message."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg."""
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
|
||||
@@ -23,9 +23,6 @@ class ChatMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
ChatMessage.update_forward_refs()
|
||||
|
||||
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
"""Chat Message chunk."""
|
||||
|
||||
|
||||
@@ -31,9 +31,6 @@ class FunctionMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
FunctionMessage.update_forward_refs()
|
||||
|
||||
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
"""Function Message chunk."""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import List, Literal
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
@@ -41,15 +41,6 @@ class HumanMessage(BaseMessage):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg."""
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
HumanMessage.update_forward_refs()
|
||||
|
||||
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
"""Human Message chunk."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import List, Literal
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
@@ -36,15 +36,6 @@ class SystemMessage(BaseMessage):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg."""
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
SystemMessage.update_forward_refs()
|
||||
|
||||
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
"""System Message chunk."""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -43,15 +43,6 @@ class ToolMessage(BaseMessage):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg."""
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
ToolMessage.update_forward_refs()
|
||||
|
||||
|
||||
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
"""Tool Message chunk."""
|
||||
|
||||
@@ -184,19 +184,10 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
determine which schema to use.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_schema(cls, values: Dict) -> Dict:
|
||||
schema = values["pydantic_schema"]
|
||||
if "args_only" not in values:
|
||||
values["args_only"] = isinstance(schema, type) and issubclass(
|
||||
schema, BaseModel
|
||||
)
|
||||
elif values["args_only"] and isinstance(schema, Dict):
|
||||
raise ValueError(
|
||||
"If multiple pydantic schemas are provided then args_only should be"
|
||||
" False."
|
||||
)
|
||||
return values
|
||||
def __post_init__(self) -> None:
|
||||
schema = self.pydantic_schema
|
||||
if self.args_only:
|
||||
self.args_only = isinstance(schema, type) and issubclass(schema, BaseModel)
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
_result = super().parse_result(result)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal
|
||||
from typing import List, Literal
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
@@ -30,17 +29,16 @@ class ChatGeneration(Generation):
|
||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def __post_init__(self) -> None:
|
||||
"""Set the text attribute to be the contents of the message."""
|
||||
try:
|
||||
text = ""
|
||||
if isinstance(values["message"].content, str):
|
||||
text = values["message"].content
|
||||
if isinstance(self.message.content, str):
|
||||
text = self.message.content
|
||||
# HACK: Assumes text in content blocks in OpenAI format.
|
||||
# Uses first text block.
|
||||
elif isinstance(values["message"].content, list):
|
||||
for block in values["message"].content:
|
||||
elif isinstance(self.message.content, list):
|
||||
for block in self.message.content:
|
||||
if isinstance(block, str):
|
||||
text = block
|
||||
break
|
||||
@@ -51,10 +49,9 @@ class ChatGeneration(Generation):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
values["text"] = text
|
||||
self.text = text
|
||||
except (KeyError, AttributeError) as e:
|
||||
raise ValueError("Error while initializing ChatGeneration") from e
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class ChatResult(BaseModel):
|
||||
class ChatResult(Serializable):
|
||||
"""Use to represent the result of a chat model call with a single prompt.
|
||||
|
||||
This container is used internally by some implementations of chat model,
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import field
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.outputs.run_info import RunInfo
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
class LLMResult(Serializable):
|
||||
"""A container for results of an LLM call.
|
||||
|
||||
Both chat models and LLMs generate an LLMResult object. This object contains
|
||||
@@ -41,7 +42,8 @@ class LLMResult(BaseModel):
|
||||
accessing relevant information from standardized fields present in
|
||||
AIMessage.
|
||||
"""
|
||||
run: Optional[List[RunInfo]] = None
|
||||
|
||||
run: Optional[List[RunInfo]] = field(compare=False, default=None)
|
||||
"""List of metadata info for model call for each input."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
@@ -79,12 +81,3 @@ class LLMResult(BaseModel):
|
||||
)
|
||||
)
|
||||
return llm_results
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.generations == other.generations
|
||||
and self.llm_output == other.llm_output
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -25,7 +26,7 @@ from langchain_core.prompt_values import (
|
||||
PromptValue,
|
||||
StringPromptValue,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.config import ensure_config
|
||||
from langchain_core.runnables.utils import create_model
|
||||
@@ -42,14 +43,14 @@ class BasePromptTemplate(
|
||||
):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
input_variables: List[str] = field(default_factory=list)
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
input_types: Dict[str, Any] = Field(default_factory=dict)
|
||||
input_types: Dict[str, Any] = field(default_factory=dict)
|
||||
"""A dictionary of the types of the variables the prompt template expects.
|
||||
If not provided, all variables are assumed to be strings."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||
partial_variables: Mapping[str, Any] = field(default_factory=dict)
|
||||
"""A dictionary of the partial variables the prompt template carries.
|
||||
|
||||
Partial variables populate the template so that you don't need to
|
||||
@@ -59,28 +60,22 @@ class BasePromptTemplate(
|
||||
tags: Optional[List[str]] = None
|
||||
"""Tags to be used for tracing."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
if "stop" in self.input_variables:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
if "stop" in self.partial_variables:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
if overall:
|
||||
if overall := set(self.input_variables).intersection(self.partial_variables):
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
@@ -92,11 +87,6 @@ class BasePromptTemplate(
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -38,7 +39,6 @@ from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.image import ImagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
@@ -46,6 +46,8 @@ from langchain_core.utils.interactive_env import is_interactive_env
|
||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
"""Base class for message prompt templates."""
|
||||
|
||||
input_variables: List[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
@@ -78,15 +80,6 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
"""
|
||||
return self.format_messages(**kwargs)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_variables(self) -> List[str]:
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variables.
|
||||
"""
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
"""Human-readable representation."""
|
||||
raise NotImplementedError
|
||||
@@ -162,7 +155,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
# ])
|
||||
"""
|
||||
|
||||
variable_name: str
|
||||
variable_name: str = field(kw_only=False)
|
||||
"""Name of variable to use as messages."""
|
||||
|
||||
optional: bool = False
|
||||
@@ -170,14 +163,15 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
list. If False then a named argument with name `variable_name` must be passed
|
||||
in, even if the value is an empty list."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.input_variables:
|
||||
self.input_variables = [self.variable_name] if not self.optional else []
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
||||
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
@@ -199,15 +193,6 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
)
|
||||
return convert_to_messages(value)
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return [self.variable_name] if not self.optional else []
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
var = "{" + self.variable_name + "}"
|
||||
if html:
|
||||
@@ -229,9 +214,13 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
|
||||
prompt: StringPromptTemplate
|
||||
"""String prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
additional_kwargs: dict = field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.input_variables:
|
||||
self.input_variables = self.prompt.input_variables
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
@@ -323,16 +312,6 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
return [await self.aformat(**kwargs)]
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""
|
||||
Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
# TODO: Handle partials
|
||||
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")
|
||||
@@ -392,11 +371,18 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
|
||||
]
|
||||
"""Prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
additional_kwargs: dict = field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
_msg_class: Type[BaseMessage]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.input_variables:
|
||||
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
|
||||
self.input_variables = [
|
||||
iv for prompt in prompts for iv in prompt.input_variables
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
@@ -519,18 +505,6 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
return [await self.aformat(**kwargs)]
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""
|
||||
Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
|
||||
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
|
||||
return input_variables
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
@@ -809,8 +783,6 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
input_variables: List[str]
|
||||
"""List of input variables in template messages. Used for validation."""
|
||||
messages: List[MessageLike]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
@@ -846,8 +818,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_input_variables(cls, values: dict) -> dict:
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate input variables.
|
||||
|
||||
If input_variables is not set, it will be set to the union of
|
||||
@@ -859,28 +830,25 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
Returns:
|
||||
Validated values.
|
||||
"""
|
||||
messages = values["messages"]
|
||||
messages = self.messages
|
||||
input_vars = set()
|
||||
input_types: Dict[str, Any] = values.get("input_types", {})
|
||||
for message in messages:
|
||||
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
|
||||
input_vars.update(message.input_variables)
|
||||
if isinstance(message, MessagesPlaceholder):
|
||||
if message.variable_name not in input_types:
|
||||
input_types[message.variable_name] = List[AnyMessage]
|
||||
if "partial_variables" in values:
|
||||
input_vars = input_vars - set(values["partial_variables"])
|
||||
if "input_variables" in values and values.get("validate_template"):
|
||||
if input_vars != set(values["input_variables"]):
|
||||
if message.variable_name not in self.input_types:
|
||||
self.input_types[message.variable_name] = List[AnyMessage]
|
||||
if self.partial_variables:
|
||||
input_vars = input_vars - set(self.partial_variables)
|
||||
if self.validate_template:
|
||||
if input_vars != set(self.input_variables):
|
||||
raise ValueError(
|
||||
"Got mismatched input_variables. "
|
||||
f"Expected: {input_vars}. "
|
||||
f"Got: {values['input_variables']}"
|
||||
f"Got: {self.input_variables}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(input_vars)
|
||||
values["input_types"] = input_types
|
||||
return values
|
||||
self.input_variables = sorted(input_vars)
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
@@ -18,43 +19,11 @@ from langchain_core.prompts.string import (
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
class _FewShotPromptTemplateMixin(BaseModel):
|
||||
class _FewShotPromptTemplateMixin:
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Optional[BaseExampleSelector] = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
if examples and example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
if examples is None and example_selector is None:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
"""Get the examples to use for formatting the prompt.
|
||||
|
||||
@@ -92,9 +61,42 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
class FewShotPromptTemplate(StringPromptTemplate, _FewShotPromptTemplateMixin):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Optional[BaseExampleSelector] = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
if self.examples and self.example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
if self.examples is None and self.example_selector is None:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
if self.validate_template:
|
||||
check_valid_template(
|
||||
self.prefix + self.suffix,
|
||||
self.template_format,
|
||||
self.input_variables + list(self.partial_variables),
|
||||
)
|
||||
|
||||
self.input_variables = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
self.prefix + self.suffix, self.template_format
|
||||
)
|
||||
if var not in self.partial_variables
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
@@ -103,9 +105,6 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""PromptTemplate used to format an individual example."""
|
||||
|
||||
@@ -121,31 +120,6 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
template_format: Literal["f-string", "jinja2"] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
check_valid_template(
|
||||
values["prefix"] + values["suffix"],
|
||||
values["template_format"],
|
||||
values["input_variables"] + list(values["partial_variables"]),
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["prefix"] + values["suffix"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
# Get the examples to use.
|
||||
@@ -309,23 +283,33 @@ class FewShotChatMessagePromptTemplate(
|
||||
chain.invoke({"input": "What's 3+3?"})
|
||||
"""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Optional[BaseExampleSelector] = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
if self.examples and self.example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
if self.examples is None and self.example_selector is None:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return False
|
||||
|
||||
input_variables: List[str] = Field(default_factory=list)
|
||||
"""A list of the names of the variables the prompt template will use
|
||||
to pass to the example_selector, if provided."""
|
||||
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
|
||||
"""The class to format each example."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages.
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
|
||||
|
||||
class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
@@ -27,9 +26,6 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
suffix: StringPromptTemplate
|
||||
"""A PromptTemplate to put after the examples."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
example_separator: str = "\n\n"
|
||||
"""String separator used to join the prefix, the examples, and suffix."""
|
||||
|
||||
@@ -47,32 +43,25 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "few_shot_with_templates"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
def __post_init__(self) -> None:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
if examples and example_selector:
|
||||
if self.examples and self.example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
if examples is None and example_selector is None:
|
||||
if self.examples is None and self.example_selector is None:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
input_variables = values["input_variables"]
|
||||
expected_input_variables = set(values["suffix"].input_variables)
|
||||
expected_input_variables |= set(values["partial_variables"])
|
||||
if values["prefix"] is not None:
|
||||
expected_input_variables |= set(values["prefix"].input_variables)
|
||||
if self.validate_template:
|
||||
input_variables = self.input_variables
|
||||
expected_input_variables = set(self.suffix.input_variables)
|
||||
expected_input_variables |= set(self.partial_variables)
|
||||
if self.prefix is not None:
|
||||
expected_input_variables |= set(self.prefix.input_variables)
|
||||
missing_vars = expected_input_variables.difference(input_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
@@ -80,18 +69,11 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
f"prefix/suffix expected {expected_input_variables}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(
|
||||
set(values["suffix"].input_variables)
|
||||
| set(values["prefix"].input_variables if values["prefix"] else [])
|
||||
- set(values["partial_variables"])
|
||||
self.input_variables = sorted(
|
||||
set(self.suffix.input_variables)
|
||||
| set(self.prefix.input_variables if self.prefix else [])
|
||||
- set(self.partial_variables)
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
if self.examples is not None:
|
||||
|
||||
@@ -10,21 +10,19 @@ from langchain_core.utils import image as image_utils
|
||||
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
"""Image prompt template for a multimodal model."""
|
||||
|
||||
input_variables: List[str] = Field(default_factory=list)
|
||||
|
||||
template: dict = Field(default_factory=dict)
|
||||
"""Template for the prompt."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
if "input_variables" not in kwargs:
|
||||
kwargs["input_variables"] = []
|
||||
|
||||
overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail"))
|
||||
def __post_init__(self) -> None:
|
||||
overlap = set(self.input_variables) & set(("url", "path", "detail"))
|
||||
if overlap:
|
||||
raise ValueError(
|
||||
"input_variables for the image template cannot contain"
|
||||
" any of 'url', 'path', or 'detail'."
|
||||
f" Found: {overlap}"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
|
||||
@@ -33,16 +33,14 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "pipeline"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_input_variables(cls, values: Dict) -> Dict:
|
||||
def __post_init__(self) -> None:
|
||||
"""Get input variables."""
|
||||
created_variables = set()
|
||||
all_variables = set()
|
||||
for k, prompt in values["pipeline_prompts"]:
|
||||
for k, prompt in self.pipeline_prompts:
|
||||
created_variables.add(k)
|
||||
all_variables.update(prompt.input_variables)
|
||||
values["input_variables"] = list(all_variables.difference(created_variables))
|
||||
return values
|
||||
self.input_variables = list(all_variables.difference(created_variables))
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
for k, prompt in self.pipeline_prompts:
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain_core.prompts.string import (
|
||||
get_template_variables,
|
||||
mustache_schema,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
@@ -74,43 +74,22 @@ class PromptTemplate(StringPromptTemplate):
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_init_validation(cls, values: Dict) -> Dict:
|
||||
def __post_init__(self) -> None:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values.get("template") is None:
|
||||
# Will let pydantic fail with a ValidationError if template
|
||||
# is not provided.
|
||||
return values
|
||||
|
||||
# Set some default values based on the field defaults
|
||||
values.setdefault("template_format", "f-string")
|
||||
values.setdefault("partial_variables", {})
|
||||
|
||||
if values.get("validate_template"):
|
||||
if values["template_format"] == "mustache":
|
||||
if self.validate_template:
|
||||
if self.template_format == "mustache":
|
||||
raise ValueError("Mustache templates cannot be validated.")
|
||||
|
||||
if "input_variables" not in values:
|
||||
raise ValueError(
|
||||
"Input variables must be provided to validate the template."
|
||||
)
|
||||
all_inputs = self.input_variables + list(self.partial_variables)
|
||||
check_valid_template(self.template, self.template_format, all_inputs)
|
||||
|
||||
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
||||
check_valid_template(
|
||||
values["template"], values["template_format"], all_inputs
|
||||
)
|
||||
|
||||
if values["template_format"]:
|
||||
values["input_variables"] = [
|
||||
if self.template_format:
|
||||
self.input_variables = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["template"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
for var in get_template_variables(self.template, self.template_format)
|
||||
if var not in self.partial_variables
|
||||
]
|
||||
|
||||
return values
|
||||
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
if self.template_format != "mustache":
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@@ -7,8 +7,9 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
from contextvars import copy_context
|
||||
from dataclasses import field, fields
|
||||
from functools import wraps
|
||||
from itertools import groupby, tee
|
||||
from itertools import tee
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -25,7 +26,6 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -2061,7 +2061,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
from langchain_core.runnables.configurable import RunnableConfigurableFields
|
||||
|
||||
for key in kwargs:
|
||||
if key not in self.__fields__:
|
||||
if key not in [f.name for f in fields(self)]:
|
||||
raise ValueError(
|
||||
f"Configuration key {key} not found in {self}: "
|
||||
f"available keys are {self.__fields__.keys()}"
|
||||
@@ -2275,7 +2275,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# the last type.
|
||||
first: Runnable[Input, Any]
|
||||
"""The first runnable in the sequence."""
|
||||
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
||||
middle: List[Runnable[Any, Any]] = field(default_factory=list)
|
||||
"""The middle runnables in the sequence."""
|
||||
last: Runnable[Any, Output]
|
||||
"""The last runnable in the sequence."""
|
||||
@@ -2289,7 +2289,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
last: Optional[Runnable[Any, Any]] = None,
|
||||
) -> None:
|
||||
"""Create a new RunnableSequence.
|
||||
|
||||
Args:
|
||||
steps: The steps to include in the sequence.
|
||||
"""
|
||||
@@ -2306,9 +2305,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
raise ValueError(
|
||||
f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}"
|
||||
)
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
return self.__default_init__(
|
||||
first=steps_flat[0],
|
||||
middle=list(steps_flat[1:-1]),
|
||||
middle=steps_flat[1:-1],
|
||||
last=steps_flat[-1],
|
||||
name=name,
|
||||
)
|
||||
@@ -2350,48 +2349,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
from langchain_core.beta.runnables.context import (
|
||||
CONTEXT_CONFIG_PREFIX,
|
||||
_key_from_id,
|
||||
)
|
||||
|
||||
# get all specs
|
||||
all_specs = [
|
||||
(spec, idx)
|
||||
for idx, step in enumerate(self.steps)
|
||||
for spec in step.config_specs
|
||||
]
|
||||
# calculate context dependencies
|
||||
specs_by_pos = groupby(
|
||||
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
|
||||
lambda x: x[1],
|
||||
)
|
||||
next_deps: Set[str] = set()
|
||||
deps_by_pos: Dict[int, Set[str]] = {}
|
||||
for pos, specs in specs_by_pos:
|
||||
deps_by_pos[pos] = next_deps
|
||||
next_deps = next_deps | {spec[0].id for spec in specs}
|
||||
# assign context dependencies
|
||||
for pos, (spec, idx) in enumerate(all_specs):
|
||||
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
|
||||
all_specs[pos] = (
|
||||
ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
annotation=spec.annotation,
|
||||
name=spec.name,
|
||||
default=spec.default,
|
||||
description=spec.description,
|
||||
is_shared=spec.is_shared,
|
||||
dependencies=[
|
||||
d
|
||||
for d in deps_by_pos[idx]
|
||||
if _key_from_id(d) != _key_from_id(spec.id)
|
||||
]
|
||||
+ (spec.dependencies or []),
|
||||
),
|
||||
idx,
|
||||
)
|
||||
|
||||
return get_unique_config_specs(spec for spec, _ in all_specs)
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
@@ -2478,10 +2441,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
|
||||
# setup callbacks and context
|
||||
config = config_with_context(ensure_config(config), self.steps)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@@ -2516,10 +2477,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
|
||||
# setup callbacks and context
|
||||
config = aconfig_with_context(ensure_config(config), self.steps)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@@ -2556,17 +2515,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks and context
|
||||
configs = [
|
||||
config_with_context(c, self.steps)
|
||||
for c in get_config_list(config, len(inputs))
|
||||
]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@@ -2682,17 +2637,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks and context
|
||||
configs = [
|
||||
aconfig_with_context(c, self.steps)
|
||||
for c in get_config_list(config, len(inputs))
|
||||
]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@@ -2810,16 +2761,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
config = config_with_context(config, self.steps)
|
||||
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
final_pipeline = cast(Iterator[Output], input)
|
||||
for idx, step in enumerate(steps):
|
||||
for idx, step in enumerate(self.steps):
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{idx+1}")
|
||||
)
|
||||
@@ -2838,17 +2784,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Output]:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
config = aconfig_with_context(config, self.steps)
|
||||
|
||||
# stream the last steps
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
final_pipeline = cast(AsyncIterator[Output], input)
|
||||
for idx, step in enumerate(steps):
|
||||
for idx, step in enumerate(self.steps):
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{idx+1}"),
|
||||
@@ -2988,31 +2929,11 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
print(output) # noqa: T201
|
||||
"""
|
||||
|
||||
steps__: Mapping[str, Runnable[Input, Any]]
|
||||
steps__: Mapping[str, Runnable[Input, Any]] = field(kw_only=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
steps__: Optional[
|
||||
Mapping[
|
||||
str,
|
||||
Union[
|
||||
Runnable[Input, Any],
|
||||
Callable[[Input], Any],
|
||||
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
**kwargs: Union[
|
||||
Runnable[Input, Any],
|
||||
Callable[[Input], Any],
|
||||
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
|
||||
],
|
||||
) -> None:
|
||||
merged = {**steps__} if steps__ is not None else {}
|
||||
merged.update(kwargs)
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
steps__={key: coerce_to_runnable(r) for key, r in merged.items()}
|
||||
)
|
||||
def __post_init__(self) -> None:
|
||||
for key, step in self.steps__.items():
|
||||
self.steps__[key] = coerce_to_runnable(step)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -4435,7 +4356,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
bound: Runnable[Input, Output]
|
||||
"""The underlying runnable that this runnable delegates to."""
|
||||
|
||||
kwargs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
kwargs: Mapping[str, Any] = field(default_factory=dict)
|
||||
"""kwargs to pass to the underlying runnable when running.
|
||||
|
||||
For example, when the runnable binding is invoked the underlying
|
||||
@@ -4443,10 +4364,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
kwargs.
|
||||
"""
|
||||
|
||||
config: RunnableConfig = Field(default_factory=dict)
|
||||
config: RunnableConfig = field(default_factory=dict)
|
||||
"""The config to bind to the underlying runnable."""
|
||||
|
||||
config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field(
|
||||
config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
"""The config factories to bind to the underlying runnable."""
|
||||
@@ -4464,51 +4385,6 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
The type can be a pydantic model, or a type annotation (e.g., `List[str]`).
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bound: Runnable[Input, Output],
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config_factories: Optional[
|
||||
List[Callable[[RunnableConfig], RunnableConfig]]
|
||||
] = None,
|
||||
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a RunnableBinding from a runnable and kwargs.
|
||||
|
||||
Args:
|
||||
bound: The underlying runnable that this runnable delegates calls to.
|
||||
kwargs: optional kwargs to pass to the underlying runnable, when running
|
||||
the underlying runnable (e.g., via `invoke`, `batch`,
|
||||
`transform`, or `stream` or async variants)
|
||||
config: config_factories:
|
||||
config_factories: optional list of config factories to apply to the
|
||||
custom_input_type: Specify to override the input type of the underlying
|
||||
runnable with a custom type.
|
||||
custom_output_type: Specify to override the output type of the underlying
|
||||
runnable with a custom type.
|
||||
**other_kwargs: Unpacked into the base class.
|
||||
"""
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
bound=bound,
|
||||
kwargs=kwargs or {},
|
||||
config=config or {},
|
||||
config_factories=config_factories or [],
|
||||
custom_input_type=custom_input_type,
|
||||
custom_output_type=custom_output_type,
|
||||
**other_kwargs,
|
||||
)
|
||||
# if we don't explicitly set config to the TypedDict here,
|
||||
# the pydantic init above will strip out any of the "extra"
|
||||
# fields even though total=False on the typed dict.
|
||||
self.config = config or {}
|
||||
|
||||
def get_name(
|
||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||
) -> str:
|
||||
@@ -4801,9 +4677,6 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||
|
||||
|
||||
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
"""Wrap a Runnable with additional functionality.
|
||||
|
||||
|
||||
@@ -407,13 +407,8 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
}
|
||||
|
||||
if configurable:
|
||||
init_params = {
|
||||
k: v
|
||||
for k, v in self.default.__dict__.items()
|
||||
if k in self.default.__fields__
|
||||
}
|
||||
return (
|
||||
self.default.__class__(**{**init_params, **configurable}),
|
||||
self.default.copy(update=configurable),
|
||||
config,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -325,7 +325,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
),
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
self.__default_init__(
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key=input_messages_key,
|
||||
output_messages_key=output_messages_key,
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
from dataclasses import field
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -165,7 +166,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
afunc = func
|
||||
func = None
|
||||
|
||||
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) # type: ignore[call-arg]
|
||||
self.__default_init__(func=func, afunc=afunc, input_type=input_type, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -365,10 +366,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
# returns {'input': 5, 'add_step': {'added': 15}}
|
||||
"""
|
||||
|
||||
mapper: RunnableParallel[Dict[str, Any]]
|
||||
|
||||
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg]
|
||||
mapper: RunnableParallel[Dict[str, Any]] = field(kw_only=False)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
||||
@@ -26,6 +26,7 @@ import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextvars import copy_context
|
||||
from dataclasses import field
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
@@ -41,7 +42,6 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
PromptTemplate,
|
||||
@@ -51,10 +51,8 @@ from langchain_core.prompts import (
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
ValidationError,
|
||||
create_model,
|
||||
root_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
@@ -193,9 +191,11 @@ class ChildTool(BaseTool):
|
||||
verbose: bool = False
|
||||
"""Whether to log the tool's progress."""
|
||||
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callbacks: Callbacks = field(default=None, metadata={"exclude": True})
|
||||
"""Callbacks to be called during tool execution."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = field(
|
||||
default=None, metadata={"exclude": True}
|
||||
)
|
||||
"""Deprecated. Please use callbacks instead."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the tool. Defaults to None
|
||||
@@ -220,11 +220,6 @@ class ChildTool(BaseTool):
|
||||
] = False
|
||||
"""Handle the content of the ValidationError thrown."""
|
||||
|
||||
class Config(Serializable.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def is_single_input(self) -> bool:
|
||||
"""Whether the tool only accepts a single input."""
|
||||
@@ -309,16 +304,14 @@ class ChildTool(BaseTool):
|
||||
}
|
||||
return tool_input
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
if self.callback_manager is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
self.callbacks = self.callback_manager
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
@@ -663,15 +656,6 @@ class Tool(BaseTool):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
def __init__(
|
||||
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize tool."""
|
||||
super(Tool, self).__init__( # type: ignore[call-arg]
|
||||
name=name, func=func, description=description, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
@@ -703,7 +687,7 @@ class StructuredTool(BaseTool):
|
||||
"""Tool that can operate on any number of inputs."""
|
||||
|
||||
description: str = ""
|
||||
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
|
||||
args_schema: Type[BaseModel]
|
||||
"""The input arguments' schema."""
|
||||
func: Optional[Callable[..., Any]]
|
||||
"""The function to run when the tool is called."""
|
||||
@@ -972,7 +956,8 @@ def tool(
|
||||
class RetrieverInput(BaseModel):
|
||||
"""Input to the retriever."""
|
||||
|
||||
query: str = Field(description="query to look up in retriever")
|
||||
query: str
|
||||
"""query to look up in retriever"""
|
||||
|
||||
|
||||
def _get_relevant_documents(
|
||||
|
||||
@@ -17,7 +17,7 @@ def selector() -> LengthBasedExampleSelector:
|
||||
"""Get length based selector to use in tests."""
|
||||
prompts = PromptTemplate(input_variables=["question"], template="{question}")
|
||||
selector = LengthBasedExampleSelector(
|
||||
examples=EXAMPLES,
|
||||
examples=EXAMPLES.copy(),
|
||||
example_prompt=prompts,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ def test_serdes_message() -> None:
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"type": "ai",
|
||||
"content": [{"text": "blah", "type": "text"}],
|
||||
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||
"invalid_tool_calls": [
|
||||
@@ -47,7 +46,6 @@ def test_serdes_message_chunk() -> None:
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessageChunk"],
|
||||
"kwargs": {
|
||||
"type": "AIMessageChunk",
|
||||
"content": [{"text": "blah", "type": "text"}],
|
||||
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||
"invalid_tool_calls": [
|
||||
|
||||
@@ -47,6 +47,7 @@ def test_base_generation_parser() -> None:
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")]))
|
||||
chain = model | StrInvertCase()
|
||||
print(chain)
|
||||
assert chain.invoke("") == "HeLLO"
|
||||
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ def test_prompt_from_template_with_partial_variables() -> None:
|
||||
def test_prompt_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables: list = []
|
||||
input_variables: list = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, validate_template=True
|
||||
|
||||
@@ -1,801 +1,4 @@
|
||||
# serializer version: 1
|
||||
# name: test_fallbacks[chain]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableParallel"
|
||||
],
|
||||
"kwargs": {
|
||||
"steps__": {
|
||||
"buz": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(lambda x: x)"
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "RunnableParallel<buz>",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "Parallel<buz>Input"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "schema",
|
||||
"data": "Parallel<buz>Output"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"name": "RunnableLambda"
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 2
|
||||
},
|
||||
{
|
||||
"source": 2,
|
||||
"target": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"buz"
|
||||
],
|
||||
"template": "what did baz say to {buz}",
|
||||
"template_format": "f-string"
|
||||
},
|
||||
"name": "PromptTemplate",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"name": "PromptTemplate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "PromptTemplateOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"language_models",
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)",
|
||||
"name": "FakeListLLM",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "FakeListLLMInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"language_models",
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"name": "FakeListLLM"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "FakeListLLMOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "RunnableSequence",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"name": "PromptTemplate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"language_models",
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"name": "FakeListLLM"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "schema",
|
||||
"data": "FakeListLLMOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 2,
|
||||
"target": 3
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"buz"
|
||||
],
|
||||
"template": "what did baz say to {buz}",
|
||||
"template_format": "f-string"
|
||||
},
|
||||
"name": "PromptTemplate",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"name": "PromptTemplate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "PromptTemplateOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"language_models",
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])",
|
||||
"name": "FakeListLLM",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "FakeListLLMInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"language_models",
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"name": "FakeListLLM"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "FakeListLLMOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "RunnableSequence",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"name": "PromptTemplate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"language_models",
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"name": "FakeListLLM"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "schema",
|
||||
"data": "FakeListLLMOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 2,
|
||||
"target": 3
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "RunnableWithFallbacks",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "FakeListLLMOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "RunnableSequence",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "Parallel<buz>Input"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"name": "RunnableLambda"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "schema",
|
||||
"data": "FakeListLLMOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 2,
|
||||
"target": 3
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_fallbacks[chain_pass_exceptions]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableParallel"
|
||||
],
|
||||
"kwargs": {
|
||||
"steps__": {
|
||||
"text": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"kwargs": {},
|
||||
"name": "RunnablePassthrough",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PassthroughInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"name": "RunnablePassthrough"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "PassthroughOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "RunnableParallel<text>",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "Parallel<text>Input"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "schema",
|
||||
"data": "Parallel<text>Output"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"name": "RunnablePassthrough"
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 2
|
||||
},
|
||||
{
|
||||
"source": 2,
|
||||
"target": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(_raise_error)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(_dont_raise_error)"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": "exception"
|
||||
},
|
||||
"name": "RunnableWithFallbacks",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "_raise_error_input"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "_raise_error_output"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "RunnableSequence",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "Parallel<text>Input"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"name": "RunnablePassthrough"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "schema",
|
||||
"data": "_raise_error_output"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 2,
|
||||
"target": 3
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_fallbacks[llm]
|
||||
'''
|
||||
{
|
||||
@@ -817,7 +20,7 @@
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)",
|
||||
"repr": "FakeListLLM(name=None, cache=None, verbose=False, callbacks=None, tags=None, metadata=None, custom_get_token_ids=None, callback_manager=None, responses=['foo'], sleep=None, i=1)",
|
||||
"name": "FakeListLLM",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
@@ -867,7 +70,7 @@
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])",
|
||||
"repr": "FakeListLLM(name=None, cache=None, verbose=False, callbacks=None, tags=None, metadata=None, custom_get_token_ids=None, callback_manager=None, responses=['bar'], sleep=None, i=0)",
|
||||
"name": "FakeListLLM",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
@@ -907,17 +110,6 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "RunnableWithFallbacks",
|
||||
@@ -982,7 +174,7 @@
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)",
|
||||
"repr": "FakeListLLM(name=None, cache=None, verbose=False, callbacks=None, tags=None, metadata=None, custom_get_token_ids=None, callback_manager=None, responses=['foo'], sleep=None, i=1)",
|
||||
"name": "FakeListLLM",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
@@ -1032,7 +224,7 @@
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['baz'], i=1)",
|
||||
"repr": "FakeListLLM(name=None, cache=None, verbose=False, callbacks=None, tags=None, metadata=None, custom_get_token_ids=None, callback_manager=None, responses=['baz'], sleep=None, i=1)",
|
||||
"name": "FakeListLLM",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
@@ -1081,7 +273,7 @@
|
||||
"fake",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])",
|
||||
"repr": "FakeListLLM(name=None, cache=None, verbose=False, callbacks=None, tags=None, metadata=None, custom_get_token_ids=None, callback_manager=None, responses=['bar'], sleep=None, i=0)",
|
||||
"name": "FakeListLLM",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
@@ -1121,17 +313,6 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "RunnableWithFallbacks",
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables import (
|
||||
ConfigurableField,
|
||||
RunnableConfig,
|
||||
@@ -11,22 +10,11 @@ from langchain_core.runnables import (
|
||||
|
||||
|
||||
class MyRunnable(RunnableSerializable[str, str]):
|
||||
my_property: str = Field(alias="my_property_alias")
|
||||
my_property: str
|
||||
_my_hidden_property: str = ""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def my_error(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "_my_hidden_property" in values:
|
||||
raise ValueError("Cannot set _my_hidden_property")
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values["_my_hidden_property"] = values["my_property"]
|
||||
return values
|
||||
def __post_init__(self) -> None:
|
||||
self._my_hidden_property = self.my_property
|
||||
|
||||
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
|
||||
return input + self._my_hidden_property
|
||||
@@ -77,42 +65,6 @@ def test_doubly_set_configurable() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_alias_set_configurable() -> None:
|
||||
runnable = MyRunnable(my_property="a") # type: ignore
|
||||
configurable_runnable = runnable.configurable_fields(
|
||||
my_property=ConfigurableField(
|
||||
id="my_property_alias",
|
||||
name="My property alias",
|
||||
description="The property to test alias",
|
||||
)
|
||||
)
|
||||
|
||||
assert (
|
||||
configurable_runnable.invoke(
|
||||
"d", config=RunnableConfig(configurable={"my_property_alias": "c"})
|
||||
)
|
||||
== "dc"
|
||||
)
|
||||
|
||||
|
||||
def test_field_alias_set_configurable() -> None:
|
||||
runnable = MyRunnable(my_property_alias="a")
|
||||
configurable_runnable = runnable.configurable_fields(
|
||||
my_property=ConfigurableField(
|
||||
id="my_property",
|
||||
name="My property alias",
|
||||
description="The property to test alias",
|
||||
)
|
||||
)
|
||||
|
||||
assert (
|
||||
configurable_runnable.invoke(
|
||||
"d", config=RunnableConfig(configurable={"my_property": "c"})
|
||||
)
|
||||
== "dc"
|
||||
)
|
||||
|
||||
|
||||
def test_config_passthrough() -> None:
|
||||
runnable = MyRunnable(my_property="a") # type: ignore
|
||||
configurable_runnable = runnable.configurable_fields(
|
||||
|
||||
@@ -1,411 +0,0 @@
|
||||
from typing import Any, Callable, List, NamedTuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.beta.runnables.context import Context
|
||||
from langchain_core.language_models import FakeListLLM, FakeStreamingListLLM
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.prompt_values import StringPromptValue
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables.base import Runnable, RunnableLambda
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import aadd, add
|
||||
|
||||
|
||||
class _TestCase(NamedTuple):
|
||||
input: Any
|
||||
output: Any
|
||||
|
||||
|
||||
def seq_naive_rag() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"result": RunnablePassthrough(),
|
||||
"context": Context.getter("context"),
|
||||
"input": Context.getter("input"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def seq_naive_rag_alt() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| Context.setter("result")
|
||||
| Context.getter(["context", "input", "result"])
|
||||
)
|
||||
|
||||
|
||||
def seq_naive_rag_scoped() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
scoped = Context.create_scope("a_scope")
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
"scoped": scoped.setter("context") | scoped.getter("context"),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| Context.setter("result")
|
||||
| Context.getter(["context", "input", "result"])
|
||||
)
|
||||
|
||||
|
||||
test_cases = [
|
||||
(
|
||||
Context.setter("foo") | Context.getter("foo"),
|
||||
(
|
||||
_TestCase("foo", "foo"),
|
||||
_TestCase("bar", "bar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
Context.setter("input") | {"bar": Context.getter("input")},
|
||||
(
|
||||
_TestCase("foo", {"bar": "foo"}),
|
||||
_TestCase("bar", {"bar": "bar"}),
|
||||
),
|
||||
),
|
||||
(
|
||||
{"bar": Context.setter("input")} | Context.getter("input"),
|
||||
(
|
||||
_TestCase("foo", "foo"),
|
||||
_TestCase("bar", "bar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt", prompt_str=lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{
|
||||
"response": "hello",
|
||||
"prompt": StringPromptValue(text="foo bar"),
|
||||
"prompt_str": "foo bar",
|
||||
},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{
|
||||
"response": "hello",
|
||||
"prompt": StringPromptValue(text="bar foo"),
|
||||
"prompt_str": "bar foo",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter(prompt_str=lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt_str": "foo bar"},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt_str": "bar foo"},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt_str", lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt_str": "foo bar"},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt_str": "bar foo"},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeStreamingListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag,
|
||||
(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag_alt,
|
||||
(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag_scoped,
|
||||
(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
||||
async def test_context_runnables(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[_TestCase]
|
||||
) -> None:
|
||||
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
||||
assert runnable.invoke(cases[0].input) == cases[0].output
|
||||
assert await runnable.ainvoke(cases[1].input) == cases[1].output
|
||||
assert runnable.batch([case.input for case in cases]) == [
|
||||
case.output for case in cases
|
||||
]
|
||||
assert await runnable.abatch([case.input for case in cases]) == [
|
||||
case.output for case in cases
|
||||
]
|
||||
assert add(runnable.stream(cases[0].input)) == cases[0].output
|
||||
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_not_found() -> None:
|
||||
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_order() -> None:
|
||||
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_deadlock() -> None:
|
||||
seq: Runnable = {
|
||||
"bar": Context.setter("input") | Context.getter("foo"),
|
||||
"foo": Context.setter("foo") | Context.getter("input"),
|
||||
} | RunnablePassthrough()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_circular_ref() -> None:
|
||||
seq: Runnable = {
|
||||
"bar": Context.setter(input=Context.getter("input"))
|
||||
} | Context.getter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
async def test_runnable_seq_streaming_chunks() -> None:
|
||||
chain: Runnable = (
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeStreamingListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
)
|
||||
|
||||
chunks = [c for c in chain.stream({"foo": "foo", "bar": "bar"})]
|
||||
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
|
||||
for c in chunks:
|
||||
assert c in achunks
|
||||
for c in achunks:
|
||||
assert c in chunks
|
||||
|
||||
assert len(chunks) == 6
|
||||
assert [c for c in chunks if c.get("response")] == [
|
||||
{"response": "h"},
|
||||
{"response": "e"},
|
||||
{"response": "l"},
|
||||
{"response": "l"},
|
||||
{"response": "o"},
|
||||
]
|
||||
assert [c for c in chunks if c.get("prompt")] == [
|
||||
{"prompt": StringPromptValue(text="foo bar")},
|
||||
]
|
||||
@@ -12,6 +12,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.runnables.base import RunnableSequence
|
||||
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
@@ -145,6 +146,7 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"])
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
|
||||
@@ -306,7 +306,6 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"definitions": {
|
||||
"Document": {
|
||||
"title": "Document",
|
||||
"description": AnyStr(),
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"page_content": {"title": "Page Content", "type": "string"},
|
||||
|
||||
@@ -558,15 +558,6 @@ def test_missing_docstring() -> None:
|
||||
return "API result"
|
||||
|
||||
|
||||
def test_create_tool_positional_args() -> None:
|
||||
"""Test that positional arguments are allowed."""
|
||||
test_tool = Tool("test_name", lambda x: x, "test_description")
|
||||
assert test_tool.invoke("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.is_single_input
|
||||
|
||||
|
||||
def test_create_tool_keyword_args() -> None:
|
||||
"""Test that keyword arguments are allowed."""
|
||||
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
||||
|
||||
Reference in New Issue
Block a user