mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
BUG: more core fixes (#13665)
Fix some circular deps: - move PromptValue into top level module bc both PromptTemplates and OutputParsers import - move tracer context vars to `tracers.context` and import them in functions in `callbacks.manager` - add core import tests
This commit is contained in:
parent
59df16ab92
commit
c61e30632e
40
.github/workflows/langchain_ci.yml
vendored
40
.github/workflows/langchain_ci.yml
vendored
@ -63,6 +63,46 @@ jobs:
|
||||
langchain-core-location: ../core
|
||||
secrets: inherit
|
||||
|
||||
# It's possible that langchain works fine with the latest *published* langchain-core,
|
||||
# but is broken with the langchain-core on `master`.
|
||||
#
|
||||
# We want to catch situations like that *before* releasing a new langchain-core, hence this test.
|
||||
test-with-latest-langchain-core:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ env.WORKDIR }}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
name: test with unpublished langchain-core - Python ${{ matrix.python-version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ env.WORKDIR }}
|
||||
cache-key: unpublished-langchain-core
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
echo "Running tests with unpublished langchain, installing dependencies with poetry..."
|
||||
poetry install
|
||||
|
||||
echo "Editably installing langchain-core outside of poetry, to avoid messing up lockfile..."
|
||||
poetry run pip install -e ../core
|
||||
|
||||
- name: Run tests
|
||||
run: make test
|
||||
|
||||
extended-tests:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
|
@ -28,8 +28,6 @@ from langchain_core.callbacks.manager import (
|
||||
CallbackManagerForToolRun,
|
||||
ParentRunManager,
|
||||
RunManager,
|
||||
env_var_is_set,
|
||||
register_configure_hook,
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
@ -64,6 +62,4 @@ __all__ = [
|
||||
"AsyncCallbackManagerForChainGroup",
|
||||
"StdOutCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
"env_var_is_set",
|
||||
"register_configure_hook",
|
||||
]
|
||||
|
@ -3,13 +3,10 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Coroutine,
|
||||
@ -18,7 +15,6 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -26,14 +22,10 @@ from typing import (
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import utils as ls_utils
|
||||
from langsmith.run_helpers import get_run_tree_context
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.agents import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
)
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
@ -48,32 +40,10 @@ from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
from langchain_core.tracers import run_collector
|
||||
from langchain_core.tracers.langchain import (
|
||||
LangChainTracer,
|
||||
)
|
||||
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
||||
from langchain_core.tracers.schemas import TracerSessionV1
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langsmith import Client as LangSmithClient
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501
|
||||
"tracing_callback", default=None
|
||||
)
|
||||
|
||||
tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
|
||||
"tracing_callback_v2", default=None
|
||||
)
|
||||
run_collector_var: ContextVar[
|
||||
Optional[run_collector.RunCollectorCallbackHandler]
|
||||
] = ContextVar( # noqa: E501
|
||||
"run_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def _get_debug() -> bool:
|
||||
from langchain_core.globals import get_debug
|
||||
@ -81,123 +51,6 @@ def _get_debug() -> bool:
|
||||
return get_debug()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[TracerSessionV1, None, None]:
|
||||
"""Get the Deprecated LangChainTracer in a context manager.
|
||||
|
||||
Args:
|
||||
session_name (str, optional): The name of the session.
|
||||
Defaults to "default".
|
||||
|
||||
Returns:
|
||||
TracerSessionV1: The LangChainTracer session.
|
||||
|
||||
Example:
|
||||
>>> with tracing_enabled() as session:
|
||||
... # Use the LangChainTracer session
|
||||
"""
|
||||
cb = LangChainTracerV1()
|
||||
session = cast(TracerSessionV1, cb.load_session(session_name))
|
||||
try:
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
finally:
|
||||
tracing_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_v2_enabled(
|
||||
project_name: Optional[str] = None,
|
||||
*,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
client: Optional[LangSmithClient] = None,
|
||||
) -> Generator[LangChainTracer, None, None]:
|
||||
"""Instruct LangChain to log all runs in context to LangSmith.
|
||||
|
||||
Args:
|
||||
project_name (str, optional): The name of the project.
|
||||
Defaults to "default".
|
||||
example_id (str or UUID, optional): The ID of the example.
|
||||
Defaults to None.
|
||||
tags (List[str], optional): The tags to add to the run.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Example:
|
||||
>>> with tracing_v2_enabled():
|
||||
... # LangChain code will automatically be traced
|
||||
|
||||
You can use this to fetch the LangSmith run URL:
|
||||
|
||||
>>> with tracing_v2_enabled() as cb:
|
||||
... chain.invoke("foo")
|
||||
... run_url = cb.get_run_url()
|
||||
"""
|
||||
if isinstance(example_id, str):
|
||||
example_id = UUID(example_id)
|
||||
cb = LangChainTracer(
|
||||
example_id=example_id,
|
||||
project_name=project_name,
|
||||
tags=tags,
|
||||
client=client,
|
||||
)
|
||||
try:
|
||||
tracing_v2_callback_var.set(cb)
|
||||
yield cb
|
||||
finally:
|
||||
tracing_v2_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]:
|
||||
"""Collect all run traces in context.
|
||||
|
||||
Returns:
|
||||
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
|
||||
|
||||
Example:
|
||||
>>> with collect_runs() as runs_cb:
|
||||
chain.invoke("foo")
|
||||
run_id = runs_cb.traced_runs[0].id
|
||||
"""
|
||||
cb = run_collector.RunCollectorCallbackHandler()
|
||||
run_collector_var.set(cb)
|
||||
yield cb
|
||||
run_collector_var.set(None)
|
||||
|
||||
|
||||
def _get_trace_callbacks(
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None,
|
||||
) -> Callbacks:
|
||||
if _tracing_v2_is_enabled():
|
||||
project_name_ = project_name or _get_tracer_project()
|
||||
tracer = tracing_v2_callback_var.get() or LangChainTracer(
|
||||
project_name=project_name_,
|
||||
example_id=example_id,
|
||||
)
|
||||
if callback_manager is None:
|
||||
cb = cast(Callbacks, [tracer])
|
||||
else:
|
||||
if not any(
|
||||
isinstance(handler, LangChainTracer)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(tracer, True)
|
||||
# If it already has a LangChainTracer, we don't need to add another one.
|
||||
# this would likely mess up the trace hierarchy.
|
||||
cb = callback_manager
|
||||
else:
|
||||
cb = None
|
||||
return cb
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_as_chain_group(
|
||||
group_name: str,
|
||||
@ -239,6 +92,8 @@ def trace_as_chain_group(
|
||||
res = llm.predict(llm_input, callbacks=manager)
|
||||
manager.on_chain_end({"output": res})
|
||||
""" # noqa: E501
|
||||
from langchain_core.tracers.context import _get_trace_callbacks
|
||||
|
||||
cb = _get_trace_callbacks(
|
||||
project_name, example_id, callback_manager=callback_manager
|
||||
)
|
||||
@ -310,6 +165,8 @@ async def atrace_as_chain_group(
|
||||
res = await llm.apredict(llm_input, callbacks=manager)
|
||||
await manager.on_chain_end({"output": res})
|
||||
""" # noqa: E501
|
||||
from langchain_core.tracers.context import _get_trace_callbacks
|
||||
|
||||
cb = _get_trace_callbacks(
|
||||
project_name, example_id, callback_manager=callback_manager
|
||||
)
|
||||
@ -1850,88 +1707,9 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
||||
|
||||
|
||||
def env_var_is_set(env_var: str) -> bool:
|
||||
"""Check if an environment variable is set.
|
||||
|
||||
Args:
|
||||
env_var (str): The name of the environment variable.
|
||||
|
||||
Returns:
|
||||
bool: True if the environment variable is set, False otherwise.
|
||||
"""
|
||||
return env_var in os.environ and os.environ[env_var] not in (
|
||||
"",
|
||||
"0",
|
||||
"false",
|
||||
"False",
|
||||
)
|
||||
|
||||
|
||||
def _tracing_v2_is_enabled() -> bool:
|
||||
return (
|
||||
env_var_is_set("LANGCHAIN_TRACING_V2")
|
||||
or tracing_v2_callback_var.get() is not None
|
||||
or get_run_tree_context() is not None
|
||||
)
|
||||
|
||||
|
||||
def _get_tracer_project() -> str:
|
||||
run_tree = get_run_tree_context()
|
||||
return getattr(
|
||||
run_tree,
|
||||
"session_name",
|
||||
getattr(
|
||||
# Note, if people are trying to nest @traceable functions and the
|
||||
# tracing_v2_enabled context manager, this will likely mess up the
|
||||
# tree structure.
|
||||
tracing_v2_callback_var.get(),
|
||||
"project",
|
||||
# Have to set this to a string even though it always will return
|
||||
# a string because `get_tracer_project` technically can return
|
||||
# None, but only when a specific argument is supplied.
|
||||
# Therefore, this just tricks the mypy type checker
|
||||
str(ls_utils.get_tracer_project()),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_configure_hooks: List[
|
||||
Tuple[
|
||||
ContextVar[Optional[BaseCallbackHandler]],
|
||||
bool,
|
||||
Optional[Type[BaseCallbackHandler]],
|
||||
Optional[str],
|
||||
]
|
||||
] = []
|
||||
|
||||
H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
|
||||
|
||||
|
||||
def register_configure_hook(
|
||||
context_var: ContextVar[Optional[Any]],
|
||||
inheritable: bool,
|
||||
handle_class: Optional[Type[BaseCallbackHandler]] = None,
|
||||
env_var: Optional[str] = None,
|
||||
) -> None:
|
||||
if env_var is not None and handle_class is None:
|
||||
raise ValueError(
|
||||
"If env_var is set, handle_class must also be set to a non-None value."
|
||||
)
|
||||
_configure_hooks.append(
|
||||
(
|
||||
# the typings of ContextVar do not have the generic arg set as covariant
|
||||
# so we have to cast it
|
||||
cast(ContextVar[Optional[BaseCallbackHandler]], context_var),
|
||||
inheritable,
|
||||
handle_class,
|
||||
env_var,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_configure_hook(run_collector_var, False)
|
||||
|
||||
|
||||
def _configure(
|
||||
callback_manager_cls: Type[T],
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
@ -1962,6 +1740,14 @@ def _configure(
|
||||
Returns:
|
||||
T: The configured callback manager.
|
||||
"""
|
||||
from langchain_core.tracers.context import (
|
||||
_configure_hooks,
|
||||
_get_tracer_project,
|
||||
_tracing_v2_is_enabled,
|
||||
tracing_callback_var,
|
||||
tracing_v2_callback_var,
|
||||
)
|
||||
|
||||
run_tree = get_run_tree_context()
|
||||
parent_run_id = None if run_tree is None else getattr(run_tree, "id")
|
||||
callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id)
|
||||
@ -2009,6 +1795,10 @@ def _configure(
|
||||
tracer_project = _get_tracer_project()
|
||||
debug = _get_debug()
|
||||
if verbose or debug or tracing_enabled_ or tracing_v2_enabled_:
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
|
||||
if verbose and not any(
|
||||
isinstance(handler, StdOutCallbackHandler)
|
||||
for handler in callback_manager.handlers
|
||||
|
@ -19,7 +19,7 @@ _llm_cache: Optional["BaseCache"] = None
|
||||
def set_verbose(value: bool) -> None:
|
||||
"""Set a new value for the `verbose` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
@ -47,7 +47,7 @@ def set_verbose(value: bool) -> None:
|
||||
def get_verbose() -> bool:
|
||||
"""Get the value of the `verbose` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
@ -80,7 +80,7 @@ def get_verbose() -> bool:
|
||||
def set_debug(value: bool) -> None:
|
||||
"""Set a new value for the `debug` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
@ -106,7 +106,7 @@ def set_debug(value: bool) -> None:
|
||||
def get_debug() -> bool:
|
||||
"""Get the value of the `debug` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
@ -137,7 +137,7 @@ def get_debug() -> bool:
|
||||
def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
||||
"""Set a new LLM cache, overwriting the previous value, if any."""
|
||||
try:
|
||||
import langchain
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
@ -165,7 +165,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
||||
def get_llm_cache() -> "BaseCache":
|
||||
"""Get the value of the `llm_cache` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
|
@ -1,4 +1,8 @@
|
||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LanguageModelInput,
|
||||
get_tokenizer,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.language_models.llms import LLM, BaseLLM
|
||||
|
||||
@ -9,4 +13,5 @@ __all__ = [
|
||||
"BaseLLM",
|
||||
"LLM",
|
||||
"LanguageModelInput",
|
||||
"get_tokenizer",
|
||||
]
|
||||
|
@ -17,7 +17,7 @@ from typing_extensions import TypeAlias
|
||||
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.prompts import PromptValue
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
|
||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
from transformers import GPT2TokenizerFast # type: ignore[import]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import transformers python package. "
|
||||
@ -74,8 +74,10 @@ class BaseLanguageModel(
|
||||
@property
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
||||
from langchain_core.prompts.string import StringPromptValue
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
StringPromptValue,
|
||||
)
|
||||
|
||||
# This is a version of LanguageModelInput which replaces the abstract
|
||||
# base class BaseMessage with a union of its subclasses, which makes
|
||||
|
@ -39,7 +39,7 @@ from langchain_core.outputs import (
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
@ -50,7 +50,7 @@ from langchain_core.language_models.base import BaseLanguageModel, LanguageModel
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||
from langchain_core.runnables import RunnableConfig, get_config_list
|
||||
|
||||
|
@ -9,7 +9,7 @@ from langchain_core.output_parsers.list import (
|
||||
MarkdownListOutputParser,
|
||||
NumberedListOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers.str import StrOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.transform import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
BaseTransformOutputParser,
|
||||
|
@ -17,11 +17,8 @@ from typing import (
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
Generation,
|
||||
)
|
||||
from langchain_core.prompts.value import PromptValue
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
|
||||
T = TypeVar("T")
|
||||
|
76
libs/core/langchain_core/prompt_values.py
Normal file
76
libs/core/langchain_core/prompt_values.py
Normal file
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
"""Base abstract class for inputs to any language model.
|
||||
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
|
||||
|
||||
class StringPromptValue(PromptValue):
|
||||
"""String prompt value."""
|
||||
|
||||
text: str
|
||||
"""Prompt text."""
|
||||
type: Literal["StringPromptValue"] = "StringPromptValue"
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return self.text
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
"""Chat prompt value.
|
||||
|
||||
A type of a prompt value that is built from messages.
|
||||
"""
|
||||
|
||||
messages: Sequence[BaseMessage]
|
||||
"""List of messages."""
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
return list(self.messages)
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
For use in external schemas."""
|
||||
|
||||
messages: Sequence[AnyMessage]
|
||||
|
||||
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
@ -23,9 +23,6 @@ from multiple components. Prompt classes and functions make constructing
|
||||
AIMessagePromptTemplate
|
||||
SystemMessagePromptTemplate
|
||||
|
||||
PromptValue --> StringPromptValue
|
||||
ChatPromptValue
|
||||
|
||||
""" # noqa: E501
|
||||
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.chat import (
|
||||
@ -33,8 +30,6 @@ from langchain_core.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
ChatMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
ChatPromptValueConcrete,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
@ -47,7 +42,13 @@ from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemp
|
||||
from langchain_core.prompts.loading import load_prompt
|
||||
from langchain_core.prompts.pipeline import PipelinePromptTemplate
|
||||
from langchain_core.prompts.prompt import Prompt, PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate, StringPromptValue
|
||||
from langchain_core.prompts.string import (
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
jinja2_formatter,
|
||||
validate_jinja2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AIMessagePromptTemplate",
|
||||
@ -55,8 +56,6 @@ __all__ = [
|
||||
"BasePromptTemplate",
|
||||
"ChatMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"ChatPromptValue",
|
||||
"ChatPromptValueConcrete",
|
||||
"FewShotPromptTemplate",
|
||||
"FewShotPromptWithTemplates",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
@ -65,12 +64,12 @@ __all__ = [
|
||||
"PipelinePromptTemplate",
|
||||
"Prompt",
|
||||
"PromptTemplate",
|
||||
"PromptValue",
|
||||
"StringPromptValue",
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"load_prompt",
|
||||
"format_document",
|
||||
"check_valid_template",
|
||||
"get_template_variables",
|
||||
"jinja2_formatter",
|
||||
"validate_jinja2",
|
||||
]
|
||||
|
||||
from langchain_core.prompts.value import PromptValue
|
||||
|
@ -8,8 +8,8 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
|
||||
import yaml
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts.value import PromptValue
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
|
||||
@ -40,8 +40,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
||||
from langchain_core.prompts.string import StringPromptValue
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
StringPromptValue,
|
||||
)
|
||||
|
||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||
|
||||
|
@ -8,7 +8,6 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
@ -27,12 +26,11 @@ from langchain_core.messages import (
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate
|
||||
from langchain_core.prompts.value import PromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
|
||||
|
||||
@ -277,33 +275,6 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
"""Chat prompt value.
|
||||
|
||||
A type of a prompt value that is built from messages.
|
||||
"""
|
||||
|
||||
messages: Sequence[BaseMessage]
|
||||
"""List of messages."""
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
return list(self.messages)
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
For use in external schemas."""
|
||||
|
||||
messages: Sequence[AnyMessage]
|
||||
|
||||
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import Callable, Dict, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||
from langchain_core.prompts.value import PromptValue
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
|
@ -4,11 +4,10 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Literal, Set
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.value import PromptValue
|
||||
from langchain_core.utils.formatting import formatter
|
||||
|
||||
|
||||
@ -149,22 +148,6 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||
return sorted(input_variables)
|
||||
|
||||
|
||||
class StringPromptValue(PromptValue):
|
||||
"""String prompt value."""
|
||||
|
||||
text: str
|
||||
"""Prompt text."""
|
||||
type: Literal["StringPromptValue"] = "StringPromptValue"
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return self.text
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""String prompt that exposes the format method, returning a prompt."""
|
||||
|
||||
|
@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
"""Base abstract class for inputs to any language model.
|
||||
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
@ -9,7 +9,7 @@ from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
|
226
libs/core/langchain_core/tracers/context.py
Normal file
226
libs/core/langchain_core/tracers/context.py
Normal file
@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import utils as ls_utils
|
||||
from langsmith.run_helpers import get_run_tree_context
|
||||
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
||||
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
||||
from langchain_core.tracers.schemas import TracerSessionV1
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langsmith import Client as LangSmithClient
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler, Callbacks
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501
|
||||
"tracing_callback", default=None
|
||||
)
|
||||
|
||||
tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
|
||||
"tracing_callback_v2", default=None
|
||||
)
|
||||
run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVar( # noqa: E501
|
||||
"run_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[TracerSessionV1, None, None]:
|
||||
"""Get the Deprecated LangChainTracer in a context manager.
|
||||
|
||||
Args:
|
||||
session_name (str, optional): The name of the session.
|
||||
Defaults to "default".
|
||||
|
||||
Returns:
|
||||
TracerSessionV1: The LangChainTracer session.
|
||||
|
||||
Example:
|
||||
>>> with tracing_enabled() as session:
|
||||
... # Use the LangChainTracer session
|
||||
"""
|
||||
cb = LangChainTracerV1()
|
||||
session = cast(TracerSessionV1, cb.load_session(session_name))
|
||||
try:
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
finally:
|
||||
tracing_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_v2_enabled(
|
||||
project_name: Optional[str] = None,
|
||||
*,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
client: Optional[LangSmithClient] = None,
|
||||
) -> Generator[LangChainTracer, None, None]:
|
||||
"""Instruct LangChain to log all runs in context to LangSmith.
|
||||
|
||||
Args:
|
||||
project_name (str, optional): The name of the project.
|
||||
Defaults to "default".
|
||||
example_id (str or UUID, optional): The ID of the example.
|
||||
Defaults to None.
|
||||
tags (List[str], optional): The tags to add to the run.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Example:
|
||||
>>> with tracing_v2_enabled():
|
||||
... # LangChain code will automatically be traced
|
||||
|
||||
You can use this to fetch the LangSmith run URL:
|
||||
|
||||
>>> with tracing_v2_enabled() as cb:
|
||||
... chain.invoke("foo")
|
||||
... run_url = cb.get_run_url()
|
||||
"""
|
||||
if isinstance(example_id, str):
|
||||
example_id = UUID(example_id)
|
||||
cb = LangChainTracer(
|
||||
example_id=example_id,
|
||||
project_name=project_name,
|
||||
tags=tags,
|
||||
client=client,
|
||||
)
|
||||
try:
|
||||
tracing_v2_callback_var.set(cb)
|
||||
yield cb
|
||||
finally:
|
||||
tracing_v2_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
|
||||
"""Collect all run traces in context.
|
||||
|
||||
Returns:
|
||||
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
|
||||
|
||||
Example:
|
||||
>>> with collect_runs() as runs_cb:
|
||||
chain.invoke("foo")
|
||||
run_id = runs_cb.traced_runs[0].id
|
||||
"""
|
||||
cb = RunCollectorCallbackHandler()
|
||||
run_collector_var.set(cb)
|
||||
yield cb
|
||||
run_collector_var.set(None)
|
||||
|
||||
|
||||
def _get_trace_callbacks(
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None,
|
||||
) -> Callbacks:
|
||||
if _tracing_v2_is_enabled():
|
||||
project_name_ = project_name or _get_tracer_project()
|
||||
tracer = tracing_v2_callback_var.get() or LangChainTracer(
|
||||
project_name=project_name_,
|
||||
example_id=example_id,
|
||||
)
|
||||
if callback_manager is None:
|
||||
from langchain_core.callbacks.base import Callbacks
|
||||
|
||||
cb = cast(Callbacks, [tracer])
|
||||
else:
|
||||
if not any(
|
||||
isinstance(handler, LangChainTracer)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(tracer, True)
|
||||
# If it already has a LangChainTracer, we don't need to add another one.
|
||||
# this would likely mess up the trace hierarchy.
|
||||
cb = callback_manager
|
||||
else:
|
||||
cb = None
|
||||
return cb
|
||||
|
||||
|
||||
def _tracing_v2_is_enabled() -> bool:
|
||||
return (
|
||||
env_var_is_set("LANGCHAIN_TRACING_V2")
|
||||
or tracing_v2_callback_var.get() is not None
|
||||
or get_run_tree_context() is not None
|
||||
)
|
||||
|
||||
|
||||
def _get_tracer_project() -> str:
|
||||
run_tree = get_run_tree_context()
|
||||
return getattr(
|
||||
run_tree,
|
||||
"session_name",
|
||||
getattr(
|
||||
# Note, if people are trying to nest @traceable functions and the
|
||||
# tracing_v2_enabled context manager, this will likely mess up the
|
||||
# tree structure.
|
||||
tracing_v2_callback_var.get(),
|
||||
"project",
|
||||
# Have to set this to a string even though it always will return
|
||||
# a string because `get_tracer_project` technically can return
|
||||
# None, but only when a specific argument is supplied.
|
||||
# Therefore, this just tricks the mypy type checker
|
||||
str(ls_utils.get_tracer_project()),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_configure_hooks: List[
|
||||
Tuple[
|
||||
ContextVar[Optional[BaseCallbackHandler]],
|
||||
bool,
|
||||
Optional[Type[BaseCallbackHandler]],
|
||||
Optional[str],
|
||||
]
|
||||
] = []
|
||||
|
||||
|
||||
def register_configure_hook(
|
||||
context_var: ContextVar[Optional[Any]],
|
||||
inheritable: bool,
|
||||
handle_class: Optional[Type[BaseCallbackHandler]] = None,
|
||||
env_var: Optional[str] = None,
|
||||
) -> None:
|
||||
if env_var is not None and handle_class is None:
|
||||
raise ValueError(
|
||||
"If env_var is set, handle_class must also be set to a non-None value."
|
||||
)
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
|
||||
_configure_hooks.append(
|
||||
(
|
||||
# the typings of ContextVar do not have the generic arg set as covariant
|
||||
# so we have to cast it
|
||||
cast(ContextVar[Optional[BaseCallbackHandler]], context_var),
|
||||
inheritable,
|
||||
handle_class,
|
||||
env_var,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_configure_hook(run_collector_var, False)
|
@ -11,9 +11,9 @@ from uuid import UUID
|
||||
import langsmith
|
||||
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
|
||||
|
||||
from langchain_core.callbacks import manager
|
||||
from langchain_core.tracers import langchain as langchain_tracer
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.context import tracing_v2_enabled
|
||||
from langchain_core.tracers.langchain import _get_executor
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
@ -115,7 +115,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
if self.project_name is None:
|
||||
eval_result = self.client.evaluate_run(run, evaluator)
|
||||
eval_results = [eval_result]
|
||||
with manager.tracing_v2_enabled(
|
||||
with tracing_v2_enabled(
|
||||
project_name=self.project_name, tags=["eval"], client=self.client
|
||||
) as cb:
|
||||
reference_example = (
|
||||
|
@ -15,7 +15,7 @@ from typing import (
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
import jsonpatch
|
||||
import jsonpatch # type: ignore[import]
|
||||
from anyio import create_memory_object_stream
|
||||
|
||||
from langchain_core.load import load
|
||||
|
20
libs/core/langchain_core/utils/env.py
Normal file
20
libs/core/langchain_core/utils/env.py
Normal file
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def env_var_is_set(env_var: str) -> bool:
|
||||
"""Check if an environment variable is set.
|
||||
|
||||
Args:
|
||||
env_var (str): The name of the environment variable.
|
||||
|
||||
Returns:
|
||||
bool: True if the environment variable is set, False otherwise.
|
||||
"""
|
||||
return env_var in os.environ and os.environ[env_var] not in (
|
||||
"",
|
||||
"0",
|
||||
"false",
|
||||
"False",
|
||||
)
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-core"
|
||||
version = "0.0.2"
|
||||
version = "0.0.3"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
@ -51,7 +51,6 @@ select = [
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = "True"
|
||||
disallow_untyped_defs = "True"
|
||||
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
||||
|
||||
|
0
libs/core/tests/unit_tests/callbacks/__init__.py
Normal file
0
libs/core/tests/unit_tests/callbacks/__init__.py
Normal file
37
libs/core/tests/unit_tests/callbacks/test_imports.py
Normal file
37
libs/core/tests/unit_tests/callbacks/test_imports.py
Normal file
@ -0,0 +1,37 @@
|
||||
from langchain_core.callbacks import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"RetrieverManagerMixin",
|
||||
"LLMManagerMixin",
|
||||
"ChainManagerMixin",
|
||||
"ToolManagerMixin",
|
||||
"Callbacks",
|
||||
"CallbackManagerMixin",
|
||||
"RunManagerMixin",
|
||||
"BaseCallbackHandler",
|
||||
"AsyncCallbackHandler",
|
||||
"BaseCallbackManager",
|
||||
"BaseRunManager",
|
||||
"RunManager",
|
||||
"ParentRunManager",
|
||||
"AsyncRunManager",
|
||||
"AsyncParentRunManager",
|
||||
"CallbackManagerForLLMRun",
|
||||
"AsyncCallbackManagerForLLMRun",
|
||||
"CallbackManagerForChainRun",
|
||||
"AsyncCallbackManagerForChainRun",
|
||||
"CallbackManagerForToolRun",
|
||||
"AsyncCallbackManagerForToolRun",
|
||||
"CallbackManagerForRetrieverRun",
|
||||
"AsyncCallbackManagerForRetrieverRun",
|
||||
"CallbackManager",
|
||||
"CallbackManagerForChainGroup",
|
||||
"AsyncCallbackManager",
|
||||
"AsyncCallbackManagerForChainGroup",
|
||||
"StdOutCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
13
libs/core/tests/unit_tests/example_selectors/test_imports.py
Normal file
13
libs/core/tests/unit_tests/example_selectors/test_imports.py
Normal file
@ -0,0 +1,13 @@
|
||||
from langchain_core.example_selectors import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseExampleSelector",
|
||||
"LengthBasedExampleSelector",
|
||||
"MaxMarginalRelevanceExampleSelector",
|
||||
"SemanticSimilarityExampleSelector",
|
||||
"sorted_values",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
15
libs/core/tests/unit_tests/language_models/test_imports.py
Normal file
15
libs/core/tests/unit_tests/language_models/test_imports.py
Normal file
@ -0,0 +1,15 @@
|
||||
from langchain_core.language_models import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseLanguageModel",
|
||||
"BaseChatModel",
|
||||
"SimpleChatModel",
|
||||
"BaseLLM",
|
||||
"LLM",
|
||||
"LanguageModelInput",
|
||||
"get_tokenizer",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
0
libs/core/tests/unit_tests/load/__init__.py
Normal file
0
libs/core/tests/unit_tests/load/__init__.py
Normal file
7
libs/core/tests/unit_tests/load/test_imports.py
Normal file
7
libs/core/tests/unit_tests/load/test_imports.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain_core.load import __all__
|
||||
|
||||
EXPECTED_ALL = ["dumpd", "dumps", "load", "loads", "Serializable"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
0
libs/core/tests/unit_tests/messages/__init__.py
Normal file
0
libs/core/tests/unit_tests/messages/__init__.py
Normal file
28
libs/core/tests/unit_tests/messages/test_imports.py
Normal file
28
libs/core/tests/unit_tests/messages/test_imports.py
Normal file
@ -0,0 +1,28 @@
|
||||
from langchain_core.messages import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"AnyMessage",
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"get_buffer_string",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"message_to_dict",
|
||||
"merge_content",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
18
libs/core/tests/unit_tests/output_parsers/test_imports.py
Normal file
18
libs/core/tests/unit_tests/output_parsers/test_imports.py
Normal file
@ -0,0 +1,18 @@
|
||||
from langchain_core.output_parsers import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseLLMOutputParser",
|
||||
"BaseGenerationOutputParser",
|
||||
"BaseOutputParser",
|
||||
"ListOutputParser",
|
||||
"CommaSeparatedListOutputParser",
|
||||
"NumberedListOutputParser",
|
||||
"MarkdownListOutputParser",
|
||||
"StrOutputParser",
|
||||
"BaseTransformOutputParser",
|
||||
"BaseCumulativeTransformOutputParser",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
0
libs/core/tests/unit_tests/outputs/__init__.py
Normal file
0
libs/core/tests/unit_tests/outputs/__init__.py
Normal file
15
libs/core/tests/unit_tests/outputs/test_imports.py
Normal file
15
libs/core/tests/unit_tests/outputs/test_imports.py
Normal file
@ -0,0 +1,15 @@
|
||||
from langchain_core.outputs import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatGeneration",
|
||||
"ChatGenerationChunk",
|
||||
"ChatResult",
|
||||
"Generation",
|
||||
"GenerationChunk",
|
||||
"LLMResult",
|
||||
"RunInfo",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -10,6 +10,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
@ -17,7 +18,6 @@ from langchain_core.prompts.chat import (
|
||||
ChatMessage,
|
||||
ChatMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
_convert_to_message,
|
||||
|
@ -6,14 +6,10 @@ EXPECTED_ALL = [
|
||||
"BasePromptTemplate",
|
||||
"ChatMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"ChatPromptValueConcrete",
|
||||
"FewShotPromptTemplate",
|
||||
"FewShotPromptWithTemplates",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
"format_document",
|
||||
"ChatPromptValue",
|
||||
"PromptValue",
|
||||
"StringPromptValue",
|
||||
"HumanMessagePromptTemplate",
|
||||
"MessagesPlaceholder",
|
||||
"PipelinePromptTemplate",
|
||||
@ -22,6 +18,10 @@ EXPECTED_ALL = [
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"load_prompt",
|
||||
"check_valid_template",
|
||||
"get_template_variables",
|
||||
"jinja2_formatter",
|
||||
"validate_jinja2",
|
||||
]
|
||||
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
28
libs/core/tests/unit_tests/runnables/test_imports.py
Normal file
28
libs/core/tests/unit_tests/runnables/test_imports.py
Normal file
@ -0,0 +1,28 @@
|
||||
from langchain_core.runnables import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableSerializable",
|
||||
"RunnableBinding",
|
||||
"RunnableBranch",
|
||||
"RunnableConfig",
|
||||
"RunnableGenerator",
|
||||
"RunnableLambda",
|
||||
"RunnableMap",
|
||||
"RunnableParallel",
|
||||
"RunnablePassthrough",
|
||||
"RunnableSequence",
|
||||
"RunnableWithFallbacks",
|
||||
"get_config_list",
|
||||
"add",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -23,7 +23,6 @@ from typing_extensions import TypedDict
|
||||
from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
atrace_as_chain_group,
|
||||
collect_runs,
|
||||
trace_as_chain_group,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
@ -39,13 +38,12 @@ from langchain_core.output_parsers import (
|
||||
CommaSeparatedListOutputParser,
|
||||
StrOutputParser,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
PromptTemplate,
|
||||
StringPromptValue,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
@ -75,6 +73,7 @@ from langchain_core.tracers import (
|
||||
RunLog,
|
||||
RunLogPatch,
|
||||
)
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||
|
||||
|
15
libs/core/tests/unit_tests/tracers/test_imports.py
Normal file
15
libs/core/tests/unit_tests/tracers/test_imports.py
Normal file
@ -0,0 +1,15 @@
|
||||
from langchain_core.tracers import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseTracer",
|
||||
"EvaluatorCallbackHandler",
|
||||
"LangChainTracer",
|
||||
"ConsoleCallbackHandler",
|
||||
"Run",
|
||||
"RunLog",
|
||||
"RunLogPatch",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -2,12 +2,12 @@
|
||||
|
||||
import uuid
|
||||
|
||||
from langchain.callbacks import collect_runs
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
|
||||
def test_collect_runs() -> None:
|
||||
llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True)
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
with collect_runs() as cb:
|
||||
llm.predict("hi")
|
||||
assert cb.traced_runs
|
@ -2,9 +2,8 @@ import uuid
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
|
||||
|
||||
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
||||
from langchain_experimental.comprehend_moderation.prompt_safety import (
|
||||
|
@ -29,14 +29,16 @@ from langchain_core.callbacks.manager import (
|
||||
RunManager,
|
||||
ahandle_event,
|
||||
atrace_as_chain_group,
|
||||
collect_runs,
|
||||
env_var_is_set,
|
||||
handle_event,
|
||||
register_configure_hook,
|
||||
trace_as_chain_group,
|
||||
)
|
||||
from langchain_core.tracers.context import (
|
||||
collect_runs,
|
||||
register_configure_hook,
|
||||
tracing_enabled,
|
||||
tracing_v2_enabled,
|
||||
)
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.tracers.wandb import WandbTracer
|
||||
@ -122,6 +124,6 @@ __all__ = [
|
||||
"trace_as_chain_group",
|
||||
"handle_event",
|
||||
"ahandle_event",
|
||||
"env_var_is_set",
|
||||
"Callbacks",
|
||||
"env_var_is_set",
|
||||
]
|
||||
|
@ -12,8 +12,8 @@ from langchain_core.load.dump import dumpd
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptValue
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
|
@ -9,7 +9,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.prompts import PromptValue
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
|
@ -13,7 +13,7 @@ from typing import (
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.prompts import PromptValue
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
check_package_version,
|
||||
|
@ -5,8 +5,8 @@ from typing import Any, TypeVar
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptValue
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
|
||||
NAIVE_COMPLETION_RETRY = """Prompt:
|
||||
{prompt}
|
||||
|
@ -1,7 +1,6 @@
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.base import (
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
StringPromptTemplate,
|
||||
StringPromptValue,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
jinja2_formatter,
|
||||
@ -13,7 +12,6 @@ __all__ = [
|
||||
"validate_jinja2",
|
||||
"check_valid_template",
|
||||
"get_template_variables",
|
||||
"StringPromptValue",
|
||||
"StringPromptTemplate",
|
||||
"BasePromptTemplate",
|
||||
]
|
||||
|
@ -5,8 +5,6 @@ from langchain_core.prompts.chat import (
|
||||
BaseStringMessagePromptTemplate,
|
||||
ChatMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
ChatPromptValueConcrete,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
@ -20,8 +18,6 @@ __all__ = [
|
||||
"HumanMessagePromptTemplate",
|
||||
"AIMessagePromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"ChatPromptValue",
|
||||
"ChatPromptValueConcrete",
|
||||
"BaseChatPromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
]
|
||||
|
@ -31,12 +31,16 @@ from langchain_core.outputs import (
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptValue, format_document
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
RUN_KEY = "__run"
|
||||
|
||||
# Backwards compatibility.
|
||||
Memory = BaseMemory
|
||||
_message_to_dict = message_to_dict
|
||||
|
||||
__all__ = [
|
||||
"BaseCache",
|
||||
@ -56,6 +60,7 @@ __all__ = [
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"message_to_dict",
|
||||
"_message_to_dict",
|
||||
"_message_from_dict",
|
||||
"get_buffer_string",
|
||||
"RunInfo",
|
||||
|
@ -16,14 +16,16 @@ from langchain_core.callbacks.manager import (
|
||||
CallbackManagerForToolRun,
|
||||
ParentRunManager,
|
||||
RunManager,
|
||||
collect_runs,
|
||||
env_var_is_set,
|
||||
handle_event,
|
||||
register_configure_hook,
|
||||
trace_as_chain_group,
|
||||
)
|
||||
from langchain_core.tracers.context import (
|
||||
collect_runs,
|
||||
register_configure_hook,
|
||||
tracing_enabled,
|
||||
tracing_v2_enabled,
|
||||
)
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
__all__ = [
|
||||
"tracing_enabled",
|
||||
@ -48,6 +50,6 @@ __all__ = [
|
||||
"CallbackManagerForChainGroup",
|
||||
"AsyncCallbackManager",
|
||||
"AsyncCallbackManagerForChainGroup",
|
||||
"env_var_is_set",
|
||||
"register_configure_hook",
|
||||
"env_var_is_set",
|
||||
]
|
||||
|
@ -1,3 +1,4 @@
|
||||
from langchain_core.documents import BaseDocumentTransformer, Document
|
||||
from langchain_core.document_transformers import BaseDocumentTransformer
|
||||
from langchain_core.documents import Document
|
||||
|
||||
__all__ = ["Document", "BaseDocumentTransformer"]
|
||||
|
@ -1,3 +1,4 @@
|
||||
from langchain_core.language_models import BaseLanguageModel, get_tokenizer
|
||||
from langchain_core.language_models.base import _get_token_ids_default_method
|
||||
|
||||
__all__ = ["get_tokenizer", "BaseLanguageModel"]
|
||||
__all__ = ["get_tokenizer", "BaseLanguageModel", "_get_token_ids_default_method"]
|
||||
|
@ -13,12 +13,17 @@ from langchain_core.messages import (
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
_message_from_dict,
|
||||
get_buffer_string,
|
||||
merge_content,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
|
||||
# Backwards compatibility.
|
||||
_message_to_dict = message_to_dict
|
||||
|
||||
__all__ = [
|
||||
"get_buffer_string",
|
||||
"BaseMessage",
|
||||
@ -38,4 +43,7 @@ __all__ = [
|
||||
"ChatMessageChunk",
|
||||
"messages_to_dict",
|
||||
"messages_from_dict",
|
||||
"_message_to_dict",
|
||||
"_message_from_dict",
|
||||
"message_to_dict",
|
||||
]
|
||||
|
@ -1,19 +1,23 @@
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
BaseGenerationOutputParser,
|
||||
BaseLLMOutputParser,
|
||||
BaseOutputParser,
|
||||
BaseTransformOutputParser,
|
||||
OutputParserException,
|
||||
StrOutputParser,
|
||||
)
|
||||
|
||||
# Backwards compatibility.
|
||||
NoOpOutputParser = StrOutputParser
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMOutputParser",
|
||||
"BaseGenerationOutputParser",
|
||||
"BaseOutputParser",
|
||||
"BaseTransformOutputParser",
|
||||
"BaseCumulativeTransformOutputParser",
|
||||
"NoOpOutputParser",
|
||||
"StrOutputParser",
|
||||
"OutputParserException",
|
||||
]
|
||||
|
@ -1,3 +1,3 @@
|
||||
from langchain_core.prompts import PromptValue
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
__all__ = ["PromptValue"]
|
||||
|
@ -12,6 +12,9 @@ from langchain_core.runnables.base import (
|
||||
coerce_to_runnable,
|
||||
)
|
||||
|
||||
# Backwards compatibility.
|
||||
RunnableMap = RunnableParallel
|
||||
|
||||
__all__ = [
|
||||
"Runnable",
|
||||
"RunnableSerializable",
|
||||
@ -23,5 +26,6 @@ __all__ = [
|
||||
"RunnableEach",
|
||||
"RunnableBindingBase",
|
||||
"RunnableBinding",
|
||||
"RunnableMap",
|
||||
"coerce_to_runnable",
|
||||
]
|
||||
|
@ -1,6 +1,7 @@
|
||||
from langchain_core.runnables.config import (
|
||||
EmptyDict,
|
||||
RunnableConfig,
|
||||
acall_func_with_variable_args,
|
||||
call_func_with_variable_args,
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
@ -18,6 +19,7 @@ __all__ = [
|
||||
"get_config_list",
|
||||
"patch_config",
|
||||
"merge_configs",
|
||||
"acall_func_with_variable_args",
|
||||
"call_func_with_variable_args",
|
||||
"get_callback_manager_for_config",
|
||||
"get_async_callback_manager_for_config",
|
||||
|
@ -1,7 +1,8 @@
|
||||
from langchain_core.runnables.passthrough import (
|
||||
RunnableAssign,
|
||||
RunnablePassthrough,
|
||||
aidentity,
|
||||
identity,
|
||||
)
|
||||
|
||||
__all__ = ["identity", "RunnablePassthrough", "RunnableAssign"]
|
||||
__all__ = ["aidentity", "identity", "RunnablePassthrough", "RunnableAssign"]
|
||||
|
@ -8,9 +8,12 @@ from langchain_core.runnables.utils import (
|
||||
IsFunctionArgDict,
|
||||
IsLocalDict,
|
||||
SupportsAdd,
|
||||
aadd,
|
||||
accepts_config,
|
||||
accepts_run_manager,
|
||||
add,
|
||||
gated_coro,
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
get_lambda_source,
|
||||
get_unique_config_specs,
|
||||
@ -34,4 +37,7 @@ __all__ = [
|
||||
"ConfigurableFieldMultiOption",
|
||||
"ConfigurableFieldSpec",
|
||||
"get_unique_config_specs",
|
||||
"aadd",
|
||||
"gated_coro",
|
||||
"gather_with_concurrency",
|
||||
]
|
||||
|
13
libs/langchain/poetry.lock
generated
13
libs/langchain/poetry.lock
generated
@ -4128,13 +4128,13 @@ tests = ["pandas (>=1.4)", "pytest", "pytest-asyncio", "pytest-mock"]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.0.1"
|
||||
version = "0.0.2"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langchain_core-0.0.1-py3-none-any.whl", hash = "sha256:cad923dd3bc39cd9fe24b9d6a9799c97719aeaafc9b19509fe1347109fcb65b3"},
|
||||
{file = "langchain_core-0.0.1.tar.gz", hash = "sha256:488b72223e14849bf9588ed677a999b282904d1d5e1f81d12767ee1024220724"},
|
||||
{file = "langchain_core-0.0.2-py3-none-any.whl", hash = "sha256:dd448c7887c24105761a0763bee5ec5b072de905b4fbd83d693e7a181fd63208"},
|
||||
{file = "langchain_core-0.0.2.tar.gz", hash = "sha256:772600dfe3e707adb9055ed96797763b0de32b394eac8325bf609a7e5929dda2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -6515,6 +6515,8 @@ files = [
|
||||
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
|
||||
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
|
||||
@ -6557,6 +6559,7 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
|
||||
@ -6565,6 +6568,8 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
|
||||
@ -11080,4 +11085,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "6781b828c8bc2b08ece3cbf82be799dc0e361b9f6f9c204ddefcfee70ab0db8b"
|
||||
content-hash = "3660cb9106129dcf1a2fe157c1bdfe567923fc9ed408672bcbcb70f4345f1285"
|
||||
|
@ -12,7 +12,7 @@ langchain-server = "langchain.server:main"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.0.1"
|
||||
langchain-core = "^0.0.2"
|
||||
pydantic = ">=1,<3"
|
||||
SQLAlchemy = ">=1.4,<3"
|
||||
requests = "^2"
|
||||
|
@ -2,7 +2,6 @@ import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
@ -16,6 +15,8 @@ def _chat_message_history(
|
||||
drop: bool = True,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> CassandraChatMessageHistory:
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
keyspace = "cmh_test_keyspace"
|
||||
table_name = "cmh_test_table"
|
||||
# get db connection
|
||||
|
@ -2,8 +2,6 @@
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores import Cassandra
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
@ -19,6 +17,8 @@ def _vectorstore_from_texts(
|
||||
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
|
||||
drop: bool = True,
|
||||
) -> Cassandra:
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
keyspace = "vector_test_keyspace"
|
||||
table_name = "vector_test_table"
|
||||
# get db connection
|
||||
@ -154,12 +154,3 @@ def test_cassandra_delete() -> None:
|
||||
time.sleep(0.3)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_cassandra()
|
||||
# test_cassandra_with_score()
|
||||
# test_cassandra_max_marginal_relevance_search()
|
||||
# test_cassandra_add_extra()
|
||||
# test_cassandra_no_drop()
|
||||
# test_cassandra_delete()
|
||||
|
20
libs/langchain/tests/unit_tests/schema/runnable/test_base.py
Normal file
20
libs/langchain/tests/unit_tests/schema/runnable/test_base.py
Normal file
@ -0,0 +1,20 @@
|
||||
from langchain.schema.runnable.base import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"Runnable",
|
||||
"RunnableBinding",
|
||||
"RunnableBindingBase",
|
||||
"RunnableEach",
|
||||
"RunnableEachBase",
|
||||
"RunnableGenerator",
|
||||
"RunnableLambda",
|
||||
"RunnableMap",
|
||||
"RunnableParallel",
|
||||
"RunnableSequence",
|
||||
"RunnableSerializable",
|
||||
"coerce_to_runnable",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.runnable.branch import __all__
|
||||
|
||||
EXPECTED_ALL = ["RunnableBranch"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,19 @@
|
||||
from langchain.schema.runnable.config import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"EmptyDict",
|
||||
"RunnableConfig",
|
||||
"acall_func_with_variable_args",
|
||||
"call_func_with_variable_args",
|
||||
"ensure_config",
|
||||
"get_async_callback_manager_for_config",
|
||||
"get_callback_manager_for_config",
|
||||
"get_config_list",
|
||||
"get_executor_for_config",
|
||||
"merge_configs",
|
||||
"patch_config",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,13 @@
|
||||
from langchain.schema.runnable.configurable import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"DynamicRunnable",
|
||||
"RunnableConfigurableAlternatives",
|
||||
"RunnableConfigurableFields",
|
||||
"StrEnum",
|
||||
"make_options_spec",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.runnable.fallbacks import __all__
|
||||
|
||||
EXPECTED_ALL = ["RunnableWithFallbacks"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.runnable.history import __all__
|
||||
|
||||
EXPECTED_ALL = ["RunnableWithMessageHistory"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,26 @@
|
||||
from langchain.schema.runnable import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableSerializable",
|
||||
"RunnableBinding",
|
||||
"RunnableBranch",
|
||||
"RunnableConfig",
|
||||
"RunnableGenerator",
|
||||
"RunnableLambda",
|
||||
"RunnableMap",
|
||||
"RunnableParallel",
|
||||
"RunnablePassthrough",
|
||||
"RunnableSequence",
|
||||
"RunnableWithFallbacks",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.runnable.passthrough import __all__
|
||||
|
||||
EXPECTED_ALL = ["RunnableAssign", "RunnablePassthrough", "aidentity", "identity"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.runnable.retry import __all__
|
||||
|
||||
EXPECTED_ALL = ["RunnableRetry"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.runnable.router import __all__
|
||||
|
||||
EXPECTED_ALL = ["RouterInput", "RouterRunnable"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,27 @@
|
||||
from langchain.schema.runnable.utils import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AddableDict",
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldSpec",
|
||||
"GetLambdaSource",
|
||||
"IsFunctionArgDict",
|
||||
"IsLocalDict",
|
||||
"SupportsAdd",
|
||||
"aadd",
|
||||
"accepts_config",
|
||||
"accepts_run_manager",
|
||||
"add",
|
||||
"gated_coro",
|
||||
"gather_with_concurrency",
|
||||
"get_function_first_arg_dict_keys",
|
||||
"get_lambda_source",
|
||||
"get_unique_config_specs",
|
||||
"indent_lines_after_first",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_agent.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_agent.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain.schema.agent import __all__
|
||||
|
||||
EXPECTED_ALL = ["AgentAction", "AgentActionMessageLog", "AgentFinish"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_cache.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_cache.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain.schema.cache import __all__
|
||||
|
||||
EXPECTED_ALL = ["BaseCache"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_chat.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_chat.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain.schema.chat import __all__
|
||||
|
||||
EXPECTED_ALL = ["ChatSession"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.chat_history import __all__
|
||||
|
||||
EXPECTED_ALL = ["BaseChatMessageHistory"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_document.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_document.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain.schema.document import __all__
|
||||
|
||||
EXPECTED_ALL = ["BaseDocumentTransformer", "Document"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.embeddings import __all__
|
||||
|
||||
EXPECTED_ALL = ["Embeddings"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.exceptions import __all__
|
||||
|
||||
EXPECTED_ALL = ["LangChainException"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -18,6 +18,7 @@ EXPECTED_ALL = [
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"message_to_dict",
|
||||
"_message_to_dict",
|
||||
"_message_from_dict",
|
||||
"get_buffer_string",
|
||||
"RunInfo",
|
||||
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.language_model import __all__
|
||||
|
||||
EXPECTED_ALL = ["BaseLanguageModel", "_get_token_ids_default_method", "get_tokenizer"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_memory.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_memory.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain.schema.memory import __all__
|
||||
|
||||
EXPECTED_ALL = ["BaseMemory"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -1,101 +1,29 @@
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain.schema.messages import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"_message_from_dict",
|
||||
"_message_to_dict",
|
||||
"message_to_dict",
|
||||
"get_buffer_string",
|
||||
"merge_content",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
]
|
||||
|
||||
|
||||
def test_message_chunks() -> None:
|
||||
assert AIMessageChunk(content="I am") + AIMessageChunk(
|
||||
content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
content="I am indeed."
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk"
|
||||
|
||||
assert (
|
||||
AIMessageChunk(content="I am") + HumanMessageChunk(content=" indeed.")
|
||||
== AIMessageChunk(content="I am indeed.")
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
|
||||
|
||||
assert (
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": "bar"})
|
||||
+ AIMessageChunk(content="", additional_kwargs={"baz": "foo"})
|
||||
== AIMessageChunk(content="", additional_kwargs={"foo": "bar", "baz": "foo"})
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
||||
|
||||
assert (
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "web_search"}}
|
||||
)
|
||||
+ AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
|
||||
)
|
||||
+ AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": ' "query": "turtles"\n}'}
|
||||
},
|
||||
)
|
||||
== AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "web_search",
|
||||
"arguments": '{\n "query": "turtles"\n}',
|
||||
}
|
||||
},
|
||||
)
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
||||
|
||||
|
||||
def test_chat_message_chunks() -> None:
|
||||
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||
role="User", content=" indeed."
|
||||
) == ChatMessageChunk(
|
||||
role="User", content="I am indeed."
|
||||
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||
role="Assistant", content=" indeed."
|
||||
)
|
||||
|
||||
assert (
|
||||
ChatMessageChunk(role="User", content="I am")
|
||||
+ AIMessageChunk(content=" indeed.")
|
||||
== ChatMessageChunk(role="User", content="I am indeed.")
|
||||
), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501
|
||||
|
||||
assert AIMessageChunk(content="I am") + ChatMessageChunk(
|
||||
role="User", content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
content="I am indeed."
|
||||
), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501
|
||||
|
||||
|
||||
def test_function_message_chunks() -> None:
|
||||
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||
name="hello", content=" indeed."
|
||||
) == FunctionMessageChunk(
|
||||
name="hello", content="I am indeed."
|
||||
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||
name="bye", content=" indeed."
|
||||
)
|
||||
|
||||
|
||||
def test_ani_message_chunks() -> None:
|
||||
assert AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||
example=True, content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
example=True, content="I am indeed."
|
||||
), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||
example=False, content=" indeed."
|
||||
)
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
|
@ -1,60 +1,15 @@
|
||||
from langchain_core.messages import HumanMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from langchain.schema.output import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatGeneration",
|
||||
"ChatGenerationChunk",
|
||||
"ChatResult",
|
||||
"Generation",
|
||||
"GenerationChunk",
|
||||
"LLMResult",
|
||||
"RunInfo",
|
||||
]
|
||||
|
||||
|
||||
def test_generation_chunk() -> None:
|
||||
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||
text="world!"
|
||||
) == GenerationChunk(
|
||||
text="Hello, world!"
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk"
|
||||
|
||||
assert (
|
||||
GenerationChunk(text="Hello, ")
|
||||
+ GenerationChunk(text="world!", generation_info={"foo": "bar"})
|
||||
== GenerationChunk(text="Hello, world!", generation_info={"foo": "bar"})
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
assert (
|
||||
GenerationChunk(text="Hello, ")
|
||||
+ GenerationChunk(text="world!", generation_info={"foo": "bar"})
|
||||
+ GenerationChunk(text="!", generation_info={"baz": "foo"})
|
||||
== GenerationChunk(
|
||||
text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"}
|
||||
)
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
|
||||
def test_chat_generation_chunk() -> None:
|
||||
assert ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, ")
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!")
|
||||
) == ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!")
|
||||
), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk"
|
||||
|
||||
assert (
|
||||
ChatGenerationChunk(message=HumanMessageChunk(content="Hello, "))
|
||||
+ ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
||||
)
|
||||
== ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!"),
|
||||
generation_info={"foo": "bar"},
|
||||
)
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
assert (
|
||||
ChatGenerationChunk(message=HumanMessageChunk(content="Hello, "))
|
||||
+ ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
||||
)
|
||||
+ ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"}
|
||||
)
|
||||
== ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!!"),
|
||||
generation_info={"foo": "bar", "baz": "foo"},
|
||||
)
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
|
16
libs/langchain/tests/unit_tests/schema/test_output_parser.py
Normal file
16
libs/langchain/tests/unit_tests/schema/test_output_parser.py
Normal file
@ -0,0 +1,16 @@
|
||||
from langchain.schema.output_parser import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseCumulativeTransformOutputParser",
|
||||
"BaseGenerationOutputParser",
|
||||
"BaseLLMOutputParser",
|
||||
"BaseOutputParser",
|
||||
"BaseTransformOutputParser",
|
||||
"NoOpOutputParser",
|
||||
"OutputParserException",
|
||||
"StrOutputParser",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_prompt.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_prompt.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain.schema.prompt import __all__
|
||||
|
||||
EXPECTED_ALL = ["PromptValue"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.prompt_template import __all__
|
||||
|
||||
EXPECTED_ALL = ["BasePromptTemplate", "format_document"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_retriever.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_retriever.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain.schema.retriever import __all__
|
||||
|
||||
EXPECTED_ALL = ["BaseRetriever"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
7
libs/langchain/tests/unit_tests/schema/test_storage.py
Normal file
7
libs/langchain/tests/unit_tests/schema/test_storage.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain.schema.storage import __all__
|
||||
|
||||
EXPECTED_ALL = ["BaseStore"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -0,0 +1,7 @@
|
||||
from langchain.schema.vectorstore import __all__
|
||||
|
||||
EXPECTED_ALL = ["VectorStore", "VectorStoreRetriever"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -21,7 +21,7 @@ from langchain_core.messages import (
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation
|
||||
from langchain_core.prompts import ChatPromptValueConcrete, StringPromptValue
|
||||
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user