mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
BUG: more core fixes (#13665)
Fix some circular deps: - move PromptValue into top level module bc both PromptTemplates and OutputParsers import - move tracer context vars to `tracers.context` and import them in functions in `callbacks.manager` - add core import tests
This commit is contained in:
parent
59df16ab92
commit
c61e30632e
40
.github/workflows/langchain_ci.yml
vendored
40
.github/workflows/langchain_ci.yml
vendored
@ -63,6 +63,46 @@ jobs:
|
|||||||
langchain-core-location: ../core
|
langchain-core-location: ../core
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
# It's possible that langchain works fine with the latest *published* langchain-core,
|
||||||
|
# but is broken with the langchain-core on `master`.
|
||||||
|
#
|
||||||
|
# We want to catch situations like that *before* releasing a new langchain-core, hence this test.
|
||||||
|
test-with-latest-langchain-core:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: ${{ env.WORKDIR }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version:
|
||||||
|
- "3.8"
|
||||||
|
- "3.9"
|
||||||
|
- "3.10"
|
||||||
|
- "3.11"
|
||||||
|
name: test with unpublished langchain-core - Python ${{ matrix.python-version }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||||
|
uses: "./.github/actions/poetry_setup"
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
poetry-version: ${{ env.POETRY_VERSION }}
|
||||||
|
working-directory: ${{ env.WORKDIR }}
|
||||||
|
cache-key: unpublished-langchain-core
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "Running tests with unpublished langchain, installing dependencies with poetry..."
|
||||||
|
poetry install
|
||||||
|
|
||||||
|
echo "Editably installing langchain-core outside of poetry, to avoid messing up lockfile..."
|
||||||
|
poetry run pip install -e ../core
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: make test
|
||||||
|
|
||||||
extended-tests:
|
extended-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
defaults:
|
defaults:
|
||||||
|
@ -28,8 +28,6 @@ from langchain_core.callbacks.manager import (
|
|||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
ParentRunManager,
|
ParentRunManager,
|
||||||
RunManager,
|
RunManager,
|
||||||
env_var_is_set,
|
|
||||||
register_configure_hook,
|
|
||||||
)
|
)
|
||||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
@ -64,6 +62,4 @@ __all__ = [
|
|||||||
"AsyncCallbackManagerForChainGroup",
|
"AsyncCallbackManagerForChainGroup",
|
||||||
"StdOutCallbackHandler",
|
"StdOutCallbackHandler",
|
||||||
"StreamingStdOutCallbackHandler",
|
"StreamingStdOutCallbackHandler",
|
||||||
"env_var_is_set",
|
|
||||||
"register_configure_hook",
|
|
||||||
]
|
]
|
||||||
|
@ -3,13 +3,10 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
@ -18,7 +15,6 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -26,14 +22,10 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langsmith import utils as ls_utils
|
|
||||||
from langsmith.run_helpers import get_run_tree_context
|
from langsmith.run_helpers import get_run_tree_context
|
||||||
from tenacity import RetryCallState
|
from tenacity import RetryCallState
|
||||||
|
|
||||||
from langchain_core.agents import (
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
AgentAction,
|
|
||||||
AgentFinish,
|
|
||||||
)
|
|
||||||
from langchain_core.callbacks.base import (
|
from langchain_core.callbacks.base import (
|
||||||
BaseCallbackHandler,
|
BaseCallbackHandler,
|
||||||
BaseCallbackManager,
|
BaseCallbackManager,
|
||||||
@ -48,32 +40,10 @@ from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||||
from langchain_core.tracers import run_collector
|
from langchain_core.utils.env import env_var_is_set
|
||||||
from langchain_core.tracers.langchain import (
|
|
||||||
LangChainTracer,
|
|
||||||
)
|
|
||||||
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
|
||||||
from langchain_core.tracers.schemas import TracerSessionV1
|
|
||||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from langsmith import Client as LangSmithClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501
|
|
||||||
"tracing_callback", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
|
|
||||||
"tracing_callback_v2", default=None
|
|
||||||
)
|
|
||||||
run_collector_var: ContextVar[
|
|
||||||
Optional[run_collector.RunCollectorCallbackHandler]
|
|
||||||
] = ContextVar( # noqa: E501
|
|
||||||
"run_collector", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_debug() -> bool:
|
def _get_debug() -> bool:
|
||||||
from langchain_core.globals import get_debug
|
from langchain_core.globals import get_debug
|
||||||
@ -81,123 +51,6 @@ def _get_debug() -> bool:
|
|||||||
return get_debug()
|
return get_debug()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def tracing_enabled(
|
|
||||||
session_name: str = "default",
|
|
||||||
) -> Generator[TracerSessionV1, None, None]:
|
|
||||||
"""Get the Deprecated LangChainTracer in a context manager.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_name (str, optional): The name of the session.
|
|
||||||
Defaults to "default".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TracerSessionV1: The LangChainTracer session.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> with tracing_enabled() as session:
|
|
||||||
... # Use the LangChainTracer session
|
|
||||||
"""
|
|
||||||
cb = LangChainTracerV1()
|
|
||||||
session = cast(TracerSessionV1, cb.load_session(session_name))
|
|
||||||
try:
|
|
||||||
tracing_callback_var.set(cb)
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
tracing_callback_var.set(None)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def tracing_v2_enabled(
|
|
||||||
project_name: Optional[str] = None,
|
|
||||||
*,
|
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
|
||||||
tags: Optional[List[str]] = None,
|
|
||||||
client: Optional[LangSmithClient] = None,
|
|
||||||
) -> Generator[LangChainTracer, None, None]:
|
|
||||||
"""Instruct LangChain to log all runs in context to LangSmith.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
project_name (str, optional): The name of the project.
|
|
||||||
Defaults to "default".
|
|
||||||
example_id (str or UUID, optional): The ID of the example.
|
|
||||||
Defaults to None.
|
|
||||||
tags (List[str], optional): The tags to add to the run.
|
|
||||||
Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> with tracing_v2_enabled():
|
|
||||||
... # LangChain code will automatically be traced
|
|
||||||
|
|
||||||
You can use this to fetch the LangSmith run URL:
|
|
||||||
|
|
||||||
>>> with tracing_v2_enabled() as cb:
|
|
||||||
... chain.invoke("foo")
|
|
||||||
... run_url = cb.get_run_url()
|
|
||||||
"""
|
|
||||||
if isinstance(example_id, str):
|
|
||||||
example_id = UUID(example_id)
|
|
||||||
cb = LangChainTracer(
|
|
||||||
example_id=example_id,
|
|
||||||
project_name=project_name,
|
|
||||||
tags=tags,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
tracing_v2_callback_var.set(cb)
|
|
||||||
yield cb
|
|
||||||
finally:
|
|
||||||
tracing_v2_callback_var.set(None)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]:
|
|
||||||
"""Collect all run traces in context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> with collect_runs() as runs_cb:
|
|
||||||
chain.invoke("foo")
|
|
||||||
run_id = runs_cb.traced_runs[0].id
|
|
||||||
"""
|
|
||||||
cb = run_collector.RunCollectorCallbackHandler()
|
|
||||||
run_collector_var.set(cb)
|
|
||||||
yield cb
|
|
||||||
run_collector_var.set(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_trace_callbacks(
|
|
||||||
project_name: Optional[str] = None,
|
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
|
||||||
callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None,
|
|
||||||
) -> Callbacks:
|
|
||||||
if _tracing_v2_is_enabled():
|
|
||||||
project_name_ = project_name or _get_tracer_project()
|
|
||||||
tracer = tracing_v2_callback_var.get() or LangChainTracer(
|
|
||||||
project_name=project_name_,
|
|
||||||
example_id=example_id,
|
|
||||||
)
|
|
||||||
if callback_manager is None:
|
|
||||||
cb = cast(Callbacks, [tracer])
|
|
||||||
else:
|
|
||||||
if not any(
|
|
||||||
isinstance(handler, LangChainTracer)
|
|
||||||
for handler in callback_manager.handlers
|
|
||||||
):
|
|
||||||
callback_manager.add_handler(tracer, True)
|
|
||||||
# If it already has a LangChainTracer, we don't need to add another one.
|
|
||||||
# this would likely mess up the trace hierarchy.
|
|
||||||
cb = callback_manager
|
|
||||||
else:
|
|
||||||
cb = None
|
|
||||||
return cb
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def trace_as_chain_group(
|
def trace_as_chain_group(
|
||||||
group_name: str,
|
group_name: str,
|
||||||
@ -239,6 +92,8 @@ def trace_as_chain_group(
|
|||||||
res = llm.predict(llm_input, callbacks=manager)
|
res = llm.predict(llm_input, callbacks=manager)
|
||||||
manager.on_chain_end({"output": res})
|
manager.on_chain_end({"output": res})
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
from langchain_core.tracers.context import _get_trace_callbacks
|
||||||
|
|
||||||
cb = _get_trace_callbacks(
|
cb = _get_trace_callbacks(
|
||||||
project_name, example_id, callback_manager=callback_manager
|
project_name, example_id, callback_manager=callback_manager
|
||||||
)
|
)
|
||||||
@ -310,6 +165,8 @@ async def atrace_as_chain_group(
|
|||||||
res = await llm.apredict(llm_input, callbacks=manager)
|
res = await llm.apredict(llm_input, callbacks=manager)
|
||||||
await manager.on_chain_end({"output": res})
|
await manager.on_chain_end({"output": res})
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
from langchain_core.tracers.context import _get_trace_callbacks
|
||||||
|
|
||||||
cb = _get_trace_callbacks(
|
cb = _get_trace_callbacks(
|
||||||
project_name, example_id, callback_manager=callback_manager
|
project_name, example_id, callback_manager=callback_manager
|
||||||
)
|
)
|
||||||
@ -1850,88 +1707,9 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
|||||||
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
||||||
|
|
||||||
|
|
||||||
def env_var_is_set(env_var: str) -> bool:
|
|
||||||
"""Check if an environment variable is set.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env_var (str): The name of the environment variable.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the environment variable is set, False otherwise.
|
|
||||||
"""
|
|
||||||
return env_var in os.environ and os.environ[env_var] not in (
|
|
||||||
"",
|
|
||||||
"0",
|
|
||||||
"false",
|
|
||||||
"False",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _tracing_v2_is_enabled() -> bool:
|
|
||||||
return (
|
|
||||||
env_var_is_set("LANGCHAIN_TRACING_V2")
|
|
||||||
or tracing_v2_callback_var.get() is not None
|
|
||||||
or get_run_tree_context() is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tracer_project() -> str:
|
|
||||||
run_tree = get_run_tree_context()
|
|
||||||
return getattr(
|
|
||||||
run_tree,
|
|
||||||
"session_name",
|
|
||||||
getattr(
|
|
||||||
# Note, if people are trying to nest @traceable functions and the
|
|
||||||
# tracing_v2_enabled context manager, this will likely mess up the
|
|
||||||
# tree structure.
|
|
||||||
tracing_v2_callback_var.get(),
|
|
||||||
"project",
|
|
||||||
# Have to set this to a string even though it always will return
|
|
||||||
# a string because `get_tracer_project` technically can return
|
|
||||||
# None, but only when a specific argument is supplied.
|
|
||||||
# Therefore, this just tricks the mypy type checker
|
|
||||||
str(ls_utils.get_tracer_project()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_configure_hooks: List[
|
|
||||||
Tuple[
|
|
||||||
ContextVar[Optional[BaseCallbackHandler]],
|
|
||||||
bool,
|
|
||||||
Optional[Type[BaseCallbackHandler]],
|
|
||||||
Optional[str],
|
|
||||||
]
|
|
||||||
] = []
|
|
||||||
|
|
||||||
H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
|
H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
|
||||||
|
|
||||||
|
|
||||||
def register_configure_hook(
|
|
||||||
context_var: ContextVar[Optional[Any]],
|
|
||||||
inheritable: bool,
|
|
||||||
handle_class: Optional[Type[BaseCallbackHandler]] = None,
|
|
||||||
env_var: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
if env_var is not None and handle_class is None:
|
|
||||||
raise ValueError(
|
|
||||||
"If env_var is set, handle_class must also be set to a non-None value."
|
|
||||||
)
|
|
||||||
_configure_hooks.append(
|
|
||||||
(
|
|
||||||
# the typings of ContextVar do not have the generic arg set as covariant
|
|
||||||
# so we have to cast it
|
|
||||||
cast(ContextVar[Optional[BaseCallbackHandler]], context_var),
|
|
||||||
inheritable,
|
|
||||||
handle_class,
|
|
||||||
env_var,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_configure_hook(run_collector_var, False)
|
|
||||||
|
|
||||||
|
|
||||||
def _configure(
|
def _configure(
|
||||||
callback_manager_cls: Type[T],
|
callback_manager_cls: Type[T],
|
||||||
inheritable_callbacks: Callbacks = None,
|
inheritable_callbacks: Callbacks = None,
|
||||||
@ -1962,6 +1740,14 @@ def _configure(
|
|||||||
Returns:
|
Returns:
|
||||||
T: The configured callback manager.
|
T: The configured callback manager.
|
||||||
"""
|
"""
|
||||||
|
from langchain_core.tracers.context import (
|
||||||
|
_configure_hooks,
|
||||||
|
_get_tracer_project,
|
||||||
|
_tracing_v2_is_enabled,
|
||||||
|
tracing_callback_var,
|
||||||
|
tracing_v2_callback_var,
|
||||||
|
)
|
||||||
|
|
||||||
run_tree = get_run_tree_context()
|
run_tree = get_run_tree_context()
|
||||||
parent_run_id = None if run_tree is None else getattr(run_tree, "id")
|
parent_run_id = None if run_tree is None else getattr(run_tree, "id")
|
||||||
callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id)
|
callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id)
|
||||||
@ -2009,6 +1795,10 @@ def _configure(
|
|||||||
tracer_project = _get_tracer_project()
|
tracer_project = _get_tracer_project()
|
||||||
debug = _get_debug()
|
debug = _get_debug()
|
||||||
if verbose or debug or tracing_enabled_ or tracing_v2_enabled_:
|
if verbose or debug or tracing_enabled_ or tracing_v2_enabled_:
|
||||||
|
from langchain_core.tracers.langchain import LangChainTracer
|
||||||
|
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
||||||
|
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||||
|
|
||||||
if verbose and not any(
|
if verbose and not any(
|
||||||
isinstance(handler, StdOutCallbackHandler)
|
isinstance(handler, StdOutCallbackHandler)
|
||||||
for handler in callback_manager.handlers
|
for handler in callback_manager.handlers
|
||||||
|
@ -19,7 +19,7 @@ _llm_cache: Optional["BaseCache"] = None
|
|||||||
def set_verbose(value: bool) -> None:
|
def set_verbose(value: bool) -> None:
|
||||||
"""Set a new value for the `verbose` global setting."""
|
"""Set a new value for the `verbose` global setting."""
|
||||||
try:
|
try:
|
||||||
import langchain
|
import langchain # type: ignore[import]
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||||
@ -47,7 +47,7 @@ def set_verbose(value: bool) -> None:
|
|||||||
def get_verbose() -> bool:
|
def get_verbose() -> bool:
|
||||||
"""Get the value of the `verbose` global setting."""
|
"""Get the value of the `verbose` global setting."""
|
||||||
try:
|
try:
|
||||||
import langchain
|
import langchain # type: ignore[import]
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||||
@ -80,7 +80,7 @@ def get_verbose() -> bool:
|
|||||||
def set_debug(value: bool) -> None:
|
def set_debug(value: bool) -> None:
|
||||||
"""Set a new value for the `debug` global setting."""
|
"""Set a new value for the `debug` global setting."""
|
||||||
try:
|
try:
|
||||||
import langchain
|
import langchain # type: ignore[import]
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||||
@ -106,7 +106,7 @@ def set_debug(value: bool) -> None:
|
|||||||
def get_debug() -> bool:
|
def get_debug() -> bool:
|
||||||
"""Get the value of the `debug` global setting."""
|
"""Get the value of the `debug` global setting."""
|
||||||
try:
|
try:
|
||||||
import langchain
|
import langchain # type: ignore[import]
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||||
@ -137,7 +137,7 @@ def get_debug() -> bool:
|
|||||||
def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
||||||
"""Set a new LLM cache, overwriting the previous value, if any."""
|
"""Set a new LLM cache, overwriting the previous value, if any."""
|
||||||
try:
|
try:
|
||||||
import langchain
|
import langchain # type: ignore[import]
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||||
@ -165,7 +165,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
|||||||
def get_llm_cache() -> "BaseCache":
|
def get_llm_cache() -> "BaseCache":
|
||||||
"""Get the value of the `llm_cache` global setting."""
|
"""Get the value of the `llm_cache` global setting."""
|
||||||
try:
|
try:
|
||||||
import langchain
|
import langchain # type: ignore[import]
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
from langchain_core.language_models.base import (
|
||||||
|
BaseLanguageModel,
|
||||||
|
LanguageModelInput,
|
||||||
|
get_tokenizer,
|
||||||
|
)
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||||
from langchain_core.language_models.llms import LLM, BaseLLM
|
from langchain_core.language_models.llms import LLM, BaseLLM
|
||||||
|
|
||||||
@ -9,4 +13,5 @@ __all__ = [
|
|||||||
"BaseLLM",
|
"BaseLLM",
|
||||||
"LLM",
|
"LLM",
|
||||||
"LanguageModelInput",
|
"LanguageModelInput",
|
||||||
|
"get_tokenizer",
|
||||||
]
|
]
|
||||||
|
@ -17,7 +17,7 @@ from typing_extensions import TypeAlias
|
|||||||
|
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
from langchain_core.prompts import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.runnables import RunnableSerializable
|
from langchain_core.runnables import RunnableSerializable
|
||||||
from langchain_core.utils import get_pydantic_field_names
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||||
def get_tokenizer() -> Any:
|
def get_tokenizer() -> Any:
|
||||||
try:
|
try:
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast # type: ignore[import]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import transformers python package. "
|
"Could not import transformers python package. "
|
||||||
@ -74,8 +74,10 @@ class BaseLanguageModel(
|
|||||||
@property
|
@property
|
||||||
def InputType(self) -> TypeAlias:
|
def InputType(self) -> TypeAlias:
|
||||||
"""Get the input type for this runnable."""
|
"""Get the input type for this runnable."""
|
||||||
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
from langchain_core.prompt_values import (
|
||||||
from langchain_core.prompts.string import StringPromptValue
|
ChatPromptValueConcrete,
|
||||||
|
StringPromptValue,
|
||||||
|
)
|
||||||
|
|
||||||
# This is a version of LanguageModelInput which replaces the abstract
|
# This is a version of LanguageModelInput which replaces the abstract
|
||||||
# base class BaseMessage with a union of its subclasses, which makes
|
# base class BaseMessage with a union of its subclasses, which makes
|
||||||
|
@ -39,7 +39,7 @@ from langchain_core.outputs import (
|
|||||||
LLMResult,
|
LLMResult,
|
||||||
RunInfo,
|
RunInfo,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ from langchain_core.language_models.base import BaseLanguageModel, LanguageModel
|
|||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||||
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||||
from langchain_core.runnables import RunnableConfig, get_config_list
|
from langchain_core.runnables import RunnableConfig, get_config_list
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from langchain_core.output_parsers.list import (
|
|||||||
MarkdownListOutputParser,
|
MarkdownListOutputParser,
|
||||||
NumberedListOutputParser,
|
NumberedListOutputParser,
|
||||||
)
|
)
|
||||||
from langchain_core.output_parsers.str import StrOutputParser
|
from langchain_core.output_parsers.string import StrOutputParser
|
||||||
from langchain_core.output_parsers.transform import (
|
from langchain_core.output_parsers.transform import (
|
||||||
BaseCumulativeTransformOutputParser,
|
BaseCumulativeTransformOutputParser,
|
||||||
BaseTransformOutputParser,
|
BaseTransformOutputParser,
|
||||||
|
@ -17,11 +17,8 @@ from typing import (
|
|||||||
from typing_extensions import get_args
|
from typing_extensions import get_args
|
||||||
|
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage
|
from langchain_core.messages import AnyMessage, BaseMessage
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
ChatGeneration,
|
from langchain_core.prompt_values import PromptValue
|
||||||
Generation,
|
|
||||||
)
|
|
||||||
from langchain_core.prompts.value import PromptValue
|
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
76
libs/core/langchain_core/prompt_values.py
Normal file
76
libs/core/langchain_core/prompt_values.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from langchain_core.load.serializable import Serializable
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AnyMessage,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
get_buffer_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptValue(Serializable, ABC):
|
||||||
|
"""Base abstract class for inputs to any language model.
|
||||||
|
|
||||||
|
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||||
|
ChatModel inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this class is serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Return prompt value as string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Return prompt as a list of Messages."""
|
||||||
|
|
||||||
|
|
||||||
|
class StringPromptValue(PromptValue):
|
||||||
|
"""String prompt value."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""Prompt text."""
|
||||||
|
type: Literal["StringPromptValue"] = "StringPromptValue"
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Return prompt as string."""
|
||||||
|
return self.text
|
||||||
|
|
||||||
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Return prompt as messages."""
|
||||||
|
return [HumanMessage(content=self.text)]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPromptValue(PromptValue):
|
||||||
|
"""Chat prompt value.
|
||||||
|
|
||||||
|
A type of a prompt value that is built from messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: Sequence[BaseMessage]
|
||||||
|
"""List of messages."""
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Return prompt as string."""
|
||||||
|
return get_buffer_string(self.messages)
|
||||||
|
|
||||||
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Return prompt as a list of messages."""
|
||||||
|
return list(self.messages)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPromptValueConcrete(ChatPromptValue):
|
||||||
|
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||||
|
For use in external schemas."""
|
||||||
|
|
||||||
|
messages: Sequence[AnyMessage]
|
||||||
|
|
||||||
|
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
@ -23,9 +23,6 @@ from multiple components. Prompt classes and functions make constructing
|
|||||||
AIMessagePromptTemplate
|
AIMessagePromptTemplate
|
||||||
SystemMessagePromptTemplate
|
SystemMessagePromptTemplate
|
||||||
|
|
||||||
PromptValue --> StringPromptValue
|
|
||||||
ChatPromptValue
|
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
@ -33,8 +30,6 @@ from langchain_core.prompts.chat import (
|
|||||||
BaseChatPromptTemplate,
|
BaseChatPromptTemplate,
|
||||||
ChatMessagePromptTemplate,
|
ChatMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
ChatPromptValue,
|
|
||||||
ChatPromptValueConcrete,
|
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
@ -47,7 +42,13 @@ from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemp
|
|||||||
from langchain_core.prompts.loading import load_prompt
|
from langchain_core.prompts.loading import load_prompt
|
||||||
from langchain_core.prompts.pipeline import PipelinePromptTemplate
|
from langchain_core.prompts.pipeline import PipelinePromptTemplate
|
||||||
from langchain_core.prompts.prompt import Prompt, PromptTemplate
|
from langchain_core.prompts.prompt import Prompt, PromptTemplate
|
||||||
from langchain_core.prompts.string import StringPromptTemplate, StringPromptValue
|
from langchain_core.prompts.string import (
|
||||||
|
StringPromptTemplate,
|
||||||
|
check_valid_template,
|
||||||
|
get_template_variables,
|
||||||
|
jinja2_formatter,
|
||||||
|
validate_jinja2,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIMessagePromptTemplate",
|
"AIMessagePromptTemplate",
|
||||||
@ -55,8 +56,6 @@ __all__ = [
|
|||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
"ChatMessagePromptTemplate",
|
"ChatMessagePromptTemplate",
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
"ChatPromptValue",
|
|
||||||
"ChatPromptValueConcrete",
|
|
||||||
"FewShotPromptTemplate",
|
"FewShotPromptTemplate",
|
||||||
"FewShotPromptWithTemplates",
|
"FewShotPromptWithTemplates",
|
||||||
"FewShotChatMessagePromptTemplate",
|
"FewShotChatMessagePromptTemplate",
|
||||||
@ -65,12 +64,12 @@ __all__ = [
|
|||||||
"PipelinePromptTemplate",
|
"PipelinePromptTemplate",
|
||||||
"Prompt",
|
"Prompt",
|
||||||
"PromptTemplate",
|
"PromptTemplate",
|
||||||
"PromptValue",
|
|
||||||
"StringPromptValue",
|
|
||||||
"StringPromptTemplate",
|
"StringPromptTemplate",
|
||||||
"SystemMessagePromptTemplate",
|
"SystemMessagePromptTemplate",
|
||||||
"load_prompt",
|
"load_prompt",
|
||||||
"format_document",
|
"format_document",
|
||||||
|
"check_valid_template",
|
||||||
|
"get_template_variables",
|
||||||
|
"jinja2_formatter",
|
||||||
|
"validate_jinja2",
|
||||||
]
|
]
|
||||||
|
|
||||||
from langchain_core.prompts.value import PromptValue
|
|
||||||
|
@ -8,8 +8,8 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers.base import BaseOutputParser
|
||||||
from langchain_core.prompts.value import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
@ -40,8 +40,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
from langchain_core.prompt_values import (
|
||||||
from langchain_core.prompts.string import StringPromptValue
|
ChatPromptValueConcrete,
|
||||||
|
StringPromptValue,
|
||||||
|
)
|
||||||
|
|
||||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Literal,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -27,12 +26,11 @@ from langchain_core.messages import (
|
|||||||
ChatMessage,
|
ChatMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
get_buffer_string,
|
|
||||||
)
|
)
|
||||||
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import StringPromptTemplate
|
from langchain_core.prompts.string import StringPromptTemplate
|
||||||
from langchain_core.prompts.value import PromptValue
|
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
|
||||||
|
|
||||||
@ -277,33 +275,6 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
|||||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptValue(PromptValue):
|
|
||||||
"""Chat prompt value.
|
|
||||||
|
|
||||||
A type of a prompt value that is built from messages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
messages: Sequence[BaseMessage]
|
|
||||||
"""List of messages."""
|
|
||||||
|
|
||||||
def to_string(self) -> str:
|
|
||||||
"""Return prompt as string."""
|
|
||||||
return get_buffer_string(self.messages)
|
|
||||||
|
|
||||||
def to_messages(self) -> List[BaseMessage]:
|
|
||||||
"""Return prompt as a list of messages."""
|
|
||||||
return list(self.messages)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptValueConcrete(ChatPromptValue):
|
|
||||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
|
||||||
For use in external schemas."""
|
|
||||||
|
|
||||||
messages: Sequence[AnyMessage]
|
|
||||||
|
|
||||||
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||||
"""Base class for chat prompt templates."""
|
"""Base class for chat prompt templates."""
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import Callable, Dict, Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers.string import StrOutputParser
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||||
from langchain_core.prompts.value import PromptValue
|
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,11 +4,10 @@ from __future__ import annotations
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Callable, Dict, List, Literal, Set
|
from typing import Any, Callable, Dict, List, Set
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage
|
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
from langchain_core.prompts.value import PromptValue
|
|
||||||
from langchain_core.utils.formatting import formatter
|
from langchain_core.utils.formatting import formatter
|
||||||
|
|
||||||
|
|
||||||
@ -149,22 +148,6 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
|
|||||||
return sorted(input_variables)
|
return sorted(input_variables)
|
||||||
|
|
||||||
|
|
||||||
class StringPromptValue(PromptValue):
|
|
||||||
"""String prompt value."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""Prompt text."""
|
|
||||||
type: Literal["StringPromptValue"] = "StringPromptValue"
|
|
||||||
|
|
||||||
def to_string(self) -> str:
|
|
||||||
"""Return prompt as string."""
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
def to_messages(self) -> List[BaseMessage]:
|
|
||||||
"""Return prompt as messages."""
|
|
||||||
return [HumanMessage(content=self.text)]
|
|
||||||
|
|
||||||
|
|
||||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||||
"""String prompt that exposes the format method, returning a prompt."""
|
"""String prompt that exposes the format method, returning a prompt."""
|
||||||
|
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(Serializable, ABC):
|
|
||||||
"""Base abstract class for inputs to any language model.
|
|
||||||
|
|
||||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
|
||||||
ChatModel inputs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether this class is serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def to_string(self) -> str:
|
|
||||||
"""Return prompt value as string."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def to_messages(self) -> List[BaseMessage]:
|
|
||||||
"""Return prompt as a list of Messages."""
|
|
@ -9,7 +9,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from tenacity import RetryCallState
|
from tenacity import RetryCallState
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.exceptions import TracerException
|
from langchain_core.exceptions import TracerException
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
|
226
libs/core/langchain_core/tracers/context.py
Normal file
226
libs/core/langchain_core/tracers/context.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langsmith import utils as ls_utils
|
||||||
|
from langsmith.run_helpers import get_run_tree_context
|
||||||
|
|
||||||
|
from langchain_core.tracers.langchain import LangChainTracer
|
||||||
|
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
||||||
|
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
||||||
|
from langchain_core.tracers.schemas import TracerSessionV1
|
||||||
|
from langchain_core.utils.env import env_var_is_set
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langsmith import Client as LangSmithClient
|
||||||
|
|
||||||
|
from langchain_core.callbacks.base import BaseCallbackHandler, Callbacks
|
||||||
|
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
|
|
||||||
|
tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501
|
||||||
|
"tracing_callback", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
|
||||||
|
"tracing_callback_v2", default=None
|
||||||
|
)
|
||||||
|
run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVar( # noqa: E501
|
||||||
|
"run_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def tracing_enabled(
|
||||||
|
session_name: str = "default",
|
||||||
|
) -> Generator[TracerSessionV1, None, None]:
|
||||||
|
"""Get the Deprecated LangChainTracer in a context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_name (str, optional): The name of the session.
|
||||||
|
Defaults to "default".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TracerSessionV1: The LangChainTracer session.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with tracing_enabled() as session:
|
||||||
|
... # Use the LangChainTracer session
|
||||||
|
"""
|
||||||
|
cb = LangChainTracerV1()
|
||||||
|
session = cast(TracerSessionV1, cb.load_session(session_name))
|
||||||
|
try:
|
||||||
|
tracing_callback_var.set(cb)
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
tracing_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def tracing_v2_enabled(
|
||||||
|
project_name: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
client: Optional[LangSmithClient] = None,
|
||||||
|
) -> Generator[LangChainTracer, None, None]:
|
||||||
|
"""Instruct LangChain to log all runs in context to LangSmith.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_name (str, optional): The name of the project.
|
||||||
|
Defaults to "default".
|
||||||
|
example_id (str or UUID, optional): The ID of the example.
|
||||||
|
Defaults to None.
|
||||||
|
tags (List[str], optional): The tags to add to the run.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with tracing_v2_enabled():
|
||||||
|
... # LangChain code will automatically be traced
|
||||||
|
|
||||||
|
You can use this to fetch the LangSmith run URL:
|
||||||
|
|
||||||
|
>>> with tracing_v2_enabled() as cb:
|
||||||
|
... chain.invoke("foo")
|
||||||
|
... run_url = cb.get_run_url()
|
||||||
|
"""
|
||||||
|
if isinstance(example_id, str):
|
||||||
|
example_id = UUID(example_id)
|
||||||
|
cb = LangChainTracer(
|
||||||
|
example_id=example_id,
|
||||||
|
project_name=project_name,
|
||||||
|
tags=tags,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
tracing_v2_callback_var.set(cb)
|
||||||
|
yield cb
|
||||||
|
finally:
|
||||||
|
tracing_v2_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
|
||||||
|
"""Collect all run traces in context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with collect_runs() as runs_cb:
|
||||||
|
chain.invoke("foo")
|
||||||
|
run_id = runs_cb.traced_runs[0].id
|
||||||
|
"""
|
||||||
|
cb = RunCollectorCallbackHandler()
|
||||||
|
run_collector_var.set(cb)
|
||||||
|
yield cb
|
||||||
|
run_collector_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_trace_callbacks(
|
||||||
|
project_name: Optional[str] = None,
|
||||||
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
|
callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None,
|
||||||
|
) -> Callbacks:
|
||||||
|
if _tracing_v2_is_enabled():
|
||||||
|
project_name_ = project_name or _get_tracer_project()
|
||||||
|
tracer = tracing_v2_callback_var.get() or LangChainTracer(
|
||||||
|
project_name=project_name_,
|
||||||
|
example_id=example_id,
|
||||||
|
)
|
||||||
|
if callback_manager is None:
|
||||||
|
from langchain_core.callbacks.base import Callbacks
|
||||||
|
|
||||||
|
cb = cast(Callbacks, [tracer])
|
||||||
|
else:
|
||||||
|
if not any(
|
||||||
|
isinstance(handler, LangChainTracer)
|
||||||
|
for handler in callback_manager.handlers
|
||||||
|
):
|
||||||
|
callback_manager.add_handler(tracer, True)
|
||||||
|
# If it already has a LangChainTracer, we don't need to add another one.
|
||||||
|
# this would likely mess up the trace hierarchy.
|
||||||
|
cb = callback_manager
|
||||||
|
else:
|
||||||
|
cb = None
|
||||||
|
return cb
|
||||||
|
|
||||||
|
|
||||||
|
def _tracing_v2_is_enabled() -> bool:
|
||||||
|
return (
|
||||||
|
env_var_is_set("LANGCHAIN_TRACING_V2")
|
||||||
|
or tracing_v2_callback_var.get() is not None
|
||||||
|
or get_run_tree_context() is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tracer_project() -> str:
|
||||||
|
run_tree = get_run_tree_context()
|
||||||
|
return getattr(
|
||||||
|
run_tree,
|
||||||
|
"session_name",
|
||||||
|
getattr(
|
||||||
|
# Note, if people are trying to nest @traceable functions and the
|
||||||
|
# tracing_v2_enabled context manager, this will likely mess up the
|
||||||
|
# tree structure.
|
||||||
|
tracing_v2_callback_var.get(),
|
||||||
|
"project",
|
||||||
|
# Have to set this to a string even though it always will return
|
||||||
|
# a string because `get_tracer_project` technically can return
|
||||||
|
# None, but only when a specific argument is supplied.
|
||||||
|
# Therefore, this just tricks the mypy type checker
|
||||||
|
str(ls_utils.get_tracer_project()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_configure_hooks: List[
|
||||||
|
Tuple[
|
||||||
|
ContextVar[Optional[BaseCallbackHandler]],
|
||||||
|
bool,
|
||||||
|
Optional[Type[BaseCallbackHandler]],
|
||||||
|
Optional[str],
|
||||||
|
]
|
||||||
|
] = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_configure_hook(
|
||||||
|
context_var: ContextVar[Optional[Any]],
|
||||||
|
inheritable: bool,
|
||||||
|
handle_class: Optional[Type[BaseCallbackHandler]] = None,
|
||||||
|
env_var: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
if env_var is not None and handle_class is None:
|
||||||
|
raise ValueError(
|
||||||
|
"If env_var is set, handle_class must also be set to a non-None value."
|
||||||
|
)
|
||||||
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
|
|
||||||
|
_configure_hooks.append(
|
||||||
|
(
|
||||||
|
# the typings of ContextVar do not have the generic arg set as covariant
|
||||||
|
# so we have to cast it
|
||||||
|
cast(ContextVar[Optional[BaseCallbackHandler]], context_var),
|
||||||
|
inheritable,
|
||||||
|
handle_class,
|
||||||
|
env_var,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_configure_hook(run_collector_var, False)
|
@ -11,9 +11,9 @@ from uuid import UUID
|
|||||||
import langsmith
|
import langsmith
|
||||||
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
|
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
|
||||||
|
|
||||||
from langchain_core.callbacks import manager
|
|
||||||
from langchain_core.tracers import langchain as langchain_tracer
|
from langchain_core.tracers import langchain as langchain_tracer
|
||||||
from langchain_core.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
|
from langchain_core.tracers.context import tracing_v2_enabled
|
||||||
from langchain_core.tracers.langchain import _get_executor
|
from langchain_core.tracers.langchain import _get_executor
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
if self.project_name is None:
|
if self.project_name is None:
|
||||||
eval_result = self.client.evaluate_run(run, evaluator)
|
eval_result = self.client.evaluate_run(run, evaluator)
|
||||||
eval_results = [eval_result]
|
eval_results = [eval_result]
|
||||||
with manager.tracing_v2_enabled(
|
with tracing_v2_enabled(
|
||||||
project_name=self.project_name, tags=["eval"], client=self.client
|
project_name=self.project_name, tags=["eval"], client=self.client
|
||||||
) as cb:
|
) as cb:
|
||||||
reference_example = (
|
reference_example = (
|
||||||
|
@ -15,7 +15,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import jsonpatch
|
import jsonpatch # type: ignore[import]
|
||||||
from anyio import create_memory_object_stream
|
from anyio import create_memory_object_stream
|
||||||
|
|
||||||
from langchain_core.load import load
|
from langchain_core.load import load
|
||||||
|
20
libs/core/langchain_core/utils/env.py
Normal file
20
libs/core/langchain_core/utils/env.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def env_var_is_set(env_var: str) -> bool:
|
||||||
|
"""Check if an environment variable is set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_var (str): The name of the environment variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the environment variable is set, False otherwise.
|
||||||
|
"""
|
||||||
|
return env_var in os.environ and os.environ[env_var] not in (
|
||||||
|
"",
|
||||||
|
"0",
|
||||||
|
"false",
|
||||||
|
"False",
|
||||||
|
)
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.0.2"
|
version = "0.0.3"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
authors = []
|
authors = []
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@ -51,7 +51,6 @@ select = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
ignore_missing_imports = "True"
|
|
||||||
disallow_untyped_defs = "True"
|
disallow_untyped_defs = "True"
|
||||||
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
||||||
|
|
||||||
|
0
libs/core/tests/unit_tests/callbacks/__init__.py
Normal file
0
libs/core/tests/unit_tests/callbacks/__init__.py
Normal file
37
libs/core/tests/unit_tests/callbacks/test_imports.py
Normal file
37
libs/core/tests/unit_tests/callbacks/test_imports.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from langchain_core.callbacks import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"RetrieverManagerMixin",
|
||||||
|
"LLMManagerMixin",
|
||||||
|
"ChainManagerMixin",
|
||||||
|
"ToolManagerMixin",
|
||||||
|
"Callbacks",
|
||||||
|
"CallbackManagerMixin",
|
||||||
|
"RunManagerMixin",
|
||||||
|
"BaseCallbackHandler",
|
||||||
|
"AsyncCallbackHandler",
|
||||||
|
"BaseCallbackManager",
|
||||||
|
"BaseRunManager",
|
||||||
|
"RunManager",
|
||||||
|
"ParentRunManager",
|
||||||
|
"AsyncRunManager",
|
||||||
|
"AsyncParentRunManager",
|
||||||
|
"CallbackManagerForLLMRun",
|
||||||
|
"AsyncCallbackManagerForLLMRun",
|
||||||
|
"CallbackManagerForChainRun",
|
||||||
|
"AsyncCallbackManagerForChainRun",
|
||||||
|
"CallbackManagerForToolRun",
|
||||||
|
"AsyncCallbackManagerForToolRun",
|
||||||
|
"CallbackManagerForRetrieverRun",
|
||||||
|
"AsyncCallbackManagerForRetrieverRun",
|
||||||
|
"CallbackManager",
|
||||||
|
"CallbackManagerForChainGroup",
|
||||||
|
"AsyncCallbackManager",
|
||||||
|
"AsyncCallbackManagerForChainGroup",
|
||||||
|
"StdOutCallbackHandler",
|
||||||
|
"StreamingStdOutCallbackHandler",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
13
libs/core/tests/unit_tests/example_selectors/test_imports.py
Normal file
13
libs/core/tests/unit_tests/example_selectors/test_imports.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from langchain_core.example_selectors import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"BaseExampleSelector",
|
||||||
|
"LengthBasedExampleSelector",
|
||||||
|
"MaxMarginalRelevanceExampleSelector",
|
||||||
|
"SemanticSimilarityExampleSelector",
|
||||||
|
"sorted_values",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
15
libs/core/tests/unit_tests/language_models/test_imports.py
Normal file
15
libs/core/tests/unit_tests/language_models/test_imports.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from langchain_core.language_models import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"BaseLanguageModel",
|
||||||
|
"BaseChatModel",
|
||||||
|
"SimpleChatModel",
|
||||||
|
"BaseLLM",
|
||||||
|
"LLM",
|
||||||
|
"LanguageModelInput",
|
||||||
|
"get_tokenizer",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
0
libs/core/tests/unit_tests/load/__init__.py
Normal file
0
libs/core/tests/unit_tests/load/__init__.py
Normal file
7
libs/core/tests/unit_tests/load/test_imports.py
Normal file
7
libs/core/tests/unit_tests/load/test_imports.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain_core.load import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["dumpd", "dumps", "load", "loads", "Serializable"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
0
libs/core/tests/unit_tests/messages/__init__.py
Normal file
0
libs/core/tests/unit_tests/messages/__init__.py
Normal file
28
libs/core/tests/unit_tests/messages/test_imports.py
Normal file
28
libs/core/tests/unit_tests/messages/test_imports.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from langchain_core.messages import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"AIMessage",
|
||||||
|
"AIMessageChunk",
|
||||||
|
"AnyMessage",
|
||||||
|
"BaseMessage",
|
||||||
|
"BaseMessageChunk",
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatMessageChunk",
|
||||||
|
"FunctionMessage",
|
||||||
|
"FunctionMessageChunk",
|
||||||
|
"HumanMessage",
|
||||||
|
"HumanMessageChunk",
|
||||||
|
"SystemMessage",
|
||||||
|
"SystemMessageChunk",
|
||||||
|
"ToolMessage",
|
||||||
|
"ToolMessageChunk",
|
||||||
|
"get_buffer_string",
|
||||||
|
"messages_from_dict",
|
||||||
|
"messages_to_dict",
|
||||||
|
"message_to_dict",
|
||||||
|
"merge_content",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
18
libs/core/tests/unit_tests/output_parsers/test_imports.py
Normal file
18
libs/core/tests/unit_tests/output_parsers/test_imports.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from langchain_core.output_parsers import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"BaseLLMOutputParser",
|
||||||
|
"BaseGenerationOutputParser",
|
||||||
|
"BaseOutputParser",
|
||||||
|
"ListOutputParser",
|
||||||
|
"CommaSeparatedListOutputParser",
|
||||||
|
"NumberedListOutputParser",
|
||||||
|
"MarkdownListOutputParser",
|
||||||
|
"StrOutputParser",
|
||||||
|
"BaseTransformOutputParser",
|
||||||
|
"BaseCumulativeTransformOutputParser",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
0
libs/core/tests/unit_tests/outputs/__init__.py
Normal file
0
libs/core/tests/unit_tests/outputs/__init__.py
Normal file
15
libs/core/tests/unit_tests/outputs/test_imports.py
Normal file
15
libs/core/tests/unit_tests/outputs/test_imports.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from langchain_core.outputs import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"ChatGeneration",
|
||||||
|
"ChatGenerationChunk",
|
||||||
|
"ChatResult",
|
||||||
|
"Generation",
|
||||||
|
"GenerationChunk",
|
||||||
|
"LLMResult",
|
||||||
|
"RunInfo",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -10,6 +10,7 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
|
from langchain_core.prompt_values import ChatPromptValue
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
@ -17,7 +18,6 @@ from langchain_core.prompts.chat import (
|
|||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatMessagePromptTemplate,
|
ChatMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
ChatPromptValue,
|
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
_convert_to_message,
|
_convert_to_message,
|
||||||
|
@ -6,14 +6,10 @@ EXPECTED_ALL = [
|
|||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
"ChatMessagePromptTemplate",
|
"ChatMessagePromptTemplate",
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
"ChatPromptValueConcrete",
|
|
||||||
"FewShotPromptTemplate",
|
"FewShotPromptTemplate",
|
||||||
"FewShotPromptWithTemplates",
|
"FewShotPromptWithTemplates",
|
||||||
"FewShotChatMessagePromptTemplate",
|
"FewShotChatMessagePromptTemplate",
|
||||||
"format_document",
|
"format_document",
|
||||||
"ChatPromptValue",
|
|
||||||
"PromptValue",
|
|
||||||
"StringPromptValue",
|
|
||||||
"HumanMessagePromptTemplate",
|
"HumanMessagePromptTemplate",
|
||||||
"MessagesPlaceholder",
|
"MessagesPlaceholder",
|
||||||
"PipelinePromptTemplate",
|
"PipelinePromptTemplate",
|
||||||
@ -22,6 +18,10 @@ EXPECTED_ALL = [
|
|||||||
"StringPromptTemplate",
|
"StringPromptTemplate",
|
||||||
"SystemMessagePromptTemplate",
|
"SystemMessagePromptTemplate",
|
||||||
"load_prompt",
|
"load_prompt",
|
||||||
|
"check_valid_template",
|
||||||
|
"get_template_variables",
|
||||||
|
"jinja2_formatter",
|
||||||
|
"validate_jinja2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
28
libs/core/tests/unit_tests/runnables/test_imports.py
Normal file
28
libs/core/tests/unit_tests/runnables/test_imports.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from langchain_core.runnables import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"ConfigurableField",
|
||||||
|
"ConfigurableFieldSingleOption",
|
||||||
|
"ConfigurableFieldMultiOption",
|
||||||
|
"patch_config",
|
||||||
|
"RouterInput",
|
||||||
|
"RouterRunnable",
|
||||||
|
"Runnable",
|
||||||
|
"RunnableSerializable",
|
||||||
|
"RunnableBinding",
|
||||||
|
"RunnableBranch",
|
||||||
|
"RunnableConfig",
|
||||||
|
"RunnableGenerator",
|
||||||
|
"RunnableLambda",
|
||||||
|
"RunnableMap",
|
||||||
|
"RunnableParallel",
|
||||||
|
"RunnablePassthrough",
|
||||||
|
"RunnableSequence",
|
||||||
|
"RunnableWithFallbacks",
|
||||||
|
"get_config_list",
|
||||||
|
"add",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -23,7 +23,6 @@ from typing_extensions import TypedDict
|
|||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
Callbacks,
|
Callbacks,
|
||||||
atrace_as_chain_group,
|
atrace_as_chain_group,
|
||||||
collect_runs,
|
|
||||||
trace_as_chain_group,
|
trace_as_chain_group,
|
||||||
)
|
)
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -39,13 +38,12 @@ from langchain_core.output_parsers import (
|
|||||||
CommaSeparatedListOutputParser,
|
CommaSeparatedListOutputParser,
|
||||||
StrOutputParser,
|
StrOutputParser,
|
||||||
)
|
)
|
||||||
|
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
|
||||||
from langchain_core.prompts import (
|
from langchain_core.prompts import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
ChatPromptValue,
|
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
StringPromptValue,
|
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
@ -75,6 +73,7 @@ from langchain_core.tracers import (
|
|||||||
RunLog,
|
RunLog,
|
||||||
RunLogPatch,
|
RunLogPatch,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tracers.context import collect_runs
|
||||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||||
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||||
|
|
||||||
|
15
libs/core/tests/unit_tests/tracers/test_imports.py
Normal file
15
libs/core/tests/unit_tests/tracers/test_imports.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from langchain_core.tracers import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"BaseTracer",
|
||||||
|
"EvaluatorCallbackHandler",
|
||||||
|
"LangChainTracer",
|
||||||
|
"ConsoleCallbackHandler",
|
||||||
|
"Run",
|
||||||
|
"RunLog",
|
||||||
|
"RunLogPatch",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from langchain.callbacks import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.fake.llm import FakeListLLM
|
||||||
|
|
||||||
|
|
||||||
def test_collect_runs() -> None:
|
def test_collect_runs() -> None:
|
||||||
llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True)
|
llm = FakeListLLM(responses=["hello"])
|
||||||
with collect_runs() as cb:
|
with collect_runs() as cb:
|
||||||
llm.predict("hi")
|
llm.predict("hi")
|
||||||
assert cb.traced_runs
|
assert cb.traced_runs
|
@ -2,9 +2,8 @@ import uuid
|
|||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.prompts.base import StringPromptValue
|
|
||||||
from langchain.prompts.chat import ChatPromptValue
|
|
||||||
from langchain.schema import AIMessage, HumanMessage
|
from langchain.schema import AIMessage, HumanMessage
|
||||||
|
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
|
||||||
|
|
||||||
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
||||||
from langchain_experimental.comprehend_moderation.prompt_safety import (
|
from langchain_experimental.comprehend_moderation.prompt_safety import (
|
||||||
|
@ -29,14 +29,16 @@ from langchain_core.callbacks.manager import (
|
|||||||
RunManager,
|
RunManager,
|
||||||
ahandle_event,
|
ahandle_event,
|
||||||
atrace_as_chain_group,
|
atrace_as_chain_group,
|
||||||
collect_runs,
|
|
||||||
env_var_is_set,
|
|
||||||
handle_event,
|
handle_event,
|
||||||
register_configure_hook,
|
|
||||||
trace_as_chain_group,
|
trace_as_chain_group,
|
||||||
|
)
|
||||||
|
from langchain_core.tracers.context import (
|
||||||
|
collect_runs,
|
||||||
|
register_configure_hook,
|
||||||
tracing_enabled,
|
tracing_enabled,
|
||||||
tracing_v2_enabled,
|
tracing_v2_enabled,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.env import env_var_is_set
|
||||||
|
|
||||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||||
from langchain.callbacks.tracers.wandb import WandbTracer
|
from langchain.callbacks.tracers.wandb import WandbTracer
|
||||||
@ -122,6 +124,6 @@ __all__ = [
|
|||||||
"trace_as_chain_group",
|
"trace_as_chain_group",
|
||||||
"handle_event",
|
"handle_event",
|
||||||
"ahandle_event",
|
"ahandle_event",
|
||||||
"env_var_is_set",
|
|
||||||
"Callbacks",
|
"Callbacks",
|
||||||
|
"env_var_is_set",
|
||||||
]
|
]
|
||||||
|
@ -12,8 +12,8 @@ from langchain_core.load.dump import dumpd
|
|||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
|
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||||
from langchain_core.pydantic_v1 import Extra, Field
|
from langchain_core.pydantic_v1 import Extra, Field
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
Runnable,
|
Runnable,
|
||||||
|
@ -9,7 +9,7 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.prompts import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
@ -13,7 +13,7 @@ from typing import (
|
|||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.outputs import GenerationChunk
|
from langchain_core.outputs import GenerationChunk
|
||||||
from langchain_core.prompts import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
check_package_version,
|
check_package_version,
|
||||||
|
@ -5,8 +5,8 @@ from typing import Any, TypeVar
|
|||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||||
|
|
||||||
NAIVE_COMPLETION_RETRY = """Prompt:
|
NAIVE_COMPLETION_RETRY = """Prompt:
|
||||||
{prompt}
|
{prompt}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import (
|
||||||
from langchain_core.prompts.base import (
|
BasePromptTemplate,
|
||||||
StringPromptTemplate,
|
StringPromptTemplate,
|
||||||
StringPromptValue,
|
|
||||||
check_valid_template,
|
check_valid_template,
|
||||||
get_template_variables,
|
get_template_variables,
|
||||||
jinja2_formatter,
|
jinja2_formatter,
|
||||||
@ -13,7 +12,6 @@ __all__ = [
|
|||||||
"validate_jinja2",
|
"validate_jinja2",
|
||||||
"check_valid_template",
|
"check_valid_template",
|
||||||
"get_template_variables",
|
"get_template_variables",
|
||||||
"StringPromptValue",
|
|
||||||
"StringPromptTemplate",
|
"StringPromptTemplate",
|
||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
]
|
]
|
||||||
|
@ -5,8 +5,6 @@ from langchain_core.prompts.chat import (
|
|||||||
BaseStringMessagePromptTemplate,
|
BaseStringMessagePromptTemplate,
|
||||||
ChatMessagePromptTemplate,
|
ChatMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
ChatPromptValue,
|
|
||||||
ChatPromptValueConcrete,
|
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
@ -20,8 +18,6 @@ __all__ = [
|
|||||||
"HumanMessagePromptTemplate",
|
"HumanMessagePromptTemplate",
|
||||||
"AIMessagePromptTemplate",
|
"AIMessagePromptTemplate",
|
||||||
"SystemMessagePromptTemplate",
|
"SystemMessagePromptTemplate",
|
||||||
"ChatPromptValue",
|
|
||||||
"ChatPromptValueConcrete",
|
|
||||||
"BaseChatPromptTemplate",
|
"BaseChatPromptTemplate",
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
]
|
]
|
||||||
|
@ -31,12 +31,16 @@ from langchain_core.outputs import (
|
|||||||
LLMResult,
|
LLMResult,
|
||||||
RunInfo,
|
RunInfo,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptValue, format_document
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_core.stores import BaseStore
|
from langchain_core.stores import BaseStore
|
||||||
|
|
||||||
RUN_KEY = "__run"
|
RUN_KEY = "__run"
|
||||||
|
|
||||||
|
# Backwards compatibility.
|
||||||
Memory = BaseMemory
|
Memory = BaseMemory
|
||||||
|
_message_to_dict = message_to_dict
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseCache",
|
"BaseCache",
|
||||||
@ -56,6 +60,7 @@ __all__ = [
|
|||||||
"messages_from_dict",
|
"messages_from_dict",
|
||||||
"messages_to_dict",
|
"messages_to_dict",
|
||||||
"message_to_dict",
|
"message_to_dict",
|
||||||
|
"_message_to_dict",
|
||||||
"_message_from_dict",
|
"_message_from_dict",
|
||||||
"get_buffer_string",
|
"get_buffer_string",
|
||||||
"RunInfo",
|
"RunInfo",
|
||||||
|
@ -16,14 +16,16 @@ from langchain_core.callbacks.manager import (
|
|||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
ParentRunManager,
|
ParentRunManager,
|
||||||
RunManager,
|
RunManager,
|
||||||
collect_runs,
|
|
||||||
env_var_is_set,
|
|
||||||
handle_event,
|
handle_event,
|
||||||
register_configure_hook,
|
|
||||||
trace_as_chain_group,
|
trace_as_chain_group,
|
||||||
|
)
|
||||||
|
from langchain_core.tracers.context import (
|
||||||
|
collect_runs,
|
||||||
|
register_configure_hook,
|
||||||
tracing_enabled,
|
tracing_enabled,
|
||||||
tracing_v2_enabled,
|
tracing_v2_enabled,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.env import env_var_is_set
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"tracing_enabled",
|
"tracing_enabled",
|
||||||
@ -48,6 +50,6 @@ __all__ = [
|
|||||||
"CallbackManagerForChainGroup",
|
"CallbackManagerForChainGroup",
|
||||||
"AsyncCallbackManager",
|
"AsyncCallbackManager",
|
||||||
"AsyncCallbackManagerForChainGroup",
|
"AsyncCallbackManagerForChainGroup",
|
||||||
"env_var_is_set",
|
|
||||||
"register_configure_hook",
|
"register_configure_hook",
|
||||||
|
"env_var_is_set",
|
||||||
]
|
]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
from langchain_core.documents import BaseDocumentTransformer, Document
|
from langchain_core.document_transformers import BaseDocumentTransformer
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
__all__ = ["Document", "BaseDocumentTransformer"]
|
__all__ = ["Document", "BaseDocumentTransformer"]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
from langchain_core.language_models import BaseLanguageModel, get_tokenizer
|
from langchain_core.language_models import BaseLanguageModel, get_tokenizer
|
||||||
|
from langchain_core.language_models.base import _get_token_ids_default_method
|
||||||
|
|
||||||
__all__ = ["get_tokenizer", "BaseLanguageModel"]
|
__all__ = ["get_tokenizer", "BaseLanguageModel", "_get_token_ids_default_method"]
|
||||||
|
@ -13,12 +13,17 @@ from langchain_core.messages import (
|
|||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
|
_message_from_dict,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
merge_content,
|
merge_content,
|
||||||
|
message_to_dict,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Backwards compatibility.
|
||||||
|
_message_to_dict = message_to_dict
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_buffer_string",
|
"get_buffer_string",
|
||||||
"BaseMessage",
|
"BaseMessage",
|
||||||
@ -38,4 +43,7 @@ __all__ = [
|
|||||||
"ChatMessageChunk",
|
"ChatMessageChunk",
|
||||||
"messages_to_dict",
|
"messages_to_dict",
|
||||||
"messages_from_dict",
|
"messages_from_dict",
|
||||||
|
"_message_to_dict",
|
||||||
|
"_message_from_dict",
|
||||||
|
"message_to_dict",
|
||||||
]
|
]
|
||||||
|
@ -1,19 +1,23 @@
|
|||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
BaseCumulativeTransformOutputParser,
|
BaseCumulativeTransformOutputParser,
|
||||||
BaseGenerationOutputParser,
|
BaseGenerationOutputParser,
|
||||||
BaseLLMOutputParser,
|
BaseLLMOutputParser,
|
||||||
BaseOutputParser,
|
BaseOutputParser,
|
||||||
BaseTransformOutputParser,
|
BaseTransformOutputParser,
|
||||||
OutputParserException,
|
|
||||||
StrOutputParser,
|
StrOutputParser,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Backwards compatibility.
|
||||||
|
NoOpOutputParser = StrOutputParser
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseLLMOutputParser",
|
"BaseLLMOutputParser",
|
||||||
"BaseGenerationOutputParser",
|
"BaseGenerationOutputParser",
|
||||||
"BaseOutputParser",
|
"BaseOutputParser",
|
||||||
"BaseTransformOutputParser",
|
"BaseTransformOutputParser",
|
||||||
"BaseCumulativeTransformOutputParser",
|
"BaseCumulativeTransformOutputParser",
|
||||||
|
"NoOpOutputParser",
|
||||||
"StrOutputParser",
|
"StrOutputParser",
|
||||||
"OutputParserException",
|
"OutputParserException",
|
||||||
]
|
]
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from langchain_core.prompts import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
|
||||||
__all__ = ["PromptValue"]
|
__all__ = ["PromptValue"]
|
||||||
|
@ -12,6 +12,9 @@ from langchain_core.runnables.base import (
|
|||||||
coerce_to_runnable,
|
coerce_to_runnable,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Backwards compatibility.
|
||||||
|
RunnableMap = RunnableParallel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Runnable",
|
"Runnable",
|
||||||
"RunnableSerializable",
|
"RunnableSerializable",
|
||||||
@ -23,5 +26,6 @@ __all__ = [
|
|||||||
"RunnableEach",
|
"RunnableEach",
|
||||||
"RunnableBindingBase",
|
"RunnableBindingBase",
|
||||||
"RunnableBinding",
|
"RunnableBinding",
|
||||||
|
"RunnableMap",
|
||||||
"coerce_to_runnable",
|
"coerce_to_runnable",
|
||||||
]
|
]
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
EmptyDict,
|
EmptyDict,
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
|
acall_func_with_variable_args,
|
||||||
call_func_with_variable_args,
|
call_func_with_variable_args,
|
||||||
ensure_config,
|
ensure_config,
|
||||||
get_async_callback_manager_for_config,
|
get_async_callback_manager_for_config,
|
||||||
@ -18,6 +19,7 @@ __all__ = [
|
|||||||
"get_config_list",
|
"get_config_list",
|
||||||
"patch_config",
|
"patch_config",
|
||||||
"merge_configs",
|
"merge_configs",
|
||||||
|
"acall_func_with_variable_args",
|
||||||
"call_func_with_variable_args",
|
"call_func_with_variable_args",
|
||||||
"get_callback_manager_for_config",
|
"get_callback_manager_for_config",
|
||||||
"get_async_callback_manager_for_config",
|
"get_async_callback_manager_for_config",
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from langchain_core.runnables.passthrough import (
|
from langchain_core.runnables.passthrough import (
|
||||||
RunnableAssign,
|
RunnableAssign,
|
||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
|
aidentity,
|
||||||
identity,
|
identity,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ["identity", "RunnablePassthrough", "RunnableAssign"]
|
__all__ = ["aidentity", "identity", "RunnablePassthrough", "RunnableAssign"]
|
||||||
|
@ -8,9 +8,12 @@ from langchain_core.runnables.utils import (
|
|||||||
IsFunctionArgDict,
|
IsFunctionArgDict,
|
||||||
IsLocalDict,
|
IsLocalDict,
|
||||||
SupportsAdd,
|
SupportsAdd,
|
||||||
|
aadd,
|
||||||
accepts_config,
|
accepts_config,
|
||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
add,
|
add,
|
||||||
|
gated_coro,
|
||||||
|
gather_with_concurrency,
|
||||||
get_function_first_arg_dict_keys,
|
get_function_first_arg_dict_keys,
|
||||||
get_lambda_source,
|
get_lambda_source,
|
||||||
get_unique_config_specs,
|
get_unique_config_specs,
|
||||||
@ -34,4 +37,7 @@ __all__ = [
|
|||||||
"ConfigurableFieldMultiOption",
|
"ConfigurableFieldMultiOption",
|
||||||
"ConfigurableFieldSpec",
|
"ConfigurableFieldSpec",
|
||||||
"get_unique_config_specs",
|
"get_unique_config_specs",
|
||||||
|
"aadd",
|
||||||
|
"gated_coro",
|
||||||
|
"gather_with_concurrency",
|
||||||
]
|
]
|
||||||
|
13
libs/langchain/poetry.lock
generated
13
libs/langchain/poetry.lock
generated
@ -4128,13 +4128,13 @@ tests = ["pandas (>=1.4)", "pytest", "pytest-asyncio", "pytest-mock"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.0.1"
|
version = "0.0.2"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "langchain_core-0.0.1-py3-none-any.whl", hash = "sha256:cad923dd3bc39cd9fe24b9d6a9799c97719aeaafc9b19509fe1347109fcb65b3"},
|
{file = "langchain_core-0.0.2-py3-none-any.whl", hash = "sha256:dd448c7887c24105761a0763bee5ec5b072de905b4fbd83d693e7a181fd63208"},
|
||||||
{file = "langchain_core-0.0.1.tar.gz", hash = "sha256:488b72223e14849bf9588ed677a999b282904d1d5e1f81d12767ee1024220724"},
|
{file = "langchain_core-0.0.2.tar.gz", hash = "sha256:772600dfe3e707adb9055ed96797763b0de32b394eac8325bf609a7e5929dda2"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -6515,6 +6515,8 @@ files = [
|
|||||||
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
|
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
|
||||||
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
|
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
|
||||||
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
|
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
|
||||||
|
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
|
||||||
|
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
|
||||||
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
|
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
|
||||||
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
|
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
|
||||||
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
|
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
|
||||||
@ -6557,6 +6559,7 @@ files = [
|
|||||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
|
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
|
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
|
||||||
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
|
||||||
@ -6565,6 +6568,8 @@ files = [
|
|||||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
|
||||||
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
|
||||||
|
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
|
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
|
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
|
||||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
|
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
|
||||||
@ -11080,4 +11085,4 @@ text-helpers = ["chardet"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "6781b828c8bc2b08ece3cbf82be799dc0e361b9f6f9c204ddefcfee70ab0db8b"
|
content-hash = "3660cb9106129dcf1a2fe157c1bdfe567923fc9ed408672bcbcb70f4345f1285"
|
||||||
|
@ -12,7 +12,7 @@ langchain-server = "langchain.server:main"
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = "^0.0.1"
|
langchain-core = "^0.0.2"
|
||||||
pydantic = ">=1,<3"
|
pydantic = ">=1,<3"
|
||||||
SQLAlchemy = ">=1.4,<3"
|
SQLAlchemy = ">=1.4,<3"
|
||||||
requests = "^2"
|
requests = "^2"
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
@ -16,6 +15,8 @@ def _chat_message_history(
|
|||||||
drop: bool = True,
|
drop: bool = True,
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> CassandraChatMessageHistory:
|
) -> CassandraChatMessageHistory:
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
|
|
||||||
keyspace = "cmh_test_keyspace"
|
keyspace = "cmh_test_keyspace"
|
||||||
table_name = "cmh_test_table"
|
table_name = "cmh_test_table"
|
||||||
# get db connection
|
# get db connection
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.vectorstores import Cassandra
|
from langchain.vectorstores import Cassandra
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
@ -19,6 +17,8 @@ def _vectorstore_from_texts(
|
|||||||
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
|
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
|
||||||
drop: bool = True,
|
drop: bool = True,
|
||||||
) -> Cassandra:
|
) -> Cassandra:
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
|
|
||||||
keyspace = "vector_test_keyspace"
|
keyspace = "vector_test_keyspace"
|
||||||
table_name = "vector_test_table"
|
table_name = "vector_test_table"
|
||||||
# get db connection
|
# get db connection
|
||||||
@ -154,12 +154,3 @@ def test_cassandra_delete() -> None:
|
|||||||
time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
output = docsearch.similarity_search("foo", k=10)
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
assert len(output) == 0
|
assert len(output) == 0
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# test_cassandra()
|
|
||||||
# test_cassandra_with_score()
|
|
||||||
# test_cassandra_max_marginal_relevance_search()
|
|
||||||
# test_cassandra_add_extra()
|
|
||||||
# test_cassandra_no_drop()
|
|
||||||
# test_cassandra_delete()
|
|
||||||
|
20
libs/langchain/tests/unit_tests/schema/runnable/test_base.py
Normal file
20
libs/langchain/tests/unit_tests/schema/runnable/test_base.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from langchain.schema.runnable.base import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"Runnable",
|
||||||
|
"RunnableBinding",
|
||||||
|
"RunnableBindingBase",
|
||||||
|
"RunnableEach",
|
||||||
|
"RunnableEachBase",
|
||||||
|
"RunnableGenerator",
|
||||||
|
"RunnableLambda",
|
||||||
|
"RunnableMap",
|
||||||
|
"RunnableParallel",
|
||||||
|
"RunnableSequence",
|
||||||
|
"RunnableSerializable",
|
||||||
|
"coerce_to_runnable",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.runnable.branch import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["RunnableBranch"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,19 @@
|
|||||||
|
from langchain.schema.runnable.config import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"EmptyDict",
|
||||||
|
"RunnableConfig",
|
||||||
|
"acall_func_with_variable_args",
|
||||||
|
"call_func_with_variable_args",
|
||||||
|
"ensure_config",
|
||||||
|
"get_async_callback_manager_for_config",
|
||||||
|
"get_callback_manager_for_config",
|
||||||
|
"get_config_list",
|
||||||
|
"get_executor_for_config",
|
||||||
|
"merge_configs",
|
||||||
|
"patch_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,13 @@
|
|||||||
|
from langchain.schema.runnable.configurable import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"DynamicRunnable",
|
||||||
|
"RunnableConfigurableAlternatives",
|
||||||
|
"RunnableConfigurableFields",
|
||||||
|
"StrEnum",
|
||||||
|
"make_options_spec",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.runnable.fallbacks import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["RunnableWithFallbacks"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.runnable.history import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["RunnableWithMessageHistory"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,26 @@
|
|||||||
|
from langchain.schema.runnable import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"ConfigurableField",
|
||||||
|
"ConfigurableFieldSingleOption",
|
||||||
|
"ConfigurableFieldMultiOption",
|
||||||
|
"patch_config",
|
||||||
|
"RouterInput",
|
||||||
|
"RouterRunnable",
|
||||||
|
"Runnable",
|
||||||
|
"RunnableSerializable",
|
||||||
|
"RunnableBinding",
|
||||||
|
"RunnableBranch",
|
||||||
|
"RunnableConfig",
|
||||||
|
"RunnableGenerator",
|
||||||
|
"RunnableLambda",
|
||||||
|
"RunnableMap",
|
||||||
|
"RunnableParallel",
|
||||||
|
"RunnablePassthrough",
|
||||||
|
"RunnableSequence",
|
||||||
|
"RunnableWithFallbacks",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.runnable.passthrough import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["RunnableAssign", "RunnablePassthrough", "aidentity", "identity"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.runnable.retry import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["RunnableRetry"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.runnable.router import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["RouterInput", "RouterRunnable"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,27 @@
|
|||||||
|
from langchain.schema.runnable.utils import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"AddableDict",
|
||||||
|
"ConfigurableField",
|
||||||
|
"ConfigurableFieldMultiOption",
|
||||||
|
"ConfigurableFieldSingleOption",
|
||||||
|
"ConfigurableFieldSpec",
|
||||||
|
"GetLambdaSource",
|
||||||
|
"IsFunctionArgDict",
|
||||||
|
"IsLocalDict",
|
||||||
|
"SupportsAdd",
|
||||||
|
"aadd",
|
||||||
|
"accepts_config",
|
||||||
|
"accepts_run_manager",
|
||||||
|
"add",
|
||||||
|
"gated_coro",
|
||||||
|
"gather_with_concurrency",
|
||||||
|
"get_function_first_arg_dict_keys",
|
||||||
|
"get_lambda_source",
|
||||||
|
"get_unique_config_specs",
|
||||||
|
"indent_lines_after_first",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_agent.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_agent.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.agent import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["AgentAction", "AgentActionMessageLog", "AgentFinish"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_cache.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_cache.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.cache import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["BaseCache"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_chat.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_chat.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.chat import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["ChatSession"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.chat_history import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["BaseChatMessageHistory"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_document.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_document.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.document import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["BaseDocumentTransformer", "Document"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.embeddings import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["Embeddings"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.exceptions import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["LangChainException"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -18,6 +18,7 @@ EXPECTED_ALL = [
|
|||||||
"messages_from_dict",
|
"messages_from_dict",
|
||||||
"messages_to_dict",
|
"messages_to_dict",
|
||||||
"message_to_dict",
|
"message_to_dict",
|
||||||
|
"_message_to_dict",
|
||||||
"_message_from_dict",
|
"_message_from_dict",
|
||||||
"get_buffer_string",
|
"get_buffer_string",
|
||||||
"RunInfo",
|
"RunInfo",
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.language_model import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["BaseLanguageModel", "_get_token_ids_default_method", "get_tokenizer"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_memory.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_memory.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.memory import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["BaseMemory"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -1,101 +1,29 @@
|
|||||||
import pytest
|
from langchain.schema.messages import __all__
|
||||||
from langchain_core.messages import (
|
|
||||||
AIMessageChunk,
|
EXPECTED_ALL = [
|
||||||
ChatMessageChunk,
|
"AIMessage",
|
||||||
FunctionMessageChunk,
|
"AIMessageChunk",
|
||||||
HumanMessageChunk,
|
"BaseMessage",
|
||||||
)
|
"BaseMessageChunk",
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatMessageChunk",
|
||||||
|
"FunctionMessage",
|
||||||
|
"FunctionMessageChunk",
|
||||||
|
"HumanMessage",
|
||||||
|
"HumanMessageChunk",
|
||||||
|
"SystemMessage",
|
||||||
|
"SystemMessageChunk",
|
||||||
|
"ToolMessage",
|
||||||
|
"ToolMessageChunk",
|
||||||
|
"_message_from_dict",
|
||||||
|
"_message_to_dict",
|
||||||
|
"message_to_dict",
|
||||||
|
"get_buffer_string",
|
||||||
|
"merge_content",
|
||||||
|
"messages_from_dict",
|
||||||
|
"messages_to_dict",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_message_chunks() -> None:
|
def test_all_imports() -> None:
|
||||||
assert AIMessageChunk(content="I am") + AIMessageChunk(
|
assert set(__all__) == set(EXPECTED_ALL)
|
||||||
content=" indeed."
|
|
||||||
) == AIMessageChunk(
|
|
||||||
content="I am indeed."
|
|
||||||
), "MessageChunk + MessageChunk should be a MessageChunk"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
AIMessageChunk(content="I am") + HumanMessageChunk(content=" indeed.")
|
|
||||||
== AIMessageChunk(content="I am indeed.")
|
|
||||||
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
|
|
||||||
|
|
||||||
assert (
|
|
||||||
AIMessageChunk(content="", additional_kwargs={"foo": "bar"})
|
|
||||||
+ AIMessageChunk(content="", additional_kwargs={"baz": "foo"})
|
|
||||||
== AIMessageChunk(content="", additional_kwargs={"foo": "bar", "baz": "foo"})
|
|
||||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
|
||||||
|
|
||||||
assert (
|
|
||||||
AIMessageChunk(
|
|
||||||
content="", additional_kwargs={"function_call": {"name": "web_search"}}
|
|
||||||
)
|
|
||||||
+ AIMessageChunk(
|
|
||||||
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
|
|
||||||
)
|
|
||||||
+ AIMessageChunk(
|
|
||||||
content="",
|
|
||||||
additional_kwargs={
|
|
||||||
"function_call": {"arguments": ' "query": "turtles"\n}'}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
== AIMessageChunk(
|
|
||||||
content="",
|
|
||||||
additional_kwargs={
|
|
||||||
"function_call": {
|
|
||||||
"name": "web_search",
|
|
||||||
"arguments": '{\n "query": "turtles"\n}',
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_message_chunks() -> None:
|
|
||||||
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
|
||||||
role="User", content=" indeed."
|
|
||||||
) == ChatMessageChunk(
|
|
||||||
role="User", content="I am indeed."
|
|
||||||
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
|
||||||
role="Assistant", content=" indeed."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ChatMessageChunk(role="User", content="I am")
|
|
||||||
+ AIMessageChunk(content=" indeed.")
|
|
||||||
== ChatMessageChunk(role="User", content="I am indeed.")
|
|
||||||
), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501
|
|
||||||
|
|
||||||
assert AIMessageChunk(content="I am") + ChatMessageChunk(
|
|
||||||
role="User", content=" indeed."
|
|
||||||
) == AIMessageChunk(
|
|
||||||
content="I am indeed."
|
|
||||||
), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
def test_function_message_chunks() -> None:
|
|
||||||
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
|
||||||
name="hello", content=" indeed."
|
|
||||||
) == FunctionMessageChunk(
|
|
||||||
name="hello", content="I am indeed."
|
|
||||||
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
|
||||||
name="bye", content=" indeed."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ani_message_chunks() -> None:
|
|
||||||
assert AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
|
||||||
example=True, content=" indeed."
|
|
||||||
) == AIMessageChunk(
|
|
||||||
example=True, content="I am indeed."
|
|
||||||
), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
|
||||||
example=False, content=" indeed."
|
|
||||||
)
|
|
||||||
|
@ -1,60 +1,15 @@
|
|||||||
from langchain_core.messages import HumanMessageChunk
|
from langchain.schema.output import __all__
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"ChatGeneration",
|
||||||
|
"ChatGenerationChunk",
|
||||||
|
"ChatResult",
|
||||||
|
"Generation",
|
||||||
|
"GenerationChunk",
|
||||||
|
"LLMResult",
|
||||||
|
"RunInfo",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_generation_chunk() -> None:
|
def test_all_imports() -> None:
|
||||||
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
assert set(__all__) == set(EXPECTED_ALL)
|
||||||
text="world!"
|
|
||||||
) == GenerationChunk(
|
|
||||||
text="Hello, world!"
|
|
||||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
GenerationChunk(text="Hello, ")
|
|
||||||
+ GenerationChunk(text="world!", generation_info={"foo": "bar"})
|
|
||||||
== GenerationChunk(text="Hello, world!", generation_info={"foo": "bar"})
|
|
||||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
|
||||||
|
|
||||||
assert (
|
|
||||||
GenerationChunk(text="Hello, ")
|
|
||||||
+ GenerationChunk(text="world!", generation_info={"foo": "bar"})
|
|
||||||
+ GenerationChunk(text="!", generation_info={"baz": "foo"})
|
|
||||||
== GenerationChunk(
|
|
||||||
text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"}
|
|
||||||
)
|
|
||||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_generation_chunk() -> None:
|
|
||||||
assert ChatGenerationChunk(
|
|
||||||
message=HumanMessageChunk(content="Hello, ")
|
|
||||||
) + ChatGenerationChunk(
|
|
||||||
message=HumanMessageChunk(content="world!")
|
|
||||||
) == ChatGenerationChunk(
|
|
||||||
message=HumanMessageChunk(content="Hello, world!")
|
|
||||||
), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ChatGenerationChunk(message=HumanMessageChunk(content="Hello, "))
|
|
||||||
+ ChatGenerationChunk(
|
|
||||||
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
|
||||||
)
|
|
||||||
== ChatGenerationChunk(
|
|
||||||
message=HumanMessageChunk(content="Hello, world!"),
|
|
||||||
generation_info={"foo": "bar"},
|
|
||||||
)
|
|
||||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ChatGenerationChunk(message=HumanMessageChunk(content="Hello, "))
|
|
||||||
+ ChatGenerationChunk(
|
|
||||||
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
|
||||||
)
|
|
||||||
+ ChatGenerationChunk(
|
|
||||||
message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"}
|
|
||||||
)
|
|
||||||
== ChatGenerationChunk(
|
|
||||||
message=HumanMessageChunk(content="Hello, world!!"),
|
|
||||||
generation_info={"foo": "bar", "baz": "foo"},
|
|
||||||
)
|
|
||||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
|
||||||
|
16
libs/langchain/tests/unit_tests/schema/test_output_parser.py
Normal file
16
libs/langchain/tests/unit_tests/schema/test_output_parser.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from langchain.schema.output_parser import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"BaseCumulativeTransformOutputParser",
|
||||||
|
"BaseGenerationOutputParser",
|
||||||
|
"BaseLLMOutputParser",
|
||||||
|
"BaseOutputParser",
|
||||||
|
"BaseTransformOutputParser",
|
||||||
|
"NoOpOutputParser",
|
||||||
|
"OutputParserException",
|
||||||
|
"StrOutputParser",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_prompt.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_prompt.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.prompt import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["PromptValue"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.prompt_template import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["BasePromptTemplate", "format_document"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_retriever.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_retriever.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.retriever import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["BaseRetriever"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_storage.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_storage.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.storage import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["BaseStore"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.schema.vectorstore import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["VectorStore", "VectorStoreRetriever"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
@ -21,7 +21,7 @@ from langchain_core.messages import (
|
|||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation
|
||||||
from langchain_core.prompts import ChatPromptValueConcrete, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user